慕容3067478
我相信这个答案比其他答案更正确:from sklearn.tree import _treedef tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print "def tree({}):".format(", ".join(feature_names))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print "{}if {} <= {}:".format(indent, name, threshold)
recurse(tree_.children_left[node], depth + 1)
print "{}else: # if {} > {}".format(indent, name, threshold)
recurse(tree_.children_right[node], depth + 1)
else:
print "{}return {}".format(indent, tree_.value[node])
recurse(0, 1)这将打印出有效的Python函数。以下是尝试返回其输入的树的示例输出,该数字介于0和10之间。def tree(f0):
if f0 <= 6.0:
if f0 <= 1.5:
return [[ 0.]]
else: # if f0 > 1.5
if f0 <= 4.5:
if f0 <= 3.5:
return [[ 3.]]
else: # if f0 > 3.5
return [[ 4.]]
else: # if f0 > 4.5
return [[ 5.]]
else: # if f0 > 6.0
if f0 <= 8.5:
if f0 <= 7.5:
return [[ 7.]]
else: # if f0 > 7.5
return [[ 8.]]
else: # if f0 > 8.5
return [[ 9.]]以下是我在其他答案中看到的一些绊脚石:使用tree_.threshold == -2来决定一个节点是否为叶是不是一个好主意。如果它是一个阈值为-2的真实决策节点怎么办?相反,你应该看看tree.feature或tree.children_*。该行features = [feature_names[i] for i in tree_.feature]与我的sklearn版本崩溃,因为某些值为tree.tree_.feature-2(特别是对于叶节点)。递归函数中不需要多个if语句,只需一个就可以了。
翻翻过去那场雪
我修改了Zelazny7提交的代码来打印一些伪代码:def get_code(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node):
if (threshold[node] != -2):
print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
if left[node] != -1:
recurse (left, right, threshold, features,left[node])
print "} else {"
if right[node] != -1:
recurse (left, right, threshold, features,right[node])
print "}"
else:
print "return " + str(value[node])
recurse(left, right, threshold, features, 0)如果您get_code(dt, df.columns)使用相同的示例,您将获得:if ( col1 <= 0.5 ) {return [[ 1. 0.]]} else {if ( col2 <= 4.5 ) {return [[ 0. 1.]]} else {if ( col1 <= 2.5 ) {return [[ 1. 0.]]} else {return [[ 0. 1.]]}}}