ID3算法实现

构造数据

def createDataSet():
""" the dataset for test """
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
# name of each feature
labels = ['no surfacing', 'flippers']
return dataSet, labels


信息增益

计算数据集的香农熵

$Ent(D) = - \sum_k p_k log P_k$

def calcShannonEnt(dataSet):
''' the function is used to calculate Shannon Entropy,
a method to measure the disorder of data .
the dataSet is like [[1,1,'yes'],[1,0,'no'],...]
the last index of each vector stands for which class it belong.
the feature needs to be discrete and numeric.
return shannon Entropy
'''

numEntries = len(dataSet)
# use lambda to define a function zero when key is not presented
labelCounts = defaultdict(lambda: 0)
for feaVec in dataSet:
currentLabel = feaVec[-1]
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)  # -\sum [p*log(p)]
return shannonEnt


按照特征划分数据集

def splitDataSet(dataSet, axis, value):
"""split Dataset to get information gain for this axis
axis means the fature to calculate using value
return the data splited
"""
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducefeatVec = featVec[:]  # copy
reducefeatVec.pop(axis)  # delete axis feature
retDataSet.append(reducefeatVec)
return retDataSet


选择最好的划分方式

def chooseBestFeatureToSplit(dataSet):
""" choose best feature according to information gain
the data must have the same dimension
return the feature chosen
"""
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInforGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob*calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInforGain):
bestInforGain = infoGain
bestFeature = i
return bestFeature


递归构建决策树

停止递归条件

• 当所有类别标签全部相同，则直接返回该标签
• 当使用完了所有特征，依然不能将数据集进行分类，这里使用投票的方法选出数量最多的标签作为该类。
import operator

def majorityCnt(classList):
"""  this function is used when the algorithm cannot continue
to classfy because all attributes are used
majority of class will be chose.
return the class chosen
"""
classCount = defaultdict(lambda: 0)
for vote in classList:
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(),
key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]


构造决策树

def createTree(dataSet, labels):
""" recursively calculate the tree
delete the label used
return tree dictionary
"""
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
# if all data is the same class, return
return classList[0]
if len(dataSet[0]) == 1:
# if cannot continue, return
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
# init label
myTree = {bestFeatLabel: {}}
# have used the feature
del(labels[bestFeat])
# split the data according to feature value
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]  # copy
myTree[bestFeatLabel][value] = createTree(splitDataSet
(dataSet, bestFeat, value), subLabels)
return myTree


结果

myDat, labels = createDataSet()
myTree = createTree(myDat,labels)
print(myTree)


可视化

绘制树节点

# draw myTree
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
""" bbox: the outline of text\n
axes fraction: fraction of axes from lower left\n
xy: the node of annotate\n
xytext: the text xy coordinate\n
va: verticalalignment\n
ha: horizontalalignment
"""

createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def createPlotTest():
""" facecolor means the arrow color\n
subplot(numrows, numcols, fignum)\n
createPlot.ax1: the attribute of function\n
clf():Clear figure
"""
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon=False)
plotNode('descion node', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()


递归计算树高和节点数

def getNumleafs(myTree):
""" get the numbers of leafs recursively
"""
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumleafs(secondDict[key])
else:
numLeafs += 1
return numLeafs

def getTreeDepth(myTree):
""" get the depth of tree
"""
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth  = 1 + getTreeDepth(secondDict[key])
else:
thisDepth  = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth


绘制决策树

def retrieveTree(i):
""" simple test of data
"""
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}},
{'no surfacing': {0: 'no', 1: {'flippers': {
0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}},
]
return listOfTrees[i]

def plotMidText(cntrPt, parentPt, txtString):
"""plot the text in the middle of two nodes
"""
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):
""" recursively plot the tree
"""
numLeafs = getNumleafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) /
2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)  # plot the text
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD  # reduce offset of y
for key in list(secondDict.keys()):
# test to see if the nodes are dictonaires, if not they are leaf nodes
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))  # recursion
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff,
plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def createPlot(inTree):
""" main entry of plot\n
calcalate the size of graph and style
"""
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumleafs(inTree))  # store the width of tree
plotTree.totalD = float(getTreeDepth(inTree))  # store the depth of tree
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()