如何从结构体数组中选择列?

问题描述 投票:0回答:1
root
 |-- InvoiceNo: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- CustomerID: integer (nullable = true)
 |-- InvoiceDate: string (nullable = true)
 |-- collect_list(items): array (nullable = false)
 |    |-- element: struct (containsNull = false)
 |    |    |-- StockCode: string (nullable = true)
 |    |    |-- Description: string (nullable = true)
 |    |    |-- UnitPrice: double (nullable = true)
 |    |    |-- Country: string (nullable = true)

这是我的架构,我尝试创建新列totalPrice。

.withColumn('TotalPrice',col('Quantity') * col('UnitPrice'))\

就像这样,但我无法从数组结构中获取 UnitPrice ..该怎么做?

python apache-spark pyspark apache-spark-sql
1个回答
0
投票

要在 PySpark DataFrame 中创建新列 TotalPrice,您需要从嵌套数组结构中提取 UnitPrice 和 Quantity 并执行必要的计算。由于您的数据包含每张发票的项目数组,因此您需要使用 PySpark 函数来处理数组列并聚合结果。

以下是在 PySpark 中实现此目标的方法:

使用爆炸将项目数组转换为单独的行。 计算每件商品的总价格。 如果需要,将结果聚合回原始结构。 以下是如何执行这些步骤的示例:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, explode, sum as _sum
# Initialize Spark session
spark = SparkSession.builder.appName("CalculateTotalPrice").getOrCreate()

# Sample data
data = [
   ("001", 2, 123, "2023-05-20", [{"StockCode": "A001", "Description": 
   "Item A", "UnitPrice": 10.0, "Country": "UK"}, {"StockCode": "A002", 
  "Description": "Item B", "UnitPrice": 20.0, "Country": "UK"}]),
   ("002", 1, 124, "2023-05-21", [{"StockCode": "A003", "Description": 
   "Item C", "UnitPrice": 15.0, "Country": "USA"}])
   ]

  # Define schema
schema = """
      InvoiceNo STRING,
      Quantity INT,
      CustomerID INT,
      InvoiceDate STRING,
      collect_list_items ARRAY<STRUCT<StockCode: STRING, Description: 
      STRING, UnitPrice: DOUBLE, Country: STRING>>
     """

     # Create DataFrame
df = spark.createDataFrame(data, schema)

# Explode the array to individual rows
exploded_df = df.withColumn("item", explode("collect_list_items"))

# Calculate the totalPrice for each item
exploded_df = exploded_df.withColumn("totalPrice", col("item.UnitPrice") 
* col("Quantity"))

# If you want to aggregate the totalPrice back to the original structure
total_price_df = exploded_df.groupBy("InvoiceNo", "Quantity", 
"CustomerID", "InvoiceDate").agg(_sum("totalPrice").alias("totalPrice"))

# Show the result
total_price_df.show(truncate=False)

在此示例中:

我们定义 DataFrame 的架构。 使用示例数据创建 DataFrame。 使用explode 函数将数组collect_list(items) 展开为单独的行。 通过将单价乘以数量来计算每个商品的总价格。 汇总结果以获得每张发票的总价格。 这将为您提供一个新的 DataFrame,其中包含每张发票的计算总价格。

© www.soinside.com 2019 - 2024. All rights reserved.