我有一个包含两列的数据框:“ID”和“Amount”,每行代表一个特定ID的交易和交易金额。我的示例使用以下DF:
val df = sc.parallelize(Seq((1, 120),(1, 120),(2, 40),
(2, 50),(1, 30),(2, 120))).toDF("ID","Amount")
我想创建一个新列,用于标识所述金额是否为重复值,即是否出现在同一ID的任何其他事务中。
我已经找到了一种更普遍的方法,即在整个列“Amount”中,不考虑ID,使用以下函数:
def recurring_amounts(df: DataFrame, col: String) : DataFrame = {
var df_to_arr = df.select(col).rdd.map(r => r(0).asInstanceOf[Double]).collect()
var arr_to_map = df_to_arr.groupBy(identity).mapValues(_.size)
var map_to_df = arr_to_map.toSeq.toDF(col, "Count")
var df_reformat = map_to_df.withColumn("Amount", $"Amount".cast(DoubleType))
var df_out = df.join(df_reformat, Seq("Amount"))
return df_new
}
val df_output = recurring_amounts(df, "Amount")
返回:
+---+------+-----+
|ID |Amount|Count|
+---+------+-----+
| 1 | 120 | 3 |
| 1 | 120 | 3 |
| 2 | 40 | 1 |
| 2 | 50 | 1 |
| 1 | 30 | 1 |
| 2 | 120 | 3 |
+---+------+-----+
然后,我可以使用它来创建我想要的二进制变量,以指示金额是否重复出现(是,如果> 1,否则没有)。
但是,我的问题在这个例子中由值120说明,它是ID 1但不是ID 2的重复。因此,我想要的输出是:
+---+------+-----+
|ID |Amount|Count|
+---+------+-----+
| 1 | 120 | 2 |
| 1 | 120 | 2 |
| 2 | 40 | 1 |
| 2 | 50 | 1 |
| 1 | 30 | 1 |
| 2 | 120 | 1 |
+---+------+-----+
我一直试图想办法使用.over(Window.partitionBy("ID")
应用函数,但不知道如何去做。任何提示都将非常感激。
如果你擅长sql,你可以为你的Dataframe
编写sql查询。你需要做的第一件事就是将你的Dataframe
as注册到spark的内存中。之后,您可以在表的顶部编写sql。请注意,spark
是spark会话变量。
val df = sc.parallelize(Seq((1, 120),(1, 120),(2, 40),(2, 50),(1, 30),(2, 120))).toDF("ID","Amount")
df.registerTempTable("transactions")
spark.sql("select *,count(*) over(partition by ID,Amount) as Count from transactions").show()
请让我知道,如果你有任何问题。