我偶然注意到 pyspark 的奇怪行为。基本上,它可以对数据框中不存在的列执行
where
函数:
print(spark.version)
df = spark.read.format("csv").option("header", True).load("abfss://some_abfs_path/df.csv")
print(type(df), df.columns.__len__(), df.count())
c = df.columns[0] # A column name before renaming
df = df.select(*[col(x).alias(f"{x}_new") for x in df.columns]) # Add suffix to column names
print(c in df.columns)
try:
df.select(c)
except:
print("SO THIS DOESN'T WORK, WHICH MAKES SENSE.")
# BUT WHY DOES THIS WORK:
print(df.where(col(c).isNotNull()).count())
# IT'S USING c AS f"{c}_new"
print(df.where(col(f"{c}_new").isNotNull()).count())
输出:
3.1.2
<class 'pyspark.sql.dataframe.DataFrame'> 102 1226791
False
SO THIS DOESN'T WORK, WHICH MAKES SENSE.
1226791
1226791
正如你所看到的,奇怪的是,当列重命名后
c
中不存在df
列时,它仍然可以用于where
功能。
我的直觉是 pyspark 在
where
重命名之前编译 select
。但在这种情况下,这将是一个可怕的设计,并且无法解释为什么旧的和新的列名称都可以工作。
希望有任何见解,谢谢。
我正在 Azure Databricks 上运行东西。
如有疑问,请使用
df.explain()
找出幕后情况。这将证实您的直觉:
Spark context available as 'sc' (master = local[*], app id = local-1709748307134).
SparkSession available as 'spark'.
>>> df = spark.read.option("header", True).option("inferSchema", True).csv("taxi.csv")
>>> c = df.columns[0]
>>> from pyspark.sql.functions import *
>>> df = df.select(*[col(x).alias(f"{x}_new") for x in df.columns])
>>> df.explain()
== Physical Plan ==
*(1) Project [VendorID#17 AS VendorID_new#51, tpep_pickup_datetime#18 AS tpep_pickup_datetime_new#52, tpep_dropoff_datetime#19 AS tpep_dropoff_datetime_new#53, passenger_count#20 AS passenger_count_new#54, trip_distance#21 AS trip_distance_new#55, RatecodeID#22 AS RatecodeID_new#56, store_and_fwd_flag#23 AS store_and_fwd_flag_new#57, PULocationID#24 AS PULocationID_new#58, DOLocationID#25 AS DOLocationID_new#59, payment_type#26 AS payment_type_new#60, fare_amount#27 AS fare_amount_new#61, extra#28 AS extra_new#62, mta_tax#29 AS mta_tax_new#63, tip_amount#30 AS tip_amount_new#64, tolls_amount#31 AS tolls_amount_new#65, improvement_surcharge#32 AS improvement_surcharge_new#66, total_amount#33 AS total_amount_new#67]
+- FileScan csv [VendorID#17,tpep_pickup_datetime#18,tpep_dropoff_datetime#19,passenger_count#20,trip_distance#21,RatecodeID#22,store_and_fwd_flag#23,PULocationID#24,DOLocationID#25,payment_type#26,fare_amount#27,extra#28,mta_tax#29,tip_amount#30,tolls_amount#31,improvement_surcharge#32,total_amount#33] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/charlie/taxi.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<VendorID:int,tpep_pickup_datetime:string,tpep_dropoff_datetime:string,passenger_count:int,...
>>> df = df.where(col(c).isNotNull())
>>> df.explain()
== Physical Plan ==
*(1) Project [VendorID#17 AS VendorID_new#51, tpep_pickup_datetime#18 AS tpep_pickup_datetime_new#52, tpep_dropoff_datetime#19 AS tpep_dropoff_datetime_new#53, passenger_count#20 AS passenger_count_new#54, trip_distance#21 AS trip_distance_new#55, RatecodeID#22 AS RatecodeID_new#56, store_and_fwd_flag#23 AS store_and_fwd_flag_new#57, PULocationID#24 AS PULocationID_new#58, DOLocationID#25 AS DOLocationID_new#59, payment_type#26 AS payment_type_new#60, fare_amount#27 AS fare_amount_new#61, extra#28 AS extra_new#62, mta_tax#29 AS mta_tax_new#63, tip_amount#30 AS tip_amount_new#64, tolls_amount#31 AS tolls_amount_new#65, improvement_surcharge#32 AS improvement_surcharge_new#66, total_amount#33 AS total_amount_new#67]
+- *(1) Filter isnotnull(VendorID#17)
+- FileScan csv [VendorID#17,tpep_pickup_datetime#18,tpep_dropoff_datetime#19,passenger_count#20,trip_distance#21,RatecodeID#22,store_and_fwd_flag#23,PULocationID#24,DOLocationID#25,payment_type#26,fare_amount#27,extra#28,mta_tax#29,tip_amount#30,tolls_amount#31,improvement_surcharge#32,total_amount#33] Batched: false, DataFilters: [isnotnull(VendorID#17)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/charlie/taxi.csv], PartitionFilters: [], PushedFilters: [IsNotNull(VendorID)], ReadSchema: struct<VendorID:int,tpep_pickup_datetime:string,tpep_dropoff_datetime:string,passenger_count:int,...
从下到上:
FileScan
读取数据,Filter
丢弃不需要的数据,Project
应用别名。对于 Spark 来说,这是构建其 DAG 的明智方法 - 尽快丢弃数据,这样您就不会浪费时间对其进行操作 - 但正如您所注意到的,它可能会导致意外的行为。如果您想避免这种情况,请在 df.checkpoint()
语句之前使用 df.where()
实现 DataFrame - 当您尝试引用旧列名称时,这将给您带来预期的错误:
>>> from pyspark.sql.functions import *
>>> spark.sparkContext.setCheckpointDir("file:/tmp/")
>>> df = spark.read.option("header", True).option("inferSchema", True).csv("taxi.csv")
>>> c = df.columns[0]
>>> df = df.select(*[col(x).alias(f"{x}_new") for x in df.columns])
>>> df = df.checkpoint()
>>> df.explain()
== Physical Plan ==
*(1) Scan ExistingRDD[VendorID_new#51,tpep_pickup_datetime_new#52,tpep_dropoff_datetime_new#53,passenger_count_new#54,trip_distance_new#55,RatecodeID_new#56,store_and_fwd_flag_new#57,PULocationID_new#58,DOLocationID_new#59,payment_type_new#60,fare_amount_new#61,extra_new#62,mta_tax_new#63,tip_amount_new#64,tolls_amount_new#65,improvement_surcharge_new#66,total_amount_new#67]
>>> df = df.where(col(c).isNotNull())
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/homebrew/opt/apache-spark/libexec/python/pyspark/sql/dataframe.py", line 3325, in filter
jdf = self._jdf.filter(condition._jc)
File "/opt/homebrew/opt/apache-spark/libexec/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1322, in __call__
File "/opt/homebrew/opt/apache-spark/libexec/python/pyspark/errors/exceptions/captured.py", line 185, in deco
raise converted from None
pyspark.errors.exceptions.captured.AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `VendorID` cannot be resolved. Did you mean one of the following? [`VendorID_new`, `extra_new`, `RatecodeID_new`, `mta_tax_new`, `DOLocationID_new`].;
'Filter isnotnull('VendorID)
+- LogicalRDD [VendorID_new#51, tpep_pickup_datetime_new#52, tpep_dropoff_datetime_new#53, passenger_count_new#54, trip_distance_new#55, RatecodeID_new#56, store_and_fwd_flag_new#57, PULocationID_new#58, DOLocationID_new#59, payment_type_new#60, fare_amount_new#61, extra_new#62, mta_tax_new#63, tip_amount_new#64, tolls_amount_new#65, improvement_surcharge_new#66, total_amount_new#67], false
>>>