正如主题描述的那样,我有一个PySpark Dataframe,我需要将三列融合成行。每列基本上代表一个类别中的单个事实。最终目标是将数据汇总到每个类别的单个总数中。
这个数据框中有数千万行,所以我需要一种方法在spark集群上进行转换,而不会将任何数据带回驱动程序(在本例中为Jupyter)。
以下是几个商店的数据框摘录:
+-----------+----------------+-----------------+----------------+
| store_id |qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs|
+-----------+----------------+-----------------+----------------+
| 100| 30| 105| 35|
| 200| 55| 85| 65|
| 300| 20| 125| 90|
+-----------+----------------+-----------------+----------------+
下面是所需的结果数据框,每个商店有多行,其中原始数据框的列已融入新数据框的行中,新类别列中每个原始列有一行:
+-----------+--------+-----------+
| product_id|CATEGORY|qty_on_hand|
+-----------+--------+-----------+
| 100| milk| 30|
| 100| bread| 105|
| 100| eggs| 35|
| 200| milk| 55|
| 200| bread| 85|
| 200| eggs| 65|
| 300| milk| 20|
| 300| bread| 125|
| 300| eggs| 90|
+-----------+--------+-----------+
最后,我想汇总结果数据框以获得每个类别的总数:
+--------+-----------------+
|CATEGORY|total_qty_on_hand|
+--------+-----------------+
| milk| 105|
| bread| 315|
| eggs| 190|
+--------+-----------------+
更新:有人建议这个问题是重复的,可以回答here。情况并非如此,因为解决方案将行转换为列,我需要执行相反的操作,将列熔化为行。
我们可以使用explode()函数来解决这个问题。在Python中,使用melt
可以完成同样的事情。
# Loading the requisite packages
from pyspark.sql.functions import col, explode, array, struct, expr, sum
# Creating the DataFrame
df = sqlContext.createDataFrame([(100,30,105,35),(200,55,85,65),(300,20,125,90)],('store_id','qty_on_hand_milk','qty_on_hand_bread','qty_on_hand_eggs'))
df.show()
+--------+----------------+-----------------+----------------+
|store_id|qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs|
+--------+----------------+-----------------+----------------+
| 100| 30| 105| 35|
| 200| 55| 85| 65|
| 300| 20| 125| 90|
+--------+----------------+-----------------+----------------+
编写下面的函数,这将是explode
这个DataFrame -
def to_explode(df, by):
# Filter dtypes and split into column names and type description
cols, dtypes = zip(*((c, t) for (c, t) in df.dtypes if c not in by))
# Spark SQL supports only homogeneous columns
assert len(set(dtypes)) == 1, "All columns have to be of the same type"
# Create and explode an array of (column_name, column_value) structs
kvs = explode(array([
struct(lit(c).alias("CATEGORY"), col(c).alias("qty_on_hand")) for c in cols
])).alias("kvs")
return df.select(by + [kvs]).select(by + ["kvs.CATEGORY", "kvs.qty_on_hand"])
将此DataFrame上的函数应用于explode
-
df = to_explode(df, ['store_id'])\
.drop('store_id')
df.show()
+-----------------+-----------+
| CATEGORY|qty_on_hand|
+-----------------+-----------+
| qty_on_hand_milk| 30|
|qty_on_hand_bread| 105|
| qty_on_hand_eggs| 35|
| qty_on_hand_milk| 55|
|qty_on_hand_bread| 85|
| qty_on_hand_eggs| 65|
| qty_on_hand_milk| 20|
|qty_on_hand_bread| 125|
| qty_on_hand_eggs| 90|
+-----------------+-----------+
现在,我们需要从qty_on_hand_
列中删除字符串CATEGORY
。它可以使用expr()函数完成。注意expr
遵循基于1的子串索引,而不是0 -
df = df.withColumn('CATEGORY',expr('substring(CATEGORY, 13)'))
df.show()
+--------+-----------+
|CATEGORY|qty_on_hand|
+--------+-----------+
| milk| 30|
| bread| 105|
| eggs| 35|
| milk| 55|
| bread| 85|
| eggs| 65|
| milk| 20|
| bread| 125|
| eggs| 90|
+--------+-----------+
最后,使用qty_on_hand
函数汇总由CATEGORY
分组的列agg() -
df = df.groupBy(['CATEGORY']).agg(sum('qty_on_hand').alias('total_qty_on_hand'))
df.show()
+--------+-----------------+
|CATEGORY|total_qty_on_hand|
+--------+-----------------+
| eggs| 190|
| bread| 315|
| milk| 105|
+--------+-----------------+
使用pyspark的col,when, functions
模块可能的方法
>>> from pyspark.sql import functions as F
>>> from pyspark.sql.functions import *
>>> from pyspark.sql.types import StringType
>>> concat_udf = F.udf(lambda cols: "".join([str(x) if x is not None else "*" for x in cols]), StringType())
>>> rdd = sc.parallelize([[100,30,105,35],[200,55,85,65],[300,20,125,90]])
>>> df = rdd.toDF(['store_id','qty_on_hand_milk','qty_on_hand_bread','qty_on_hand_eggs'])
>>> df.show()
+--------+----------------+-----------------+----------------+
|store_id|qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs|
+--------+----------------+-----------------+----------------+
| 100| 30| 105| 35|
| 200| 55| 85| 65|
| 300| 20| 125| 90|
+--------+----------------+-----------------+----------------+
#adding one more column with arrayed values of all three columns
>>> df_1=df.withColumn("new_col", concat_udf(F.array("qty_on_hand_milk", "qty_on_hand_bread","qty_on_hand_eggs")))
#convert it into array<int> for carrying out agg operations
>>> df_2=df_1.withColumn("new_col_1",split(col("new_col"), ",\s*").cast("array<int>").alias("new_col_1"))
#posexplode gives you the position along with usual explode which helps in categorizing
>>> df_3=df_2.select("store_id", posexplode("new_col_1").alias("col_1","qty"))
#if else conditioning for category column
>>> df_3.withColumn("category",F.when(col("col_1") == 0, "milk").when(col("col_1") == 1, "bread").otherwise("eggs")).select("store_id","category","qty").show()
+--------+--------+---+
|store_id|category|qty|
+--------+--------+---+
| 100| milk| 30|
| 100| bread|105|
| 100| eggs| 35|
| 200| milk| 55|
| 200| bread| 85|
| 200| eggs| 65|
| 300| milk| 20|
| 300| bread|125|
| 300| eggs| 90|
+--------+--------+---+
#aggregating to find sum
>>> df_3.withColumn("category",F.when(col("col_1") == 0, "milk").when(col("col_1") == 1, "bread").otherwise("eggs")).select("category","qty").groupBy('category').sum().show()
+--------+--------+
|category|sum(qty)|
+--------+--------+
| eggs| 190|
| bread| 315|
| milk| 105|
+--------+--------+
>>> df_3.printSchema()
root
|-- store_id: long (nullable = true)
|-- col_1: integer (nullable = false)
|-- qty: integer (nullable = true)