PySpark 中数组列值的配对组合

问题描述 投票:0回答:2

类似于这个问题(Scala),但我需要PySpark中的组合(数组列的配对组合)。

输入示例:

df = spark.createDataFrame(
    [([0, 1],),
     ([2, 3, 4],),
     ([5, 6, 7, 8],)],
    ['array_col'])

预期输出:

+------------+------------------------------------------------+
|array_col   |out                                             |
+------------+------------------------------------------------+
|[0, 1]      |[[0, 1]]                                        |
|[2, 3, 4]   |[[2, 3], [2, 4], [3, 4]]                        |
|[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
+------------+------------------------------------------------+
python arrays apache-spark pyspark combinations
2个回答
6
投票

原生 Spark 方法。我已将这个答案翻译为 PySpark。

Python 3.8+(

:=
的海象
"array_col"
运算符在此脚本中重复多次):

from pyspark.sql import functions as F

df = df.withColumn(
    "out",
    F.filter(
        F.transform(
            F.flatten(F.transform(
                c:="array_col",
                lambda x: F.arrays_zip(F.array_repeat(x, F.size(c)), c)
            )),
            lambda x: F.array(x["0"], x[c])
        ),
        lambda x: x[0] < x[1]
    )
)
df.show(truncate=0)
# +------------+------------------------------------------------+
# |array_col   |out                                             |
# +------------+------------------------------------------------+
# |[0, 1]      |[[0, 1]]                                        |
# |[2, 3, 4]   |[[2, 3], [2, 4], [3, 4]]                        |
# |[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
# +------------+------------------------------------------------+

没有海象操作员的替代方案:

from pyspark.sql import functions as F

df = df.withColumn(
    "out",
    F.filter(
        F.transform(
            F.flatten(F.transform(
                "array_col",
                lambda x: F.arrays_zip(F.array_repeat(x, F.size("array_col")), "array_col")
            )),
            lambda x: F.array(x["0"], x["array_col"])
        ),
        lambda x: x[0] < x[1]
    )
)

Spark 2.4+ 的替代品

from pyspark.sql import functions as F

df = df.withColumn(
    "out",
    F.expr("""
        filter(
            transform(
                flatten(transform(
                    array_col,
                    x -> arrays_zip(array_repeat(x, size(array_col)), array_col)
                )),
                x -> array(x["0"], x["array_col"])
            ),
            x -> x[0] < x[1]
        )
    """)
)

2
投票

pandas_udf
是 PySpark 中一种高效且简洁的方法。

from pyspark.sql import functions as F
import pandas as pd
from itertools import combinations

@F.pandas_udf('array<array<int>>')
def pudf(c: pd.Series) -> pd.Series:
    return c.apply(lambda x: list(combinations(x, 2)))


df = df.withColumn('out', pudf('array_col'))
df.show(truncate=0)
# +------------+------------------------------------------------+
# |array_col   |out                                             |
# +------------+------------------------------------------------+
# |[0, 1]      |[[0, 1]]                                        |
# |[2, 3, 4]   |[[2, 3], [2, 4], [3, 4]]                        |
# |[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
# +------------+------------------------------------------------+

注意:在某些系统中,您可能需要提供

'array<array<int>>'
中的类型,而不是
pyspark.sql.types
,例如
ArrayType(ArrayType(IntegerType()))

© www.soinside.com 2019 - 2024. All rights reserved.