我有这样的Dataframe(在Pyspark 2.3.1中):
from pyspark.sql import Row
my_data = spark.createDataFrame([
Row(a=[9, 3, 4], b=['a', 'b', 'c'], mask=[True, False, False]),
Row(a=[7, 2, 6, 4], b=['w', 'x', 'y', 'z'], mask=[True, False, True, False])
])
my_data.show(truncate=False)
#+------------+------------+--------------------------+
#|a |b |mask |
#+------------+------------+--------------------------+
#|[9, 3, 4] |[a, b, c] |[true, false, false] |
#|[7, 2, 6, 4]|[w, x, y, z]|[true, false, true, false]|
#+------------+------------+--------------------------+
现在我想使用mask
列来对a
和b
列进行子集化:
my_desired_output = spark.createDataFrame([
Row(a=[9], b=['a']),
Row(a=[7, 6], b=['w', 'y'])
])
my_desired_output.show(truncate=False)
#+------+------+
#|a |b |
#+------+------+
#|[9] |[a] |
#|[7, 6]|[w, y]|
#+------+------+
实现这一目标的“惯用”方法是什么?我当前的解决方案涉及到底层RDD的map
-ing和Numpy的子集化,这似乎是不优雅的:
import numpy as np
def subset_with_mask(row):
mask = np.asarray(row.mask)
a_masked = np.asarray(row.a)[mask].tolist()
b_masked = np.asarray(row.b)[mask].tolist()
return Row(a=a_masked, b=b_masked)
my_desired_output = spark.createDataFrame(my_data.rdd.map(subset_with_mask))
这是最好的方法,还是有更好的东西(更简洁和/或更高效)我可以使用Spark SQL工具做什么?
一种选择是使用UDF,您可以选择通过数组中的数据类型进行专门化:
import numpy as np
import pyspark.sql.functions as F
import pyspark.sql.types as T
def _mask_list(lst, mask):
return np.asarray(lst)[mask].tolist()
mask_array_int = F.udf(_mask_list, T.ArrayType(T.IntegerType()))
mask_array_str = F.udf(_mask_list, T.ArrayType(T.StringType()))
my_desired_output = my_data
my_desired_output = my_desired_output.withColumn(
'a', mask_array_int(F.col('a'), F.col('mask'))
)
my_desired_output = my_desired_output.withColumn(
'b', mask_array_str(F.col('b'), F.col('mask'))
)
上一个答案中提到的UDF可能是在Spark 2.4中添加的数组函数之前的方法。为了完整起见,这里是2.4之前的“纯SQL”实现。
from pyspark.sql.functions import *
df = my_data.withColumn("row", monotonically_increasing_id())
df1 = df.select("row", posexplode("a").alias("pos", "a"))
df2 = df.select("row", posexplode("b").alias("pos", "b"))
df3 = df.select("row", posexplode("mask").alias("pos", "mask"))
df1\
.join(df2, ["row", "pos"])\
.join(df3, ["row", "pos"])\
.filter("mask")\
.groupBy("row")\
.agg(collect_list("a").alias("a"), collect_list("b").alias("b"))\
.select("a", "b")\
.show()
输出:
+------+------+
| a| b|
+------+------+
|[7, 6]|[w, y]|
| [9]| [a]|
+------+------+
这是使用2个UDF进行压缩和解压缩列表的另一种方法:
from pyspark.sql.types import ArrayType, StructType, StructField, StringType
from pyspark.sql.functions import udf, col, lit
zip_schema = ArrayType(StructType((StructField("a", StringType()), StructField("b", StringType()))))
unzip_schema = ArrayType(StringType())
zip_udf = udf(my_zip, zip_schema)
unzip_udf = udf(my_unzip, unzip_schema)
df = my_data.withColumn("zipped", zip_udf(col("a"), col("b"), col("mask")))
.withColumn("a", unzip_udf(col("zipped"), lit(0)))
.withColumn("b", unzip_udf(col("zipped"), lit(1)))
.drop("zipped", "mask")
def my_unzip(zipped, indx):
return [str(x[indx]) for x in zipped]
def my_zip(a, b, mask):
return [(x[0], x[1]) for x in zip(a,b,mask) if x[2]]
my_zip负责根据掩码过滤数据并创建(cola,colb)元组,该元组也是返回列表的项目。
my_unzip将从my_zip创建的数据中提取特定indx的数据。
输出:
+------+------+
| a| b|
+------+------+
| [9]| [a]|
|[7, 6]|[w, y]|
+------+------+