将回归树输出转换为 pandas 表

这段代码适合 python 中的回归树。我想将此基于文本的输出转换为表格格式。

import pandas as pd

import numpy as np

from sklearn.tree import DecisionTreeRegressor

from sklearn import tree


dataset = np.array( 

[['Asset Flip', 100, 1000], 

['Text Based', 500, 3000], 

['Visual Novel', 1500, 5000], 

['2D Pixel Art', 3500, 8000], 

['2D Vector Art', 5000, 6500], 

['Strategy', 6000, 7000], 

['First Person Shooter', 8000, 15000], 

['Simulator', 9500, 20000], 

['Racing', 12000, 21000], 

['RPG', 14000, 25000], 

['Sandbox', 15500, 27000], 

['Open-World', 16500, 30000], 

['MMOFPS', 25000, 52000], 

['MMORPG', 30000, 80000] 

]) 


X = dataset[:, 1:2].astype(int)


y = dataset[:, 2].astype(int)  


regressor = DecisionTreeRegressor(random_state = 0) 


regressor.fit(X, y) 


text_rule = tree.export_text(regressor )


print(text_rule)

我得到的输出是这样的


print(text_rule)

|--- feature_0 <= 20750.00

|   |--- feature_0 <= 7000.00

|   |   |--- feature_0 <= 1000.00

|   |   |   |--- feature_0 <= 300.00

|   |   |   |   |--- value: [1000.00]

|   |   |   |--- feature_0 >  300.00

|   |   |   |   |--- value: [3000.00]

|   |   |--- feature_0 >  1000.00

|   |   |   |--- feature_0 <= 2500.00

|   |   |   |   |--- value: [5000.00]

|   |   |   |--- feature_0 >  2500.00

|   |   |   |   |--- feature_0 <= 4250.00

|   |   |   |   |   |--- value: [8000.00]

|   |   |   |   |--- feature_0 >  4250.00

|   |   |   |   |   |--- feature_0 <= 5500.00

|   |   |   |   |   |   |--- value: [6500.00]

|   |   |   |   |   |--- feature_0 >  5500.00

|   |   |   |   |   |   |--- value: [7000.00]

|   |--- feature_0 >  7000.00

|   |   |--- feature_0 <= 13000.00

|   |   |   |--- feature_0 <= 8750.00

|   |   |   |   |--- value: [15000.00]

|   |   |   |--- feature_0 >  8750.00

我想在 pandas 表中转换此规则,类似于以下形式。这个怎么做 ?

http://img2.sycdn.imooc.com/64b6285100013c9c06470132.jpg

规则的情节版本是这样的(供参考)。请注意,在表中我显示了规则的最左边部分。

http://img4.sycdn.imooc.com/64b628610001ed6111510717.jpg


月关宝盒
浏览 99回答 2
2回答

蓝山帝景

import sklearnimport pandas as pddef tree_to_df(reg_tree, feature_names):    tree_ = reg_tree.tree_    feature_name = [        feature_names[i] if i != sklearn.tree._tree.TREE_UNDEFINED else "undefined!"        for i in tree_.feature    ]        def recurse(node, row, ret):        if tree_.feature[node] != sklearn.tree._tree.TREE_UNDEFINED:            name = feature_name[node]            threshold = tree_.threshold[node]            # Add rule to row and search left branch            row[-1].append(name + " <= " +  str(threshold))            recurse(tree_.children_left[node], row, ret)            # Add rule to row and search right branch            row[-1].append(name + " > " +  str(threshold))            recurse(tree_.children_right[node], row, ret)        else:            # Add output rules and start a new row            label = tree_.value[node]            ret.append("return " + str(label[0][0]))            row.append([])        # Initialize    rules = [[]]    vals = []        # Call recursive function with initial values    recurse(0, rules, vals)        # Convert to table and output    df = pd.DataFrame(rules).dropna(how='all')    df['Return'] = pd.Series(vals)    return df这将返回一个 pandas 数据框:                     0                   1                   2                 3          Return0   feature <= 20750.0   feature <= 7000.0   feature <= 1000.0  feature <= 300.0   return 1000.01      feature > 300.0                None                None              None   return 3000.02     feature > 1000.0   feature <= 2500.0                None              None   return 5000.03     feature > 2500.0   feature <= 4250.0                None              None   return 8000.04     feature > 4250.0   feature <= 5500.0                None              None   return 6500.05     feature > 5500.0                None                None              None   return 7000.06     feature > 7000.0  feature <= 13000.0   feature <= 8750.0              None  return 15000.07     feature > 8750.0  feature <= 10750.0                None              None  return 20000.08    feature > 10750.0                None                None              None  return 21000.09    feature > 13000.0  feature <= 16000.0  feature <= 14750.0              None  return 25000.010   feature > 14750.0                None                None              None  return 27000.011   feature > 16000.0                None                None              None  return 30000.012   feature > 20750.0  feature <= 27500.0                None              None  return 52000.013   feature > 27500.0                None                None              None  return 80000.0

