如何在LightGBM中获得每棵树的预测?

问题描述 投票:0回答:1

我正在使用 LightGBM 和 Python 来解决多类分类问题。我想知道每棵树的单独预测。然而,我还没有成功。您能告诉我如何访问给定测试样本的每棵树的预测吗?我的最终目标是了解 LGBMClassifier.predict_proba 如何估计类别概率。

python random-forest decision-tree lightgbm
1个回答
0
投票

对于 LightGBM 模型,每棵树的叶子都包含以目标函数为单位的原始值,并对叶子值进行求和以产生预测。

LightGBM 的

predict()
API 允许您自定义使用哪些迭代树。

这是一个 Python 示例,展示了如何从 LightGBM 模型中获取预测:

  • 所有树直到给定的迭代
  • 来自特定迭代的所有树
import lightgbm as lgb import numpy as np from sklearn.datasets import make_blobs X, y = make_blobs(n_samples=10_000, centers=3) dtrain = lgb.Dataset(X, label=y) bst = lgb.train( train_set=dtrain, params={ "objective": "multiclass", "num_class": 3, "num_leaves": 7 }, num_boost_round=10 ) # get predictions using all trees up to the 5th tree preds_first5 = bst.predict( X, num_iteration=5, raw_score=True ) # get predictions using all trees up to the 6th tree preds_first6 = bst.predict( X, num_iteration=6, raw_score=True ) # get predictions from exactly the 6th tree preds_tree6 = bst.predict( X, start_iteration=5, num_iteration=1, raw_score=True ) # confirm that those things are consistent np.testing.assert_allclose( actual=preds_tree6, desired=preds_first6 - preds_first5, rtol=1e-10 )
请注意,对于多类分类,对于每次 boosting 迭代,LightGBM 为每个类训练 1 棵树。使用上面示例中训练的模型尝试以下操作。

# 30 trees (3 classes * 10 iterations) assert bst.num_trees() == 30
LightGBM 的 

predict()

 API 可根据
迭代(而不是树)进行配置。对于多类分类,它返回形状为 [num_data, num_classes]
 的数组,其中元素 
[i, j]
 是数据中第 
i
 个样本属于第 
j
 类的预测概率。

bst.predict(X, start_iteration=5, num_iteration=1)
array([[0.37597866, 0.31173329, 0.31228805],
       [0.37597866, 0.31173329, 0.31228805],
       [0.37597866, 0.31173329, 0.31228805],
       ...])
如果您确实想查看多类分类模型的一棵 

tree 的预测,请使用 lightgbm

 Python 包,安装 
pandas
 并将模型转储到 DataFrame。

model_df = bst.trees_to_dataframe() # dump the structure of exactly the 6th tree model_df[model_df["tree_index"] == 5].head(50)

树索引节点深度节点索引左孩子右孩子父索引分割特征分割增益阈值决策类型缺少方向缺少类型价值重量数65515-S05-L05-S1第_1栏1344.09-3.43421左无001000066525-L05-S0南南-0.09382971043.4332067525-S15-S25-S45-S0第_1栏3407.111.3275左无0.04343662255.11668068535-S25-L15-S35-S1第_1栏40.54530.785617左无0.1576561210.15336569545-L15-S2南南0.1621071142.59317070545-S35-L35-L45-S2Column_041.9748-3.34207左无0.082376367.551619571555-L35-S3南南-0.0083110929.07268772555-L45-S3南南0.15089538.47910873535-S45-S55-L55-S1第_1栏28.14221.96724左无-0.08883831044.96331574545-S55-L25-L65-S4Column_039.4162-2.79966左无-0.034631587.733527375555-L25-S5南南-0.062837574.534723476555-L65-S5南南0.1246511987.133977545-L55-S4南南-0.0938066957.2293042
<=
<=
<=
<=
<=
<=
预测值将来自叶节点之一(其中

node_index

 具有 
L
 的行)。您必须研究树的规则才能找出样本落入哪一片叶子。

© www.soinside.com 2019 - 2024. All rights reserved.