我有一个图像数据集(id,url,features),我在其中执行了所有图像之间的余弦相似度。结果是具有以下结构的pyspark数据帧:
>>> cos_df.printSchema()
root
|-- id: integer (nullable = true)
|-- url: string (nullable = true)
|-- vec: vector (nullable = true)
vec是包含余弦相似性结果的列(DenseVector)。我要做的是创建一个列“similar_urls”或更新“vec”,并为每一行输入基于vec列值的前N个最相似的项目。
例如,如果我取id = 26,我想查看“vec”并查找前N个项的索引(id和索引是相同的)并将vec的值替换为顶部的url列表N项。
我试图做的是:
我陷入了第一步,因为我似乎无法将我的“vec”值转换为数组/列表以找到前10个值。
from pyspark.sql.functions import udf
def convert_to_array(vec):
return type(vec)
test_udf = udf(convert_to_array, StringType())
cos_df = cos_df.withColumn("vec", test_udf("vec"))
当我尝试查看vec值的类型时,它会返回
net.razorvine.pickle.objects.ClassDictConstructor@2db673eb
你知道那是什么类型,我怎么操纵它,以便我可以改变vec?
P.S:我也对任何其他能够解决特定问题的解决方案持开放态度!
想出了解决方案1.它似乎工作。
from pyspark.sql.functions import udf
def convert_to_array(vec):
vec_list = vec.tolist()
sorted_top = sorted(range(len(vec_list)), key=lambda i: vec_list[i], reverse=True)[1:16]
return sorted_top
test_udf = udf(convert_to_array, ArrayType(IntegerType()))
cos_df = cos_df.withColumn("similar_url", test_udf("vec"))