无法在 spark ml 管道中加载带有自定义转换器的 mlflow 模型,因为缺少带有 mlflow.spark.log_model 的类的位置参数

问题描述 投票:0回答:0

我有一个与 spark mllib 管道一起使用的自定义转换器 python 类。我想将模型记录到 mlflow,将其注册到 mlflow 模型注册表,然后使用 spark 加载它。我能够记录模型,但之后我无法加载它,因为信心转换器需要标签并且它缺少位置参数。我正在使用 mlflow==2.2.2 和 pyspark==3.3.0.

from pyspark.ml import Transformer, PipelineModel
from pyspark.ml.param import Param, Params
from pyspark.ml.util import DefaultParamsWriter, DefaultParamsReader, MLReadable, MLWritable, MLWriter
from pyspark.sql.functions import lit
import mlflow
from pyspark.ml.feature import StringIndexer
from pyspark.ml.classification import LogisticRegression


class Confidence(Transformer, DefaultParamsReadable, DefaultParamsWritable):
    """
    A custom Transformer which does some cleanup of the output of the model and creates a column a confidence metric based on a T distribution.
    """
    
    labels = Param(
        Params._dummy(),
        "labels",
        "Count of labels for degrees of freedom",
        typeConverter=TypeConverters.toInt)

    def __init__(self, labels: int):
        super(Confidence, self).__init__()
        self._setDefault(labels=labels)
        
    def getLabels(self):
        return self.getOrDefault(self.labels)

    def setLabels(self, value):
        self._set(labels=value)

    def _transform(self, df):
        return df.withColumn("labelCount", lit(self.getLabels()))
# String Indexer to convert feature column
stringIndexer = StringIndexer(inputCol = "feature", outputCol = "label").fit(train)
# Fit model
lr = LogisticRegression()
# Get count of labels from string indexer to pass to confidence
labelCount = len(stringIndexer.labels)
confidence = Confidence(labels = labelCount)

with mlflow.start_run():
  # Create pipeline and fit model
  pipeline = Pipeline().setStages([stringIndexer, lr, confidence])
  pipeline_model = pipeline.fit(train_df)

  # Log model to mlflow
  mlflow.spark.log_model(pipeline_model, 'model')
import mlflow
logged_model = 'runs:/edffe7204da747d0a8cfd816120ebcad/model'

loaded_model = mlflow.spark.load_model(logged_model)

我收到此错误消息:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<command-3447613862254112> in <cell line: 5>()
      3 
      4 # Load model
----> 5 loaded_model = mlflow.spark.load_model(logged_model)

/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.9/site-packages/mlflow/spark.py in load_model(model_uri, dfs_tmpdir, dst_path)
    803     sparkml_model_uri = append_to_uri_path(model_uri, flavor_conf["model_data"])
    804     local_sparkml_model_path = os.path.join(local_mlflow_model_path, flavor_conf["model_data"])
--> 805     return _load_model(
    806         model_uri=sparkml_model_uri,
    807         dfs_tmpdir_base=dfs_tmpdir,

/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.9/site-packages/mlflow/spark.py in _load_model(model_uri, dfs_tmpdir_base, local_model_path)
    735     dfs_tmpdir = generate_tmp_dfs_path(dfs_tmpdir_base or MLFLOW_DFS_TMP.get())
    736     if databricks_utils.is_in_cluster() and databricks_utils.is_dbfs_fuse_available():
--> 737         return _load_model_databricks(
    738             dfs_tmpdir, local_model_path or _download_artifact_from_uri(model_uri)
    739         )

/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.9/site-packages/mlflow/spark.py in _load_model_databricks(dfs_tmpdir, local_model_path)
    727     # errors on passthrough-enabled clusters when attempting to copy permission bits for directories
    728     shutil_copytree_without_file_permissions(src_dir=local_model_path, dst_dir=fuse_dfs_tmpdir)
--> 729     return PipelineModel.load(dfs_tmpdir)
    730 
    731 

/databricks/spark/python/pyspark/ml/util.py in load(cls, path)
    444     def load(cls, path: str) -> RL:
    445         """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
--> 446         return cls.read().load(path)
    447 
    448 

/databricks/spark/python/pyspark/ml/pipeline.py in load(self, path)
    284             return JavaMLReader(cast(Type["JavaMLReadable[PipelineModel]"], self.cls)).load(path)
    285         else:
--> 286             uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
    287             return PipelineModel(stages=cast(List[Transformer], stages))._resetUid(uid)
    288 

/databricks/spark/python/pyspark/ml/pipeline.py in load(metadata, sc, path)
    437                 stageUid, index, len(stageUids), stagesDir
    438             )
--> 439             stage: "PipelineStage" = DefaultParamsReader.loadParamsInstance(stagePath, sc)
    440             stages.append(stage)
    441         return (metadata["uid"], stages)

/databricks/spark/python/pyspark/ml/util.py in loadParamsInstance(path, sc)
    727             pythonClassName = metadata["class"].replace("org.apache.spark", "pyspark")
    728         py_type: Type[RL] = DefaultParamsReader.__get_class(pythonClassName)
--> 729         instance = py_type.load(path)
    730         return instance
    731 

/databricks/spark/python/pyspark/ml/util.py in load(cls, path)
    444     def load(cls, path: str) -> RL:
    445         """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
--> 446         return cls.read().load(path)
    447 
    448 

/databricks/spark/python/pyspark/ml/util.py in load(self, path)
    638         metadata = DefaultParamsReader.loadMetadata(path, self.sc)
    639         py_type: Type[RL] = DefaultParamsReader.__get_class(metadata["class"])
--> 640         instance = py_type()
    641         cast("Params", instance)._resetUid(metadata["uid"])
    642         DefaultParamsReader.getAndSetParams(instance, metadata)

TypeError: __init__() missing 1 required positional argument: 'labels'

我已经尝试了各种方法来添加到我的类中以及它的许多变体,这是我尝试使用自定义类保存标签元数据的最远方法:

    def write(self):
         metadata = {"labels": self.getLabels()}
         return PipelineModel([self]).write().option("metadata", json.dumps(metadata))
python apache-spark databricks transformer-model mlflow
© www.soinside.com 2019 - 2024. All rights reserved.