输入数据集如下所示
| id | fields | f1 | f2 | f3 | f4 |
| -------- | -------- | -------- | -------- | -------- | -------- |
| 1 | f1, f2, f3 | 3 | 2 | 0 | 2 |
| 2 | f2, f4 | 2 | 4 | 2 | 5 |
| 3 | f1 | 7 | 6 | 4 | 6 |
预期输出是
| id | fields | json_field |
| -------- | -------- | ------------------------------- |
| 1 | f1, f2, f3 | {"f1": 3, "f2": 2, "f3": 0} |
| 2 | f2, f4 | {"f2": 4, "f4": 5} |
| 3 | f1 | {"f1": 7} |
我已经尝试过了
input_df.selet(
col("id"),
col("fields"),
to_json(struct(split(col("fields"), ","))).alias("json_field")
)
但它无法正常工作。
您可以使用
collect_list
Spark SQL 函数上的 map_from_entries
和 struct
来基于 fields
列动态创建地图:
data = [(1, "f1, f2, f3", 3, 2, 0, 2),
(2, "f2, f4", 2, 4, 2, 5),
(3, "f1", 7, 6, 4, 6)]
input_df = spark.createDataFrame(data, ["id", "fields", "f1", "f2", "f3", "f4"])
input_df.createOrReplaceTempView("input_view")
// INPUT: input_view
// +---+----------+---+---+---+---+
// |id |fields |f1 |f2 |f3 |f4 |
// +---+----------+---+---+---+---+
// |1 |f1, f2, f3|3 |2 |0 |2 |
// |2 |f2, f4 |2 |4 |2 |5 |
// |3 |f1 |7 |6 |4 |6 |
// +---+----------+---+---+---+---+
output_df = spark.sql(
"""
|WITH exploded_view AS (
| SELECT id, explode(split(fields, ', ')) as field, f1, f2, f3, f4
| FROM input_view
| )
|SELECT
| id,
| collect_list(field) as fields,
| map_from_entries(collect_list(struct(field, CASE field WHEN 'f1' THEN f1 WHEN 'f2' THEN f2 WHEN 'f3' THEN f3 WHEN 'f4' THEN f4 END))) as json_field
| FROM exploded_view
| GROUP BY id, f1, f2, f3, f4
| ORDER BY id
|""".stripMargin)
output_df.show(false)
// OUTPUT:
//+---+------------+---------------------------+
//|id |fields |json_field |
//+---+------------+---------------------------+
//|1 |[f1, f2, f3]|{f1 -> 3, f2 -> 2, f3 -> 0}|
//|2 |[f2, f4] |{f2 -> 4, f4 -> 5} |
//|3 |[f1] |{f1 -> 7} |
//+---+------------+---------------------------+
在DataFrame API中,相同的代码可以是:
# Input data
data = [(1, "f1, f2, f3", 3, 2, 0, 2),
(2, "f2, f4", 2, 4, 2, 5),
(3, "f1", 7, 6, 4, 6)]
input_df = spark.createDataFrame(data, ["id", "fields", "f1", "f2", "f3", "f4"])
input_df.show(input_df.count(), False)
# Explode the 'fields' column into multiple rows
exploded_df = input_df.select("id", explode(split(input_df.fields, ", ")).alias("field"), "f1", "f2", "f3", "f4")
# Create the 'json_field' column using the 'field' column and the 'f1', 'f2', 'f3', 'f4' columns with map_from_entries
output_df = (exploded_df.groupBy("id", "f1", "f2", "f3", "f4")
.agg(collect_list("field").alias("fields"),
map_from_entries(collect_list(struct("field",
when(exploded_df.field == "f1", exploded_df.f1)
.when(exploded_df.field == "f2", exploded_df.f2)
.when(exploded_df.field == "f3", exploded_df.f3)
.when(exploded_df.field == "f4", exploded_df.f4)
)
)
).alias("json_field")
)
.orderBy("id")
.select("id", "fields", "json_field"))
output_df.show(output_df.count(), truncate=False)