Scikit-learn 决策树提取特征节点

我有一个决策树,它在包含七个特征的数据集上进行了训练。我想要一些方法来轻松提取每个节点,其中数据在特定特征上被分割,并且哪个阈值或类别被用作分割标准。

这是我用来生成树的小模板。

clfSCReduced = DecisionTreeClassifier()
clfSCReduced.fit(featuresSC, labelsSC)

有什么建议吗?


holdtom
浏览 127回答 2
2回答

猛跑小猪

从DecisionTreeClassifier 文档中,它声明您可以从中获取Tree对象clsfSCReduced.tree_。scikit-learn 文档在这里有一个关于如何从树中获取信息的示例。该示例提供以下输出:The binary tree structure has 5 nodes and has the following tree structure:node=0 test node: go to node 1 if X[:, 3] <= 0.800000011920929 else to node 2.&nbsp; &nbsp; &nbsp; &nbsp; node=1 leaf node.&nbsp; &nbsp; &nbsp; &nbsp; node=2 test node: go to node 3 if X[:, 2] <= 4.950000047683716 else to node 4.&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; node=3 leaf node.&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; node=4 leaf node.Rules used to predict sample 0:decision id node 0 : (X_test[0, 3] (= 2.4) > 0.800000011920929)decision id node 2 : (X_test[0, 2] (= 5.1) > 4.950000047683716)The following samples [0, 1] share the node [0 2] in the treeIt is 40.0 % of all nodes.我在下面复制了他们的示例以完成import numpy as npfrom sklearn.model_selection import train_test_splitfrom sklearn.datasets import load_irisfrom sklearn.tree import DecisionTreeClassifieriris = load_iris()X = iris.datay = iris.targetX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)estimator.fit(X_train, y_train)# The decision estimator has an attribute called tree_&nbsp; which stores the entire# tree structure and allows access to low level attributes. The binary tree# tree_ is represented as a number of parallel arrays. The i-th element of each# array holds information about the node `i`. Node 0 is the tree's root. NOTE:# Some of the arrays only apply to either leaves or split nodes, resp. In this# case the values of nodes of the other type are arbitrary!## Among those arrays, we have:#&nbsp; &nbsp;- left_child, id of the left child of the node#&nbsp; &nbsp;- right_child, id of the right child of the node#&nbsp; &nbsp;- feature, feature used for splitting the node#&nbsp; &nbsp;- threshold, threshold value at the node## Using those arrays, we can parse the tree structure:n_nodes = estimator.tree_.node_countchildren_left = estimator.tree_.children_leftchildren_right = estimator.tree_.children_rightfeature = estimator.tree_.featurethreshold = estimator.tree_.threshold# The tree structure can be traversed to compute various properties such# as the depth of each node and whether or not it is a leaf.node_depth = np.zeros(shape=n_nodes, dtype=np.int64)is_leaves = np.zeros(shape=n_nodes, dtype=bool)stack = [(0, -1)]&nbsp; # seed is the root node id and its parent depthwhile len(stack) > 0:&nbsp; &nbsp; node_id, parent_depth = stack.pop()&nbsp; &nbsp; node_depth[node_id] = parent_depth + 1&nbsp; &nbsp; # If we have a test node&nbsp; &nbsp; if (children_left[node_id] != children_right[node_id]):&nbsp; &nbsp; &nbsp; &nbsp; stack.append((children_left[node_id], parent_depth + 1))&nbsp; &nbsp; &nbsp; &nbsp; stack.append((children_right[node_id], parent_depth + 1))&nbsp; &nbsp; else:&nbsp; &nbsp; &nbsp; &nbsp; is_leaves[node_id] = Trueprint("The binary tree structure has %s nodes and has "&nbsp; &nbsp; &nbsp; "the following tree structure:"&nbsp; &nbsp; &nbsp; % n_nodes)for i in range(n_nodes):&nbsp; &nbsp; if is_leaves[i]:&nbsp; &nbsp; &nbsp; &nbsp; print("%snode=%s leaf node." % (node_depth[i] * "\t", i))&nbsp; &nbsp; else:&nbsp; &nbsp; &nbsp; &nbsp; print("%snode=%s test node: go to node %s if X[:, %s] <= %s else to "&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; "node %s."&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; % (node_depth[i] * "\t",&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;i,&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;children_left[i],&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;feature[i],&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;threshold[i],&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;children_right[i],&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;))print()# First let's retrieve the decision path of each sample. The decision_path# method allows to retrieve the node indicator functions. A non zero element of# indicator matrix at the position (i, j) indicates that the sample i goes# through the node j.node_indicator = estimator.decision_path(X_test)# Similarly, we can also have the leaves ids reached by each sample.leave_id = estimator.apply(X_test)# Now, it's possible to get the tests that were used to predict a sample or# a group of samples. First, let's make it for the sample.sample_id = 0node_index = node_indicator.indices[node_indicator.indptr[sample_id]:&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; node_indicator.indptr[sample_id + 1]]print('Rules used to predict sample %s: ' % sample_id)for node_id in node_index:&nbsp; &nbsp; if leave_id[sample_id] == node_id:&nbsp; &nbsp; &nbsp; &nbsp; continue&nbsp; &nbsp; if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):&nbsp; &nbsp; &nbsp; &nbsp; threshold_sign = "<="&nbsp; &nbsp; else:&nbsp; &nbsp; &nbsp; &nbsp; threshold_sign = ">"&nbsp; &nbsp; print("decision id node %s : (X_test[%s, %s] (= %s) %s %s)"&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; % (node_id,&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;sample_id,&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;feature[node_id],&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;X_test[sample_id, feature[node_id]],&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;threshold_sign,&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;threshold[node_id]))# For a group of samples, we have the following common node.sample_ids = [0, 1]common_nodes = (node_indicator.toarray()[sample_ids].sum(axis=0) ==&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; len(sample_ids))common_node_id = np.arange(n_nodes)[common_nodes]print("\nThe following samples %s share the node %s in the tree"&nbsp; &nbsp; &nbsp; % (sample_ids, common_node_id))print("It is %s %% of all nodes." % (100 * len(common_node_id) / n_nodes,))

