类似于这个问题(Scala),但我需要PySpark中的组合(数组列的配对组合)。
输入示例:
df = spark.createDataFrame(
[([0, 1],),
([2, 3, 4],),
([5, 6, 7, 8],)],
['array_col'])
预期输出:
+------------+------------------------------------------------+
|array_col |out |
+------------+------------------------------------------------+
|[0, 1] |[[0, 1]] |
|[2, 3, 4] |[[2, 3], [2, 4], [3, 4]] |
|[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
+------------+------------------------------------------------+
原生 Spark 方法。我已将这个答案翻译为 PySpark。
Python 3.8+(
:=
的海象"array_col"
运算符在此脚本中重复多次):
from pyspark.sql import functions as F
df = df.withColumn(
"out",
F.filter(
F.transform(
F.flatten(F.transform(
c:="array_col",
lambda x: F.arrays_zip(F.array_repeat(x, F.size(c)), c)
)),
lambda x: F.array(x["0"], x[c])
),
lambda x: x[0] < x[1]
)
)
df.show(truncate=0)
# +------------+------------------------------------------------+
# |array_col |out |
# +------------+------------------------------------------------+
# |[0, 1] |[[0, 1]] |
# |[2, 3, 4] |[[2, 3], [2, 4], [3, 4]] |
# |[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
# +------------+------------------------------------------------+
没有海象操作员的替代方案:
from pyspark.sql import functions as F
df = df.withColumn(
"out",
F.filter(
F.transform(
F.flatten(F.transform(
"array_col",
lambda x: F.arrays_zip(F.array_repeat(x, F.size("array_col")), "array_col")
)),
lambda x: F.array(x["0"], x["array_col"])
),
lambda x: x[0] < x[1]
)
)
Spark 2.4+ 的替代品
from pyspark.sql import functions as F
df = df.withColumn(
"out",
F.expr("""
filter(
transform(
flatten(transform(
array_col,
x -> arrays_zip(array_repeat(x, size(array_col)), array_col)
)),
x -> array(x["0"], x["array_col"])
),
x -> x[0] < x[1]
)
""")
)
pandas_udf
是 PySpark 中一种高效且简洁的方法。
from pyspark.sql import functions as F
import pandas as pd
from itertools import combinations
@F.pandas_udf('array<array<int>>')
def pudf(c: pd.Series) -> pd.Series:
return c.apply(lambda x: list(combinations(x, 2)))
df = df.withColumn('out', pudf('array_col'))
df.show(truncate=0)
# +------------+------------------------------------------------+
# |array_col |out |
# +------------+------------------------------------------------+
# |[0, 1] |[[0, 1]] |
# |[2, 3, 4] |[[2, 3], [2, 4], [3, 4]] |
# |[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
# +------------+------------------------------------------------+
注意:在某些系统中,您可能需要提供
'array<array<int>>'
中的类型,而不是 pyspark.sql.types
,例如ArrayType(ArrayType(IntegerType()))
。