如何在pyspark中对数组中的标签进行编码

问题描述 投票:2回答:2

例如,我在name中具有带分类特征的DataFrame:

 from pyspark.sql import SparkSession

 spark = SparkSession.builder.master("local").appName("example")
 .config("spark.some.config.option", "some-value").getOrCreate()

 features = [(['a', 'b', 'c'], 1),
             (['a', 'c'], 2),
             (['d'], 3),
             (['b', 'c'], 4), 
             (['a', 'b', 'd'], 5)]

 df = spark.createDataFrame(features, ['name','id'])
 df.show()

日期:

+---------+----+
|     name| id |
+---------+----+
|[a, b, c]|   1|
|   [a, c]|   2|
|      [d]|   3|
|   [b, c]|   4|
|[a, b, d]|   5|
+---------+----+

我想得到什么:

+--------+--------+--------+--------+----+
| name_a | name_b | name_c | name_d | id |
+--------+--------+--------+--------+----+
| 1      | 1      | 1      | 0      | 1  |
+--------+--------+--------+--------+----+
| 1      | 0      | 1      | 0      | 2  |
+--------+--------+--------+--------+----+
| 0      | 0      | 0      | 1      | 3  |
+--------+--------+--------+--------+----+
| 0      | 1      | 1      | 0      | 4  |
+--------+--------+--------+--------+----+
| 1      | 1      | 0      | 1      | 5  |
+--------+--------+--------+--------+----+

我找到了same queston,但没有任何帮助。我试图使用来自VectorIndexerPySpark.ML,但我遇到了一些问题,将name字段转换为vector type

 from pyspark.ml.feature import VectorIndexer

 indexer = VectorIndexer(inputCol="name", outputCol="indexed", maxCategories=5)
 indexerModel = indexer.fit(df)

我收到以下错误:

Column name must be of type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 but was actually ArrayType

我找到了一个解决方案here但它看起来过于复杂。但是,我不确定它是否只能用VectorIndexer完成。

python apache-spark pyspark pyspark-sql
2个回答
2
投票

如果你想使用Spark ML的输出,最好使用CountVectorizer

from pyspark.ml.feature import CountVectorizer

# Add binary=True if needed
df_enc = (CountVectorizer(inputCol="name", outputCol="name_vector")
    .fit(df)
    .transform(df))
df_enc.show(truncate=False)
+---------+---+-------------------------+
|name     |id |name_vector              |
+---------+---+-------------------------+
|[a, b, c]|1  |(4,[0,1,2],[1.0,1.0,1.0])|
|[a, c]   |2  |(4,[0,1],[1.0,1.0])      |
|[d]      |3  |(4,[3],[1.0])            |
|[b, c]   |4  |(4,[1,2],[1.0,1.0])      |
|[a, b, d]|5  |(4,[0,2,3],[1.0,1.0,1.0])|
+---------+---+-------------------------+

否则收集不同的值:

from pyspark.sql.functions import array_contains, col, explode

names = [
    x[0] for x in 
    df.select(explode("name").alias("name")).distinct().orderBy("name").collect()]

并使用array_contains选择列:

df_sep = df.select("*", *[
    array_contains("name", name).alias("name_{}".format(name)).cast("integer") 
    for name in names]
)
df_sep.show()
+---------+---+------+------+------+------+
|     name| id|name_a|name_b|name_c|name_d|
+---------+---+------+------+------+------+
|[a, b, c]|  1|     1|     1|     1|     0|
|   [a, c]|  2|     1|     0|     1|     0|
|      [d]|  3|     0|     0|     0|     1|
|   [b, c]|  4|     0|     1|     1|     0|
|[a, b, d]|  5|     1|     1|     0|     1|
+---------+---+------+------+------+------+

0
投票

来自explodepyspark.sql.functionspivot

from pyspark.sql import functions as F
features = [(['a', 'b', 'c'], 1),
             (['a', 'c'], 2),
             (['d'], 3),
             (['b', 'c'], 4),
             (['a', 'b', 'd'], 5)]
df = spark.createDataFrame(features, ['name','id'])
df.show()
+---------+---+
|     name| id|
+---------+---+
|[a, b, c]|  1|
|   [a, c]|  2|
|      [d]|  3|
|   [b, c]|  4|
|[a, b, d]|  5|
+---------+---+

df = df.withColumn('exploded', F.explode('name'))

df.drop('name').groupby('id').pivot('exploded').count().show()
+---+----+----+----+----+
| id|   a|   b|   c|   d|
+---+----+----+----+----+
|  5|   1|   1|null|   1|
|  1|   1|   1|   1|null|
|  3|null|null|null|   1|
|  2|   1|null|   1|null|
|  4|null|   1|   1|null|
+---+----+----+----+----+

id排序,将null转换为0

df.drop('name').groupby('id').pivot('exploded').count().na.fill(0).sort(F.col('id').asc()).show()
+---+---+---+---+---+
| id|  a|  b|  c|  d|
+---+---+---+---+---+
|  1|  1|  1|  1|  0|
|  2|  1|  0|  1|  0|
|  3|  0|  0|  0|  1|
|  4|  0|  1|  1|  0|
|  5|  1|  1|  0|  1|
+---+---+---+---+---+

explode为给定数组或映射中的每个元素返回一个新行。然后,您可以使用pivot“转置”新列。

© www.soinside.com 2019 - 2024. All rights reserved.