Pyspark - 创建一个 json 列,其键来自另一个 CSV 列

问题描述 投票:0回答:1

输入数据集如下所示

 |    id     | fields     | f1       | f2         | f3       | f4         |
 | --------  | --------   | -------- | --------   | -------- | --------   |
 |     1     | f1, f2, f3 | 3        |  2         |     0    |     2      |
 |     2     | f2, f4     | 2        |  4         |     2    |     5      |
 |     3     | f1         | 7        |  6         |     4    |     6      |

预期输出是

 |    id     | fields     | json_field                      |
 | --------  | --------   | ------------------------------- |    
 |     1     | f1, f2, f3 | {"f1": 3, "f2": 2, "f3": 0}     |     
 |     2     | f2, f4     | {"f2": 4, "f4": 5}              |     
 |     3     | f1         | {"f1": 7}                       |

我已经尝试过了

    input_df.selet(
        col("id"),
        col("fields"),
        to_json(struct(split(col("fields"), ","))).alias("json_field")
    )

但它无法正常工作。

python apache-spark pyspark apache-spark-sql
1个回答
0
投票

您可以使用

collect_list
Spark SQL 函数上的
map_from_entries
struct
来基于
fields
列动态创建地图:

data = [(1, "f1, f2, f3", 3, 2, 0, 2),
        (2, "f2, f4", 2, 4, 2, 5),
        (3, "f1", 7, 6, 4, 6)]
input_df = spark.createDataFrame(data, ["id", "fields", "f1", "f2", "f3", "f4"])

input_df.createOrReplaceTempView("input_view")
// INPUT: input_view
// +---+----------+---+---+---+---+
// |id |fields    |f1 |f2 |f3 |f4 |
// +---+----------+---+---+---+---+
// |1  |f1, f2, f3|3  |2  |0  |2  |
// |2  |f2, f4    |2  |4  |2  |5  |
// |3  |f1        |7  |6  |4  |6  |
// +---+----------+---+---+---+---+

output_df = spark.sql(
  """
    |WITH exploded_view AS (
    |  SELECT id, explode(split(fields, ', ')) as field, f1, f2, f3, f4
    |  FROM input_view
    |  )
    |SELECT
    | id,
    | collect_list(field) as fields,
    | map_from_entries(collect_list(struct(field, CASE field WHEN 'f1' THEN f1 WHEN 'f2' THEN f2 WHEN 'f3' THEN f3 WHEN 'f4' THEN f4 END))) as json_field
    | FROM exploded_view
    | GROUP BY id, f1, f2, f3, f4
    | ORDER BY id
    |""".stripMargin)

output_df.show(false)
// OUTPUT:
//+---+------------+---------------------------+
//|id |fields      |json_field                 |
//+---+------------+---------------------------+
//|1  |[f1, f2, f3]|{f1 -> 3, f2 -> 2, f3 -> 0}|
//|2  |[f2, f4]    |{f2 -> 4, f4 -> 5}         |
//|3  |[f1]        |{f1 -> 7}                  |
//+---+------------+---------------------------+

在DataFrame API中,相同的代码可以是:

# Input data
data = [(1, "f1, f2, f3", 3, 2, 0, 2),
        (2, "f2, f4", 2, 4, 2, 5),
        (3, "f1", 7, 6, 4, 6)]
input_df = spark.createDataFrame(data, ["id", "fields", "f1", "f2", "f3", "f4"])

input_df.show(input_df.count(), False)

# Explode the 'fields' column into multiple rows
exploded_df = input_df.select("id", explode(split(input_df.fields, ", ")).alias("field"), "f1", "f2", "f3", "f4")

# Create the 'json_field' column using the 'field' column and the 'f1', 'f2', 'f3', 'f4' columns with map_from_entries
output_df = (exploded_df.groupBy("id", "f1", "f2", "f3", "f4")
             .agg(collect_list("field").alias("fields"),
                  map_from_entries(collect_list(struct("field",
                                                       when(exploded_df.field == "f1", exploded_df.f1)
                                                       .when(exploded_df.field == "f2", exploded_df.f2)
                                                       .when(exploded_df.field == "f3", exploded_df.f3)
                                                       .when(exploded_df.field == "f4", exploded_df.f4)
                                                       )
                                                )
                                   ).alias("json_field")
                  )
             .orderBy("id")
             .select("id", "fields", "json_field"))

output_df.show(output_df.count(), truncate=False)
© www.soinside.com 2019 - 2024. All rights reserved.