我在Spark 1.6.3工作。以下是两个执行相同操作的函数:
def modelFromBytesCV(modelArray: Array[Byte]): CountVectorizerModel = {
val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
Files.write(tempPath, modelArray)
CountVectorizerModel.read.load(tempPath.toString)
}
def modelFromBytesIDF(modelArray: Array[Byte]): IDFModel = {
val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
Files.write(tempPath, modelArray)
IDFModel.read.load(tempPath.toString)
}
我想使这些功能通用。我所依赖的是CountVectorizerModel对象和IDFModel之间的共同特征是MLReadable [T],它本身必须采用CountVectorizerModel或IDFModel类型。这是一种递归的父类循环,我无法找到解决方案。
相比之下,通用模型编写器很容易,因为MLWritable是我感兴趣的所有模型扩展的共同特征:
def modelToBytes[M <: MLWritable](model: M): Array[Byte] = {
val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
model.write.overwrite().save(tempPath.toString)
Files.readAllBytes(tempPath)
}
如何制作一个将spark-ml模型转换为字节数组的通用读取器?
要使其工作,您需要访问特定的MlReadable
对象。
import org.apache.spark.ml.util.MLReadable
def modelFromBytes[M](obj: MLReadable[M], modelArray: Array[Byte]): M = {
val tempPath: Path = ???
...
obj.read.load(tempPath.toString)
}
以后可以用作:
val bytes: Array[Byte] = ???
modelFromBytes(CountVectorizerModel, bytes)
请注意,尽管第一次出现,但这里没有任何递归 - MLReadable[M]
指的是伴侣对象,而不是类。所以例如CountVectorizerModel
object是MLReadable
,而CountVectorizeModel
class不是。
在内部,Spark MLReader
以不同的方式处理这个问题 - it creates an instance of the class using reflection,然后是sets its Params
。但是这条路径对你来说不是很有用*。
如果需要与当前API的兼容性,您可以尝试隐藏可读对象:
def modelFromBytes[M](modelArray: Array[Byte])(implicit obj: MLReadable[M]): M = {
...
}
然后
implicit val readable: MLReadable[CountVectorizerModel] = CountVectorizerModel
modelFromBytes[CountVectorizerModel](bytes)
*从技术上讲,可以通过反射获得伴侣对象
def modelFromBytesCV[M <: MLWritable](
modelArray: Array[Byte])(implicit ct: ClassTag[M]): M = {
val tempPath: Path = ???
...
val cls = Class.forName(ct.runtimeClass.getName + "$");
cls.getField("MODULE$").get(cls).asInstanceOf[MLReadable[M]]
.read.load(tempPath.toString))
}
但我不认为这是一条值得探讨的道路。特别是我们不能在这里真正提供严格的类型界限 - 使用MLWritable
是一个限制人为错误的黑客,但对编译器来说却无用。