PySpark Dataframe将列熔化为行

问题描述 投票:2回答:2

正如主题描述的那样,我有一个PySpark Dataframe,我需要将三列融合成行。每列基本上代表一个类别中的单个事实。最终目标是将数据汇总到每个类别的单个总数中。

这个数据框中有数千万行,所以我需要一种方法在spark集群上进行转换,而不会将任何数据带回驱动程序(在本例中为Jupyter)。

以下是几个商店的数据框摘录: +-----------+----------------+-----------------+----------------+ | store_id |qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs| +-----------+----------------+-----------------+----------------+ | 100| 30| 105| 35| | 200| 55| 85| 65| | 300| 20| 125| 90| +-----------+----------------+-----------------+----------------+

下面是所需的结果数据框,每个商店有多行,其中原始数据框的列已融入新数据框的行中,新类别列中每个原始列有一行: +-----------+--------+-----------+ | product_id|CATEGORY|qty_on_hand| +-----------+--------+-----------+ | 100| milk| 30| | 100| bread| 105| | 100| eggs| 35| | 200| milk| 55| | 200| bread| 85| | 200| eggs| 65| | 300| milk| 20| | 300| bread| 125| | 300| eggs| 90| +-----------+--------+-----------+

最后,我想汇总结果数据框以获得每个类别的总数: +--------+-----------------+ |CATEGORY|total_qty_on_hand| +--------+-----------------+ | milk| 105| | bread| 315| | eggs| 190| +--------+-----------------+

更新:有人建议这个问题是重复的,可以回答here。情况并非如此,因为解决方案将行转换为列,我需要执行相反的操作,将列熔化为行。

python dataframe pyspark aggregate melt
2个回答
0
投票

我们可以使用explode()函数来解决这个问题。在Python中,使用melt可以完成同样的事情。

# Loading the requisite packages
from pyspark.sql.functions import col, explode, array, struct, expr, sum
# Creating the DataFrame
df = sqlContext.createDataFrame([(100,30,105,35),(200,55,85,65),(300,20,125,90)],('store_id','qty_on_hand_milk','qty_on_hand_bread','qty_on_hand_eggs'))
df.show()
+--------+----------------+-----------------+----------------+
|store_id|qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs|
+--------+----------------+-----------------+----------------+
|     100|              30|              105|              35|
|     200|              55|               85|              65|
|     300|              20|              125|              90|
+--------+----------------+-----------------+----------------+

编写下面的函数,这将是explode这个DataFrame -

def to_explode(df, by):

    # Filter dtypes and split into column names and type description
    cols, dtypes = zip(*((c, t) for (c, t) in df.dtypes if c not in by))
    # Spark SQL supports only homogeneous columns
    assert len(set(dtypes)) == 1, "All columns have to be of the same type"

    # Create and explode an array of (column_name, column_value) structs
    kvs = explode(array([
      struct(lit(c).alias("CATEGORY"), col(c).alias("qty_on_hand")) for c in cols
    ])).alias("kvs")

    return df.select(by + [kvs]).select(by + ["kvs.CATEGORY", "kvs.qty_on_hand"])

将此DataFrame上的函数应用于explode-

df = to_explode(df, ['store_id'])\
     .drop('store_id')
df.show()
+-----------------+-----------+
|         CATEGORY|qty_on_hand|
+-----------------+-----------+
| qty_on_hand_milk|         30|
|qty_on_hand_bread|        105|
| qty_on_hand_eggs|         35|
| qty_on_hand_milk|         55|
|qty_on_hand_bread|         85|
| qty_on_hand_eggs|         65|
| qty_on_hand_milk|         20|
|qty_on_hand_bread|        125|
| qty_on_hand_eggs|         90|
+-----------------+-----------+

现在,我们需要从qty_on_hand_列中删除字符串CATEGORY。它可以使用expr()函数完成。注意expr遵循基于1的子串索引,而不是0 -

df = df.withColumn('CATEGORY',expr('substring(CATEGORY, 13)'))
df.show()
+--------+-----------+
|CATEGORY|qty_on_hand|
+--------+-----------+
|    milk|         30|
|   bread|        105|
|    eggs|         35|
|    milk|         55|
|   bread|         85|
|    eggs|         65|
|    milk|         20|
|   bread|        125|
|    eggs|         90|
+--------+-----------+

最后,使用qty_on_hand函数汇总由CATEGORY分组的列agg() -

df = df.groupBy(['CATEGORY']).agg(sum('qty_on_hand').alias('total_qty_on_hand'))
df.show()
+--------+-----------------+
|CATEGORY|total_qty_on_hand|
+--------+-----------------+
|    eggs|              190|
|   bread|              315|
|    milk|              105|
+--------+-----------------+

0
投票

使用pyspark的col,when, functions模块可能的方法

>>> from pyspark.sql import functions as F
>>> from pyspark.sql.functions import *
>>> from pyspark.sql.types import StringType
>>> concat_udf = F.udf(lambda cols: "".join([str(x) if x is not None else "*" for x in cols]), StringType())

>>> rdd = sc.parallelize([[100,30,105,35],[200,55,85,65],[300,20,125,90]])
>>> df = rdd.toDF(['store_id','qty_on_hand_milk','qty_on_hand_bread','qty_on_hand_eggs'])

>>> df.show()
+--------+----------------+-----------------+----------------+
|store_id|qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs|
+--------+----------------+-----------------+----------------+
|     100|              30|              105|              35|
|     200|              55|               85|              65|
|     300|              20|              125|              90|
+--------+----------------+-----------------+----------------+

#adding one more column with arrayed values of all three columns
>>> df_1=df.withColumn("new_col", concat_udf(F.array("qty_on_hand_milk", "qty_on_hand_bread","qty_on_hand_eggs")))
#convert it into array<int> for carrying out agg operations
>>> df_2=df_1.withColumn("new_col_1",split(col("new_col"), ",\s*").cast("array<int>").alias("new_col_1"))
#posexplode gives you the position along with usual explode which helps in categorizing
>>> df_3=df_2.select("store_id",  posexplode("new_col_1").alias("col_1","qty"))
#if else conditioning for category column
>>> df_3.withColumn("category",F.when(col("col_1") == 0, "milk").when(col("col_1") == 1, "bread").otherwise("eggs")).select("store_id","category","qty").show()
+--------+--------+---+
|store_id|category|qty|
+--------+--------+---+
|     100|    milk| 30|
|     100|   bread|105|
|     100|    eggs| 35|
|     200|    milk| 55|
|     200|   bread| 85|
|     200|    eggs| 65|
|     300|    milk| 20|
|     300|   bread|125|
|     300|    eggs| 90|
+--------+--------+---+

#aggregating to find sum
>>> df_3.withColumn("category",F.when(col("col_1") == 0, "milk").when(col("col_1") == 1, "bread").otherwise("eggs")).select("category","qty").groupBy('category').sum().show()
+--------+--------+
|category|sum(qty)|
+--------+--------+
|    eggs|     190|
|   bread|     315|
|    milk|     105|
+--------+--------+
>>> df_3.printSchema()
root
 |-- store_id: long (nullable = true)
 |-- col_1: integer (nullable = false)
 |-- qty: integer (nullable = true)
© www.soinside.com 2019 - 2024. All rights reserved.