猿问

在python中从xgboost中提取决策规则

我想在 python 中为我即将推出的模型使用 xgboost。然而,由于我们的生产系统在 SAS 中,我试图从 xgboost 中提取决策规则,然后编写 SAS 评分代码以在 SAS 环境中实现该模型。

上面两个链接对xgboost部署特别是Shiutang-Li给出的代码有很大帮助。但是,我的预测分数并不完全匹配。


以下是我迄今为止尝试过的代码:


import numpy as np

import pandas as pd

import xgboost as xgb

from sklearn.grid_search import GridSearchCV

%matplotlib inline

import graphviz

from graphviz import Digraph


#Read the sample iris data:

iris =pd.read_csv("C:\\Users\\XXXX\\Downloads\\Iris.csv")

#Create dependent variable:

iris.loc[iris["class"] != 2,"class"] = 0

iris.loc[iris["class"] == 2,"class"] = 1


#Select independent and dependent variable:

X = iris[["sepal_length","sepal_width","petal_length","petal_width"]]

Y = iris["class"]


xgdmat = xgb.DMatrix(X, Y) # Create our DMatrix to make XGBoost more efficient


#Build the sample xgboost Model:


our_params = {'eta': 0.1, 'seed':0, 'subsample': 0.8, 'colsample_bytree': 0.8, 

             'objective': 'binary:logistic', 'max_depth':3, 'min_child_weight':1} 

Base_Model = xgb.train(our_params, xgdmat, num_boost_round = 10)


#Below code reads the dump file created by xgboost and writes a scoring code in SAS:


import re

def string_parser(s):

    if len(re.findall(r":leaf=", s)) == 0:

        out  = re.findall(r"[\w.-]+", s)

        tabs = re.findall(r"[\t]+", s)

        if (out[4] == out[8]):

            missing_value_handling = (" or missing(" + out[1] + ")")

        else:

            missing_value_handling = ""


        if len(tabs) > 0:

            return (re.findall(r"[\t]+", s)[0].replace('\t', '    ') + 

                    '        if state = ' + out[0] + ' then do;\n' +

                    re.findall(r"[\t]+", s)[0].replace('\t', '    ') +

                    '            if ' + out[1] + ' < ' + out[2] + missing_value_handling +

所以基本上,我想要做的是,将节点号保存在变量“状态”中,并相应地访问叶节点(我从上面链接中提到的 Shiutang-Li 的文章中了解到)。


慕仙森
浏览 337回答 3
3回答

蓝山帝景

我在获得匹配分数方面有类似的经验。我的理解是,除非您修复ntree_limit选项以匹配n_estimators您在模型拟合期间使用的选项,否则评分可能会提前停止。df['score']=&nbsp;xgclfpkl.predict(df[xg_features],&nbsp;ntree_limit=500)开始使用后ntree_limit,我开始获得匹配的分数。

Smart猫小萌

我有类似的经验,需要将 xgboost 评分代码从 R 提取到 SAS。最初,我遇到了与您在这里相同的问题,即在较小的树中,R 和 SAS 的分数没有太大差异,一旦树的数量增加到 100 或更多,我开始观察差异.我做了三件事来缩小差异:确保丢失的组朝着正确的方向前进,您需要明确。否则 SAS 会将缺失值视为所有数字中的最小值。规则应该类似于 SAS 中的以下内容。if sepal_width > 2.95000005 or missing(sepal_width) then state = 1;else state = 2;或者if sepal_width <= 2.95000005 and ~missing(sepal_width) then state = 1;else state = 2;我使用了一个叫做 R 包float来使分数有更多的小数位。&nbsp;as.numeric(float::fl(Quality))确保 SAS 数据与您在 Python 中训练的数据具有相同的形状。希望以上有帮助。

神不在的星期二

几点——首先,正则表达式叶返回值匹配并没有捕捉到垃圾堆里的“E-小数”科学记数法(默认)。显式示例(第二个是正确的修改!)-s = '3:leaf=9.95066429e-09'out = re.findall(r"[\d.-]+", s)out2 = re.findall(r"-?[\d.]+(?:e-?\d+)?", s)out2,out(易于修复但不易发现,因为我的模型中只有一片叶子受到影响!)其次,问题是关于二进制的,但在多类目标中,转储中的每个类都有单独的树,因此您T*C总共有树,其中T是提升轮C数,是类数。对于类c(在 {0,1,...,C-1} 中),您需要评估(并求和)树i*C +c的i = 0,...,T-1. 然后将其 softmax 以匹配来自 xgb 的预测。
随时随地看视频慕课网APP

相关分类

Python
我要回答