继续浏览精彩内容
慕课网APP
程序员的梦工厂
打开
继续
感谢您的支持,我会继续努力的
赞赏金额会直接到老师账户
将二维码发送给自己后长按识别
微信支付
支付宝支付

决策树简单代码

Coder_zheng
关注TA
已关注
手记 71
粉丝 23
获赞 45
from math import log
import time
import pandas as pd 
import numpy as np 

def createDataSet():
    dataSet =[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no'],]
    labels =['Tree','leaves']
    return dataSet,labels

#计算香农熵
def calcShannonEnt(dataSet):
    numEntries =len(dataSet)
    labelCounts={}
    for feaVec in dataSet:
        currentLabel =feaVec[-1]
        if currentLabel not in labelCounts:
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    shannonEnt =0.0
    for key in labelCounts:
        prob =float(labelCounts[key])/numEntries
        shannonEnt-=prob*log(prob,2)
    return shannonEnt

#去掉已经决策过的属性
def splitDataSet(dataSet,axis,value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec =featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

#根据信息增益算法,选取最优属性
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0])-1 #因为数据集的最后一项是标签
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain =0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList =[example[i] for example in dataSet]
        print(featList)
        uniqueVals =set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet =splitDataSet(dataSet,i,value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy +=prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy -newEntropy
        if infoGain >bestInfoGain:
            bestInfoGain =infoGain
            bestFeature =i
    return bestFeature
#因为我们递归构建决策树是根据属性的消耗进行计算的,所以可能会存在最后属性用完了,但是分类还没有算完,
#这时候就会采用多数表决的方式计算节点分类
def majorityCnt(classList):
    classCount ={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] =0
        classCount[vote]+=1
    return max(classCount)

def createTree(dataSet,labels):
    classList =[example[-1] for example in dataSet]
    if classList.count(classList[0]) ==len(classList): #类别相同则停止划分
        return classList[0]
    if len(dataSet[0])==1:#所有特征已经用完
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel =labels[bestFeat]
    myTree ={bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]#为了不改变原始列表的内容复制了一下
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat, value),subLabels)
    return myTree
     
def main():
    data,label =createDataSet()
    t1 =time.clock()
    myTree =createTree(data,label)
    t2 =time.clock()
    print(myTree)
    print('execure time:',t2-t1)


if __name__=='__main__':
    main()
    

运行结果
图片描述

打开App,阅读手记
0人推荐
发表评论
随时随地看视频慕课网APP