我具有以下功能,该功能从sci-kit学习分类树生成代码:
def mxTreeToCode(tree, feature_names, mx_name = 'mxTree', rm_file = False):
# Remove pre-existent file
if rm_file:
import os
try:
os.remove('./tree.py')
except OSError:
pass
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
file = open('tree.py', 'a')
file.write('def ' + mx_name + '(x):'+ '\n')
#col_name = ''
def recurse(node, depth):
global col_name
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
file.write(indent +"if x['"+ name + "'] <= " + str(threshold) + ':' + '\n')
col_name += "'"+name + '_' + '<=' + str(threshold) +"'"
recurse(tree_.children_left[node], depth + 1)
file.write(indent + "else: # if x['"+ name +"'] > " + str(threshold) + '\n')
col_name += "'"+name + '_' + '>' + str(threshold) +"'"
recurse(tree_.children_right[node], depth + 1)
else:
file.write(indent + 'return '+str(col_name) + '\n')
#print(col_name)
col_name = ""
recurse(0, 1)
file.close()
这样,对于给定的分类树,我在文件'tree.py'上获得以下输出:
def mxTree(x):
if x['V1'] <= 0.5:
if x['V2'] <= 0.5:
return 'V1_<=0.5''V2_<=0.5'
else: # if x['V2'] > 0.5
return 'V2_>0.5'
else: # if x['V1'] > 0.5
return 'V1_>0.5'
虽然我可以在IF端累积条件并返回条件的加法,但是当IF和ELSE(树节点的左侧/右侧)跟随时,我无法进行累加:
def mxTree(x):
if x['V1'] <= 0.5:
if x['V2'] <= 0.5:
return 'V1_<=0.5''V2_<=0.5'
else: # if x['V2'] > 0.5
return 'V1_<=0.5''V2_>0.5' # 'V1<=0.5' must be added
else: # if x['V1'] > 0.5
return 'V1_>0.5'
我将不胜感激任何建议。
由于每个节点的左侧/右侧都同时递归,所以我刚刚创建了一个额外的变量,该变量保存了每一侧的输出。最后,我连接到变量col_name:
col_name = ""
names_list={}
def mxTreeToCode(tree, feature_names, mx_name = 'mxTree', rm_file = False):
# Remove pre-existent file
if rm_file:
import os
try:
os.remove('./tree.py')
except OSError:
pass
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
file = open('tree.py', 'a')
file.write('def ' + mx_name + '(x):'+ '\n')
def recurse(node, depth):
global col_name, names_list
indent = " " * depth
names_list[node] = col_name
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
file.write(indent +"if x['"+ name + "'] <= " + str(threshold) + ':' + '\n')
col_name += "'"+name + '_' + '<=' + str(threshold) +"'"
recurse(tree_.children_left[node], depth + 1)
file.write(indent + "else: # if x['"+ name +"'] > " + str(threshold) + '\n')
col_name += names_list[node]
col_name += "'"+name + '_' + '>' + str(threshold) +"'"
recurse(tree_.children_right[node], depth + 1)
else:
file.write(indent + 'return '+str(col_name) + '\n')
col_name = ""
recurse(0, 1)
file.close()
我想知道是否还有其他可行的方法。