由于缺少类的位置参数,无法在 Spark ml 管道中使用自定义转换器加载 pyspark.ml.Pipeline 模型

问题描述 投票:0回答:2

我有一个自定义转换器 python 类,与 Spark mllib 管道一起使用。我想保存模型并使用 Spark 将其加载到另一个会话中。我能够记录模型,但之后无法加载它,因为置信转换器需要标签并且缺少位置参数。我正在使用 pyspark==3.3.0.

from pyspark.ml import Transformer, PipelineModel
from pyspark.ml.param import Param, Params
from pyspark.ml.util import DefaultParamsWriter, DefaultParamsReader, MLReadable, MLWritable, MLWriter
from pyspark.sql.functions import lit
from pyspark.ml.feature import StringIndexer
from pyspark.ml.classification import LogisticRegression


class Confidence(Transformer, DefaultParamsReadable, DefaultParamsWritable):
    """
    A custom Transformer which does some cleanup of the output of the model and creates a column a confidence metric based on a T distribution.
    """
    
    labels = Param(
        Params._dummy(),
        "labels",
        "Count of labels for degrees of freedom",
        typeConverter=TypeConverters.toInt)

    def __init__(self, labels: int):
        super(Confidence, self).__init__()
        self._setDefault(labels=labels)
        
    def getLabels(self):
        return self.getOrDefault(self.labels)

    def setLabels(self, value):
        self._set(labels=value)

    def _transform(self, df):
        return df.withColumn("labelCount", lit(self.getLabels()))
# String Indexer to convert feature column
stringIndexer = StringIndexer(inputCol = "feature", outputCol = "label").fit(train)
# Fit model
lr = LogisticRegression()
# Get count of labels from string indexer to pass to confidence
labelCount = len(stringIndexer.labels)
confidence = Confidence(labels = labelCount)

# Create pipeline and fit model
pipeline = Pipeline().setStages([stringIndexer, lr, confidence])
pipeline_model = pipeline.fit(train_df)

basePath = "/tmp/mllib-persistence-example"
pipeline_model.write().overwrite().save(basePath + "/model")
model_loaded = Pipeline.load(basePath + "/model")

我收到此错误消息:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<command-1923614380916072> in <cell line: 1>()
----> 1 model_loaded = Pipeline.load(basePath + "/model")

/databricks/spark/python/pyspark/ml/util.py in load(cls, path)
    444     def load(cls, path: str) -> RL:
    445         """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
--> 446         return cls.read().load(path)
    447 
    448 

/databricks/spark/python/pyspark/ml/pipeline.py in load(self, path)
    247             return JavaMLReader(cast(Type["JavaMLReadable[Pipeline]"], self.cls)).load(path)
    248         else:
--> 249             uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
    250             return Pipeline(stages=stages)._resetUid(uid)
    251 

/databricks/spark/python/pyspark/ml/pipeline.py in load(metadata, sc, path)
    437                 stageUid, index, len(stageUids), stagesDir
    438             )
--> 439             stage: "PipelineStage" = DefaultParamsReader.loadParamsInstance(stagePath, sc)
    440             stages.append(stage)
    441         return (metadata["uid"], stages)

/databricks/spark/python/pyspark/ml/util.py in loadParamsInstance(path, sc)
    727             pythonClassName = metadata["class"].replace("org.apache.spark", "pyspark")
    728         py_type: Type[RL] = DefaultParamsReader.__get_class(pythonClassName)
--> 729         instance = py_type.load(path)
    730         return instance
    731 

/databricks/spark/python/pyspark/ml/util.py in load(cls, path)
    444     def load(cls, path: str) -> RL:
    445         """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
--> 446         return cls.read().load(path)
    447 
    448 

/databricks/spark/python/pyspark/ml/util.py in load(self, path)
    638         metadata = DefaultParamsReader.loadMetadata(path, self.sc)
    639         py_type: Type[RL] = DefaultParamsReader.__get_class(metadata["class"])
--> 640         instance = py_type()
    641         cast("Params", instance)._resetUid(metadata["uid"])
    642         DefaultParamsReader.getAndSetParams(instance, metadata)

TypeError: __init__() missing 1 required positional argument: 'labels'
python apache-spark databricks transformer-model mlflow
2个回答
0
投票

如果您尝试加载已安装的管道,不应该是:

model_loaded = PipelineModel.load(basePath + "/model")


0
投票

在构造函数中将标签默认设置为

None

def __init__(self, labels: int = None):
© www.soinside.com 2019 - 2024. All rights reserved.