是否可以从表中创建任意 Scikit learn 决策树?

问题描述 投票:0回答:1

假设我有一个描述决策树的表 - 每个节点及其特征、阈值(分割点)及其左右子节点 - 或预测值,如果这是叶节点(此处用

status==-1 表示) 
)。

tree_structure = """ # (simplified example for clarity)
#   left_daughter right_daughter split_var split_point status prediction
1               2             3        2  394.250000      1          0
2               -1             -1       -1   0.0         -1          1
3               -1             -1       -1   0.0         -1          2
"""

# Convert the tree structure to a DataFrame
lines = tree_structure.strip().split("\n")
header = lines[0].split()
tree_data = [line.split() for line in lines[1:]]
df_tree = pd.DataFrame(tree_data, columns=header)

是否可以将此表/DataFrame 转换为 scikit learn 决策树?

python pandas scikit-learn decision-tree
1个回答
0
投票

traverse_tree(df,node_idx,样本) 我创建了一个递归函数

traverse_tree
,它遍历由您 df 表示的决策树,以对您提供的样本进行预测。

我包含的参数:

  • df
    :决策树的df
  • node_idx
    :是索引 df 中的当前节点,即我要评估的
  • sample
    :表示正在进行预测的输入样本的字典

如果当前节点(node_idx)是叶节点,该函数将返回该节点的预测值。 如果当前节点不是叶子节点,它将根据

split_var
split_point
列评估分割条件。这里出现了递归性,因为它将根据分割条件在左或右子节点上调用自身。

至于

predict(df, samples)
,它将使用 df 表示的决策树来预测样本列表的输出。 我开始迭代样本列表中的每个样本。然后,对于每个样本,我从树的根(索引 0)开始调用 traverse_tree 函数,并收集预测。

import pandas as pd

tree_structure = """
1               2             3        2  394.250000      1          0
2               -1            -1       -1   0.0         -1          1
3               -1            -1       -1   0.0         -1          2
"""

header = ["index", "left_daughter", "right_daughter", "split_var", "split_point", "status", "prediction"]

tree_data = [line.split() for line in tree_structure.strip().split("\n")]
df_tree = pd.DataFrame(tree_data, columns=header).apply(pd.to_numeric)  

def traverse_tree(df, node_idx, sample):
    node = df.iloc[node_idx]
    
    if node['status'] == -1:
        return node['prediction']
    
    if sample[int(node['split_var'])] <= node['split_point']:
        return traverse_tree(df, int(node['left_daughter']) - 1, sample)
    else:
        return traverse_tree(df, int(node['right_daughter']) - 1, sample)

def predict(df, samples):
    return [traverse_tree(df, 0, sample) for sample in samples]

samples = [{2: 395}, {2: 390}]
predictions = predict(df_tree, samples)
print(predictions)  # Output will be based on the decision tree structure

© www.soinside.com 2019 - 2024. All rights reserved.