向头部添加一条记录,以便使用 scala Spark 不会重复记录

问题描述 投票:0回答:2

我有以下数据框: PK 的 -> code_id、日期

  • 不必生成重复项,但是,有
+-------+----------+-------+-------+-------+-----------+----------+
|code_id|date      |import1|import2|import3|postal_code|country   |
+-------+----------+-------+-------+-------+-----------+----------+
|9900   |2024-02-21|50.0   |null   |null   |28032      |california|
|9900   |2024-02-21|null   |null   |7.0    |28032      |california|
|9900   |2024-02-21|null   |100.0  |null   |28032      |california|
|5000   |2024-02-21|230.0  |null   |null   |28032      |california|
|5000   |2024-02-21|null   |46.4   |null   |28032      |california|
|5000   |2024-02-21|null   |null   |180.2  |28032      |california|
+-------+----------+-------+-------+-------+-----------+----------+

我想按(“code_id”,“date”)对其进行分组,我希望的结果应该是:

+-------+----------+-------+-------+-------+-----------+----------+
|code_id|date      |import1|import2|import3|postal_code|country   |
+-------+----------+-------+-------+-------+-----------+----------+
|9900   |2024-02-21|50.0   |100.0  |7.0    |28032      |california|
|5000   |2024-02-21|230.0  |46.4   |180.2  |28032      |california|
+-------+----------+-------+-------+-------+-----------+----------+

这可能吗? 我尝试过创建一个 groupBy 但我不喜欢它的结果。

 data.groupBy("code_id", "date").agg(collect_set(struct("import1","import2","import3","postal_code","country")))

谢谢ss!

scala apache-spark
2个回答
0
投票

类似的解决方案可以在下面的链接中找到。

获取组中第一个非空值

希望这有帮助!


0
投票

一个简单的解决方案是使用

groupBy
在每个“导入?”上聚合
collect_list
。列(然后通过
element_at
从每个列表中选择第一个元素),并为其余各列选取
first

val data = Seq(
  (9900, "2024-02-21", Some(50.0), None, None, "28032", "california"),
  (9900, "2024-02-21", None, None, Some(7.0), "28032", "california"),
  (9900, "2024-02-21", None, Some(100.0), None, "28032", "california"),
  (5000, "2024-02-21", Some(230.0), None, None, "28032", "california"),
  (5000, "2024-02-21", None, Some(46.4), None, "28032", "california"),
  (5000, "2024-02-21", None, None, Some(180.2), "28032", "california") 
).toDF("code_id", "date", "import1", "import2", "import3", "postal_code", "country")

val keys = Seq("code_id", "date")
val imports = data.columns.filter(_.matches("import\\d+"))
val rest = data.columns diff keys diff imports

val aggCols = imports.map(c => element_at(collect_list(col(c)), 1).as(c))
val restCols = rest.map(c => first(col(c)).as(c))

data.
  groupBy(keys.map(col): _*).agg(aggCols.head, aggCols.tail ++ restCols: _*).
  show

// +-------+----------+-------+-------+-------+-----------+----------+
// |code_id|      date|import1|import2|import3|postal_code|   country|
// +-------+----------+-------+-------+-------+-----------+----------+
// |   9900|2024-02-21|   50.0|  100.0|    7.0|      28032|california|
// |   5000|2024-02-21|  230.0|   46.4|  180.2|      28032|california|
// +-------+----------+-------+-------+-------+-----------+----------+

或者,使用

first
进行
Window
(ignoreNulls=true) 分区也可以工作,尽管在给定所需输出的情况下仍需要额外的
rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
    

© www.soinside.com 2019 - 2024. All rights reserved.