慕斯王

我也在这里提供主要思想。以下代码来自 sklearn 文档,并进行了一些小改动以实现您的目标。import numpy as npfrom sklearn.model_selection import train_test_splitfrom sklearn.datasets import load_irisfrom sklearn.tree import DecisionTreeClassifieriris = load_iris()X = iris.datay = iris.targetX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)estimator.fit(X_train, y_train)# The decision estimator has an attribute called tree_&nbsp; which stores the entire# tree structure and allows access to low level attributes. The binary tree# tree_ is represented as a number of parallel arrays. The i-th element of each# array holds information about the node `i`. Node 0 is the tree's root. NOTE:# Some of the arrays only apply to either leaves or split nodes, resp. In this# case the values of nodes of the other type are arbitrary!## Among those arrays, we have:#&nbsp; &nbsp;- left_child, id of the left child of the node#&nbsp; &nbsp;- right_child, id of the right child of the node#&nbsp; &nbsp;- feature, feature used for splitting the node#&nbsp; &nbsp;- threshold, threshold value at the noden_nodes = estimator.tree_.node_countchildren_left = estimator.tree_.children_leftchildren_right = estimator.tree_.children_rightfeature = estimator.tree_.featurethreshold = estimator.tree_.threshold# The tree structure can be traversed to compute various properties such# as the depth of each node and whether or not it is a leaf.node_depth = np.zeros(shape=n_nodes, dtype=np.int64)is_leaves = np.zeros(shape=n_nodes, dtype=bool)stack = [(0, -1)]&nbsp; # seed is the root node id and its parent depthwhile len(stack) > 0:&nbsp; &nbsp; node_id, parent_depth = stack.pop()&nbsp; &nbsp; node_depth[node_id] = parent_depth + 1&nbsp; &nbsp; # If we have a test node&nbsp; &nbsp; if (children_left[node_id] != children_right[node_id]):&nbsp; &nbsp; &nbsp; &nbsp; stack.append((children_left[node_id], parent_depth + 1))&nbsp; &nbsp; &nbsp; &nbsp; stack.append((children_right[node_id], parent_depth + 1))&nbsp; &nbsp; else:&nbsp; &nbsp; &nbsp; &nbsp; is_leaves[node_id] = Trueprint("The binary tree structure has %s nodes and has "&nbsp; &nbsp; &nbsp; "the following tree structure:"&nbsp; &nbsp; &nbsp; % n_nodes)for i in range(n_nodes):&nbsp; &nbsp; if is_leaves[i]:&nbsp; &nbsp; &nbsp; &nbsp; print("%snode=%s leaf node." % (node_depth[i] * "\t", i))&nbsp; &nbsp; else:&nbsp; &nbsp; &nbsp; &nbsp; print("%snode=%s test node: go to node %s if X[:, %s] <= %s else to "&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; "node %s."&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; % (node_depth[i] * "\t",&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;i,&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;children_left[i],&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;feature[i],&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;threshold[i],&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;children_right[i],&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;))print("\n")# First let's retrieve the decision path of each sample. The decision_path# method allows to retrieve the node indicator functions. A non zero element of# indicator matrix at the position (i, j) indicates that the sample i goes# through the node j.node_indicator = estimator.decision_path(X_test)# Similarly, we can also have the leaves ids reached by each sample.leave_id = estimator.apply(X_test)# Now, it's possible to get the tests that were used to predict a sample or# a group of samples. First, let's make it for the sample.# HERE IS WHAT YOU WANTsample_id = 0node_index = node_indicator.indices[node_indicator.indptr[sample_id]:&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; node_indicator.indptr[sample_id + 1]]print('Rules used to predict sample %s: ' % sample_id)for node_id in node_index:&nbsp; &nbsp; if leave_id[sample_id] == node_id:&nbsp; # <-- changed != to ==&nbsp; &nbsp; &nbsp; &nbsp; #continue # <-- comment out&nbsp; &nbsp; &nbsp; &nbsp; print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--&nbsp; &nbsp; else: # < -- added else to iterate through decision nodes&nbsp; &nbsp; &nbsp; &nbsp; if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; threshold_sign = "<="&nbsp; &nbsp; &nbsp; &nbsp; else:&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; threshold_sign = ">"&nbsp; &nbsp; &nbsp; &nbsp; print("decision id node %s : (X[%s, %s] (= %s) %s %s)"&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; % (node_id,&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;sample_id,&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;feature[node_id],&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;X_test[sample_id, feature[node_id]], # <-- changed i to sample_id&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;threshold_sign,&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;threshold[node_id]))这将在最后打印以下内容:Rules used to predict sample 0:&nbsp;decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011920929)decision id node 2 : (X[0, 2] (= 5.1) > 4.950000047683716)leaf node 4 reached, no decision here
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python