我想从保存的(或未保存的)DecisionTreeClassificationModel
中获取树节点的权重。但是我找不到任何类似的东西。
模型如何实际执行分类而不知道其中任何一个。以下是模型中保存的Params:
{"class":"org.apache.spark.ml.classification.DecisionTreeClassificationModel"
"timestamp":1551207582648
"sparkVersion":"2.3.2"
"uid":"DecisionTreeClassifier_4ffc94d20f1ddb29f282"
"paramMap":{
"cacheNodeIds":false
"maxBins":32
"minInstancesPerNode":1
"predictionCol":"prediction"
"minInfoGain":0.0
"rawPredictionCol":"rawPrediction"
"featuresCol":"features"
"probabilityCol":"probability"
"checkpointInterval":10
"seed":956191873026065186
"impurity":"gini"
"maxMemoryInMB":256
"maxDepth":2
"labelCol":"indexed"
}
"numFeatures":1
"numClasses":2
}
通过使用treeWeights
:
treeWeights
返回每棵树的权重
1.5.0版中的新功能。
所以
模型如何实际执行分类而不知道其中任何一个。
存储权重,而不是作为元数据的一部分。如果你有model
from pyspark.ml.classification import RandomForestClassificationModel
model: RandomForestClassificationModel = ...
并将其保存到磁盘
path: str = ...
model.save(path)
你会看到作者创建了treesMetadata
子目录。如果加载内容(默认编写器使用Parquet):
import os
trees_metadata = spark.read.parquet(os.path.join(path, "treesMetadata"))
你会看到以下结构:
trees_metadata.printSchema()
root
|-- treeID: integer (nullable = true)
|-- metadata: string (nullable = true)
|-- weights: double (nullable = true)
其中weights
列包含由treeID
识别的树的重量。
类似地,节点数据存储在data
子目录中(例如参见Extract and Visualize Model Trees from Sparklyr):
spark.read.parquet(os.path.join(path, "data")).printSchema()
root
|-- id: integer (nullable = true)
|-- prediction: double (nullable = true)
|-- impurity: double (nullable = true)
|-- impurityStats: array (nullable = true)
| |-- element: double (containsNull = true)
|-- gain: double (nullable = true)
|-- leftChild: integer (nullable = true)
|-- rightChild: integer (nullable = true)
|-- split: struct (nullable = true)
| |-- featureIndex: integer (nullable = true)
| |-- leftCategoriesOrThreshold: array (nullable = true)
| | |-- element: double (containsNull = true)
| |-- numCategories: integer (nullable = true)
DecisionTreeClassificationModel
也提供等效信息(减去树数据和树权重)。