如何在MLReader上创建泛型函数

问题描述 投票:1回答:1

我在Spark 1.6.3工作。以下是两个执行相同操作的函数:

def modelFromBytesCV(modelArray: Array[Byte]): CountVectorizerModel = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  Files.write(tempPath, modelArray)
  CountVectorizerModel.read.load(tempPath.toString)
}

def modelFromBytesIDF(modelArray: Array[Byte]): IDFModel = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  Files.write(tempPath, modelArray)
  IDFModel.read.load(tempPath.toString)
}

我想使这些功能通用。我所依赖的是CountVectorizerModel对象和IDFModel之间的共同特征是MLReadable [T],它本身必须采用CountVectorizerModel或IDFModel类型。这是一种递归的父类循环,我无法找到解决方案。

相比之下,通用模型编写器很容易,因为MLWritable是我感兴趣的所有模型扩展的共同特征:

def modelToBytes[M <: MLWritable](model: M): Array[Byte] = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  model.write.overwrite().save(tempPath.toString)
  Files.readAllBytes(tempPath)
}

如何制作一个将spark-ml模型转换为字节数组的通用读取器?

scala apache-spark apache-spark-ml
1个回答
2
投票

要使其工作,您需要访问特定的MlReadable对象。

import org.apache.spark.ml.util.MLReadable

def modelFromBytes[M](obj: MLReadable[M], modelArray: Array[Byte]): M = {
  val tempPath: Path = ???
  ...
  obj.read.load(tempPath.toString)
}

以后可以用作:

val bytes: Array[Byte] = ???
modelFromBytes(CountVectorizerModel, bytes)

请注意,尽管第一次出现,但这里没有任何递归 - MLReadable[M]指的是伴侣对象,而不是类。所以例如CountVectorizerModel objectMLReadable,而CountVectorizeModel class不是。

在内部,Spark MLReader以不同的方式处理这个问题 - it creates an instance of the class using reflection,然后是sets its Params。但是这条路径对你来说不是很有用*。

如果需要与当前API的兼容性,您可以尝试隐藏可读对象:

def modelFromBytes[M](modelArray: Array[Byte])(implicit obj: MLReadable[M]): M = {
  ...
}

然后

implicit val readable: MLReadable[CountVectorizerModel] = CountVectorizerModel

modelFromBytes[CountVectorizerModel](bytes)

*从技术上讲,可以通过反射获得伴侣对象

def modelFromBytesCV[M <: MLWritable](
    modelArray: Array[Byte])(implicit ct: ClassTag[M]): M = {
  val tempPath: Path = ???
  ...
  val cls = Class.forName(ct.runtimeClass.getName + "$");
  cls.getField("MODULE$").get(cls).asInstanceOf[MLReadable[M]]
    .read.load(tempPath.toString)) 
}

但我不认为这是一条值得探讨的道路。特别是我们不能在这里真正提供严格的类型界限 - 使用MLWritable是一个限制人为错误的黑客,但对编译器来说却无用。

© www.soinside.com 2019 - 2024. All rights reserved.