spark数据帧(scala)中tf idf输出的余弦相似度

问题描述 投票:0回答:1

我正在使用Spark Scala计算数据帧行之间的余弦相似度。

数据帧格式如下:

root
 |-- id: long (nullable = true)
 |-- features: vector (nullable = true)

数据框的示例如下:

+---+--------------------+
| id|            features|
+---+--------------------+
| 65|(10000,[48,70,87,...|
|191|(10000,[1,73,77,1...|
+---+--------------------+

给出结果的代码如下:

val df = spark.read.json("articles_line.json")
val tokenizer = new Tokenizer().setInputCol("desc").setOutputCol("words")
val wordsDF = tokenizer.transform(df)

def flattenWords = udf( (s: Seq[Seq[String]]) => s.flatMap(identity) )
val groupedDF = wordsDF.groupBy("id").
  agg(flattenWords(collect_list("words")).as("grouped_words"))
val hashingTF = new HashingTF().
  setInputCol("grouped_words").setOutputCol("rawFeatures").setNumFeatures(10000)
val featurizedData = hashingTF.transform(groupedDF)
val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
val idfModel = idf.fit(featurizedData)
val rescaledData = idfModel.transform(featurizedData)
val asDense = udf((v: Vector) => v.toDense) //transform to dense matrix
val newDf = rescaledData.select('id, 'features)
    .withColumn("dense_features", asDense($"features")

最终数据框看起来像

+-----+--------------------+--------------------+
|   id|            features|      dense_features|
+-----+--------------------+--------------------+
|21209|(10000,[128,288,2...|[0.0,0.0,0.0,0.0,...|
|21223|(10000,[8,18,32,4...|[0.0,0.0,0.0,0.0,...|
+-----+--------------------+--------------------+

我不知道如何处理“ dense_features”以计算余弦相似度。 This article不适用于我。感谢任何帮助。

[dense_features的一行示例。为简单起见,长度被削减。

[[0.0,0.0,0.0,0.0,7.08,0.0,0.0,0.0,0.0,2.24,0.0,0.0,0.0,0.0,0.0,,9.59]]
scala apache-spark-sql tf-idf cosine-similarity
1个回答
0
投票

这对我来说很好。完整代码

import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.distributed._
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.types._
import org.apache.spark.ml.feature._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val df = spark.read.json("/user/dmitry.korniltsev/lab02/data/DO_record_per_line.json")
val cleaned_df = df
    .withColumn("desc", regexp_replace('desc, "[^\\w\\sа-яА-ЯЁё]", ""))
    .withColumn("desc", lower(trim(regexp_replace('desc, "\\s+", " "))))
    .where(length('desc) > 0)

val tokenizer = new Tokenizer().setInputCol("desc").setOutputCol("words")
val wordsDF = tokenizer.transform(cleaned_df)
def flattenWords = udf( (s: Seq[Seq[String]]) => s.flatMap(identity) )
val hashingTF = new HashingTF()
    .setInputCol("words")
    .setOutputCol("rawFeatures")
    .setNumFeatures(20000)
val featurizedData = hashingTF.transform(wordsDF)
val idf = new IDF()
    .setInputCol("rawFeatures")
    .setOutputCol("features")
val idfModel = idf.fit(featurizedData)
val rescaledData = idfModel.transform(featurizedData)
val asDense = udf((v: Vector) => v.toDense)
val newDf = rescaledData
    .withColumn("dense_features", asDense($"features"))

val cosSimilarity = udf { (x: Vector, y: Vector) => 
    val v1 = x.toArray
    val v2 = y.toArray
    val l1 = scala.math.sqrt(v1.map(x => x*x).sum)
    val l2 = scala.math.sqrt(v2.map(x => x*x).sum)
    val scalar = v1.zip(v2).map(p => p._1*p._2).sum
    scalar/(l1*l2)
    }

val id_list = Seq(23325, 15072, 24506, 3879, 1067, 17019)
val filtered_df = newDf
    .filter(col("id").isin(id_list: _*))
    .select('id.alias("id_frd"), 'dense_features.alias("dense_frd"), 'lang.alias("lang_frd"))

val joinedDf = newDf.join(broadcast(filtered_df), 'id =!= 'id_frd && 'lang === 'lang_frd)
    .withColumn("cosine_sim", cosSimilarity(col("dense_frd"), col("dense_features")))

val filtered = joinedDf
    .filter(col("lang")==="en")
    .withColumn("cosine_sim", when(col("cosine_sim").isNaN, 0).otherwise(col("cosine_sim")))
    .withColumn("rank", row_number().over(
            Window.partitionBy(col("id_frd")).orderBy(col("cosine_sim").desc)))
    .filter(col("rank")between(2,11))
© www.soinside.com 2019 - 2024. All rights reserved.