在分类树上递归期间的累积条件

问题描述 投票:-1回答:1

我具有以下功能,该功能从sci-kit学习分类树生成代码:

def mxTreeToCode(tree, feature_names, mx_name = 'mxTree', rm_file = False):

    # Remove pre-existent file
    if rm_file:
        import os
        try:
            os.remove('./tree.py')
        except OSError:
            pass

    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    file = open('tree.py', 'a')
    file.write('def ' + mx_name + '(x):'+ '\n') 
    #col_name = ''
    def recurse(node, depth):
        global col_name
        indent = "    " * depth

        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]

            file.write(indent +"if x['"+ name + "'] <= " + str(threshold) + ':' + '\n')
            col_name += "'"+name + '_' + '<=' + str(threshold) +"'"

            recurse(tree_.children_left[node], depth + 1)


            file.write(indent + "else: # if x['"+ name +"'] > " + str(threshold) + '\n')
            col_name += "'"+name + '_' + '>' + str(threshold) +"'"

            recurse(tree_.children_right[node], depth + 1)


        else:
            file.write(indent + 'return '+str(col_name) + '\n')
            #print(col_name)
            col_name = ""

    recurse(0, 1)
    file.close()

这样,对于给定的分类树,我在文件'tree.py'上获得以下输出:

def mxTree(x):
    if x['V1'] <= 0.5:
        if x['V2'] <= 0.5:
            return 'V1_<=0.5''V2_<=0.5'
        else: # if x['V2'] > 0.5
            return 'V2_>0.5'
    else: # if x['V1'] > 0.5
        return 'V1_>0.5'

虽然我可以在IF端累积条件并返回条件的加法,但是当IF和ELSE(树节点的左侧/右侧)跟随时,我无法进行累加:

def mxTree(x):
    if x['V1'] <= 0.5:
        if x['V2'] <= 0.5:
            return 'V1_<=0.5''V2_<=0.5'
        else: # if x['V2'] > 0.5
            return 'V1_<=0.5''V2_>0.5' # 'V1<=0.5' must be added
    else: # if x['V1'] > 0.5
        return 'V1_>0.5'

我将不胜感激任何建议。

python recursion scikit-learn tree classification
1个回答
0
投票

由于每个节点的左侧/右侧都同时递归,所以我刚刚创建了一个额外的变量,该变量保存了每一侧的输出。最后,我连接到变量col_name:

col_name = ""
names_list={}
def mxTreeToCode(tree, feature_names, mx_name = 'mxTree', rm_file = False):

    # Remove pre-existent file
    if rm_file:
        import os
        try:
            os.remove('./tree.py')
        except OSError:
            pass

    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    file = open('tree.py', 'a')
    file.write('def ' + mx_name + '(x):'+ '\n') 

    def recurse(node, depth):
        global col_name, names_list
        indent = "    " * depth
        names_list[node] = col_name
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]

            file.write(indent +"if x['"+ name + "'] <= " + str(threshold) + ':' + '\n')
            col_name += "'"+name + '_' + '<=' + str(threshold) +"'"

            recurse(tree_.children_left[node], depth + 1)


            file.write(indent + "else: # if x['"+ name +"'] > " + str(threshold) + '\n')
            col_name += names_list[node]
            col_name += "'"+name + '_' + '>' + str(threshold) +"'"

            recurse(tree_.children_right[node], depth + 1)


        else:
            file.write(indent + 'return '+str(col_name) + '\n')
            col_name = ""

    recurse(0, 1)
    file.close()

我想知道是否还有其他可行的方法。

© www.soinside.com 2019 - 2024. All rights reserved.