慕姐4208626

如果您正在处理分类决策树,您可以尝试一下import pandas as pdtext="""|--- Age <= 0.63|&nbsp; &nbsp;|--- EstimatedSalary <= 0.61|&nbsp; &nbsp;|&nbsp; &nbsp;|--- Age <= -0.16|&nbsp; &nbsp;|&nbsp; &nbsp;|&nbsp; &nbsp;|--- class: 0|&nbsp; &nbsp;|&nbsp; &nbsp;|--- Age >&nbsp; -0.16|&nbsp; &nbsp;|&nbsp; &nbsp;|&nbsp; &nbsp;|--- EstimatedSalary <= -0.06|&nbsp; &nbsp;|&nbsp; &nbsp;|&nbsp; &nbsp;|&nbsp; &nbsp;|--- class: 0|&nbsp; &nbsp;|&nbsp; &nbsp;|&nbsp; &nbsp;|--- EstimatedSalary >&nbsp; -0.06|&nbsp; &nbsp;|&nbsp; &nbsp;|&nbsp; &nbsp;|&nbsp; &nbsp;|--- EstimatedSalary <= 0.40|&nbsp; &nbsp;|&nbsp; &nbsp;|&nbsp; &nbsp;|&nbsp; &nbsp;|&nbsp; &nbsp;|--- EstimatedSalary <= 0.03|&nbsp; &nbsp;|&nbsp; &nbsp;|&nbsp; &nbsp;|&nbsp; &nbsp;|&nbsp; &nbsp;|&nbsp; &nbsp;|--- class: 1"""def tree_parser(text):&nbsp; &nbsp; lines=text.splitlines()&nbsp; &nbsp; max_levels=max([l.count('|') for l in lines])&nbsp; &nbsp; result={}&nbsp; &nbsp; for i in range(0,max_levels+1):&nbsp; &nbsp; &nbsp; &nbsp; result['Column'+str(i)]=[]&nbsp; &nbsp; for line in lines:&nbsp; &nbsp; &nbsp; &nbsp; level=line.count('|')&nbsp; &nbsp; &nbsp; &nbsp; currvalue=result.get('Column'+str(level),[])&nbsp; &nbsp; &nbsp; &nbsp; currvalue.append(line.replace('|','').replace('-',''))&nbsp; &nbsp; &nbsp; &nbsp; result['Column'+str(level)]=currvalue&nbsp; &nbsp; &nbsp; &nbsp; for i in range(0, max_levels + 1):&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; if i>level and line.find('class')!=-1:&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; result['Column' + str(i)].append(None)&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; if i<level:&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; parent_value=result.get('Column' + str(i),[])&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; if len(parent_value)!=len(currvalue):&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; parent_value.append(parent_value[len(parent_value)-1])&nbsp; &nbsp; return resultresult=tree_parser(text)df=pd.DataFrame(result)df=df.drop(columns=['Column0'])df.to_csv('treeout1.csv',index=False)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python