我有由日期时间、ID 和速度组成的数据,我希望使用 PySpark 获取每个 ID 的速度直方图数据(起点/终点和计数)。样本数据:
df = spark.createDataFrame(
[
("2023-06-01 07:09:17", "abc", 4.5),
("2023-06-01 07:09:18", "abc", 9.1),
("2023-06-01 07:09:19", "abc", 3.2),
("2023-06-01 07:10:06", "ddc", 5.1),
("2023-06-01 07:09:07", "ddc", 3.6),
("2023-06-01 07:09:08", "ddc", 2.6)
],
["date_time", "id", "velocity"]
)
我对输出的格式不太挑剔。最初我是使用 Spark 的
rdd.histogram(bins)
函数绘制直方图,但这涵盖了所有速度值(没有分组)。这段代码是:
df.filter(col("velocity").isNotNull()).rdd.histogram(list(range(0, 100, 1)))
但是,我不知道如何对分组数据执行此操作。我尝试过这两件事:
# This throws an error: 'GroupedData' object has no attribute 'rdd'
df.filter(col("velocity").isNotNull()).groupBy("id").rdd.histogram(list(range(0, 100, 1)))
# This throws a much longer error, but ends with: TypeError: 'str' object is not callable
# I think this has to do with the rdd.groupBy method
df.filter(col("velocity").isNotNull()).rdd.groupBy("id").histogram(list(range(0, 100, 1)))
# This throws a long error, with this TypeError: TypeError: '>' not supported between instances of 'tuple' and 'int'
df.filter(col("velocity").isNotNull()).select("id", "velocity").rdd.groupByKey().histogram(list(range(0, 100, 1)))
您可以按
id
进行分组,然后简单地计算您感兴趣的所有区间内的 velocity
值的数量。如下所示:
result = df.filter(col("velocity").isNotNull())\
.groupBy("id")\
.agg( *[sum(
when((col("velocity") >= i) & (col("velocity") < (i+1)), 1)
.otherwise(0)
).alias(f"between_{i}_{i+1}") for i in range(10)])
result.show()
+---+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+------------+
| id|between_0_1|between_1_2|between_2_3|between_3_4|between_4_5|between_5_6|between_6_7|between_7_8|between_8_9|between_9_10|
+---+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+------------+
|abc| 0| 0| 0| 1| 1| 0| 0| 0| 0| 1|
|ddc| 0| 0| 1| 1| 0| 1| 0| 0| 0| 0|
+---+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+------------+
如果您喜欢简单的数组:
result\
.select("id", array(
[col(f"between_{i}_{i+1}") for i in range(10)]
).alias("histogram") )\
.show(truncate = False)
+---+------------------------------+
|id |histogram |
+---+------------------------------+
|abc|[0, 0, 0, 1, 1, 0, 0, 0, 0, 1]|
|ddc|[0, 0, 1, 1, 0, 1, 0, 0, 0, 0]|
+---+------------------------------+