我是Spark的新手,我无法找到解决问题的方法,非常感谢任何建议或帮助。
我有一个Pyspark.sql.dataframe,其中有两个包含字符串的数组列。两个列数组的长度都不一致,并且某些行还具有Null条目。我需要比较这两列,并且必须在B列中的每一行中都在OVERRIDE列中找到该数组的元素时,删除该元素。
+---------------+---------------+
| OVERRIDE | B |
+---------------+---------------+
| ['a']| ['a','b']|
| null| ['b']|
| null| ['a','c']|
| ['d','g']| ['d','g']|
| null| null|
| ['f']| ['f']|
+---------------+---------------+
最后应该看起来像这样:
+---------------+---------------+
| OVERRIDE | B |
+---------------+---------------+
| ['a']| ['b']|
| null| ['b']|
| null| ['a','c']|
| ['d','g']| null|
| null| null|
| ['f']| null|
+---------------+---------------+
我尝试过
from pyspark.sql.functions import array_remove, array_intersect
df = df.withColumn('B', array_remove(df.B, df.OVERRIDE))
还有
df = df.withColumn('B', array_remove(df.B, array_intersect(df.OVERRIDE, df.B)))
但是了解到array_remove()不能遍历该列,而是只能使用一个元素(例如'a')将其删除,然后在B列的所有行中将其删除。
我必须构建一个udf函数,如果是的话,我该怎么做?
您可以使用udf
执行此操作
@udf(returnType=ArrayType(StringType()))
def removeFromRight(override,b):
filtered_list=[x for x in b if x not in override]
if(len(filtered_list)==0):
filtered_list=[None]
return filtered_list
test1=test.withColumn("new_overridden_col",removeFromRight(col("override"),col("b")))
test1.show()