将模型输出写入文本文件spark scala

问题描述 投票:0回答:1

我使用spark MLlib拟合了以下逻辑回归模型

val df = spark.read.option("header","true").option("inferSchema","true").csv("car_milage-6f50d.csv")
val hasher = new FeatureHasher().setInputCols(Array("mpg","displacement","hp","torque")).setOutputCol("features")
val transformed = hasher.transform(df)
val Array(training, test) = transformed.randomSplit(Array(0.8, 0.2))
val lr = new LogisticRegression()
  .setFeaturesCol("features")
  .setLabelCol("automatic")
  .setMaxIter(20)
val paramGrid = new ParamGridBuilder()
  .addGrid(lr.regParam, Array(0.1,0.3))
  .addGrid(lr.elasticNetParam, Array(0.9,1))
  .build()
val cv = new CrossValidator()
  .setEstimator(lr)
  .setEvaluator(new BinaryClassificationEvaluator())
  .setEstimatorParamMaps(paramGrid)
  .setNumFolds(10)
  .setParallelism(2)

val model = cv.fit(training)
val results = model.transform(test).select("features", "automatic", "prediction")

val predictionAndLabels = results.select("prediction","label").as[(Double, Double)].rdd

最后我获得了这些模型评估指标

val mMetrics = new MulticlassMetrics(predictionAndLabels)
mMetrics.confusionMatrix
mMetrics.labels
mMetrics.accuracy

作为文件步骤,我需要将这些评估指标(mMetrics)写入文件(可以是csv文件的文本文件)。谁能帮助我该怎么做?

我刚刚尝试过,但找不到与这些值相关的任何写方法。

谢谢

scala apache-spark apache-spark-mllib
1个回答
1
投票

[通过查看MultiClassMetrics的方法摘要,我认为您应该可以这样做:

val confusionMatrixOutput = mMetrics.confusionMatrix.toArray
val confusionMatrixOutputFinal = spark.parallelize(confusionMatrixOutput)
confusionMatrixOutputFinal.coalesce(1).saveAsTextFile("C:/confusionMatrixOutput.txt")

您应该能够对mMetrics.labels执行相同的操作:

val labelsOutput = mMetrics.labels
val labelsOutputFinal = spark.parallelize(labelsOutput)
labelsOutputFinal.coalesce(1).saveAsTextFile("C:/labelsOutput.txt")

准确度应该是原来的两倍,因此您可以轻松地打印以下内容:

val accuracy = mMetrics.accuracy
println("Summary Statistics")
println(s"Accuracy = $accuracy")

您应该能够将逻辑回归模型的所有统计信息写到单个文件中,如下所示:

import java.io._

object MulticlassMetricsOutputWriter {

  def main(args:Array[String]) {

    // All your other code can be added here

    val mMetrics = new MulticlassMetrics(predictionAndLabels)
    val labels = mMetrics.labels

    // Create new file and passing reference of file to the printWriter
    val pw = new PrintWriter(new File("C:/mllib_lr_output.txt"))

    // Confusion Matrix
    val confusionMatrixOutput = mMetrics.confusionMatrix.toArray
    val confusionMatrixOutputFinal = spark.parallelize(confusionMatrixOutput)
    pw.write(s"ConfusionMatrix:\n + $confusionMatrixOutputFinal")

    // Labels
    val labelsOutput = mMetrics.labels
    val labelsOutputFinal = spark.parallelize(labelsOutput)
    pw.write(s"labels:\n + $labelsOutputFinal")

    // False positive rate by label
    labels.foreach { l =>
      pw.write(s"FPR($l) = " + mMetrics.falsePositiveRate(l) + "\n")
    }

    // True positive rate by label
    labels.foreach { l =>
      pw.write(s"TPR($l) = " + mMetrics.truePositiveRate(l) + "\n")
    }

    // F-measure by label
    labels.foreach { l =>
      pw.write(s"F1-Score($l) = " + mMetrics.fMeasure(l) + "\n")
    }

    // Precision by label
    labels.foreach { l =>
      pw.write(s"Precision($l) = " + mMetrics.precision(l) + "\n")
    }

    // Recall by label
    labels.foreach { l =>
      pw.write(s"Recall($l) = " + mMetrics.recall(l) + "\n")
    }

    val accuracy = mMetrics.accuracy
    val weightedFalsePositiveRate = mMetrics.weightedFalsePositiveRate
    val weightedFMeasure = mMetrics.weightedFMeasure
    val weightedPrecision = mMetrics.weightedPrecision
    val weightedRecall = mMetrics.weightedRecall
    val weightedTruePositiveRate = mMetrics.weightedTruePositiveRate

    pw.write("Summary Statistics" + "\n")
    pw.write(s"Accuracy = $accuracy" + "\n")
    pw.write(s"weightedFalsePositiveRate = $weightedFalsePositiveRate" + "\n")
    pw.write(s"weightedFMeasure = $weightedFMeasure" + "\n")
    pw.write(s"weightedPrecision = $weightedPrecision" + "\n")
    pw.write(s"weightedRecall = $weightedRecall" + "\n")
    pw.write(s"weightedTruePositiveRate = $weightedTruePositiveRate" + "\n")

    // Closing the printWriter connection
    pw.close
  }
}

-1
投票

如果您的数据框不为空,则可以按如下所示编写内容:

<dataFrame>.write.format("csv").save("/your/location/data.csv")

或如果要在一个文件中

<dataFrame>.coalesce(1).write.csv("/your/location/data.csv")
© www.soinside.com 2019 - 2024. All rights reserved.