各位用户为了找寻关于基于ID3决策树算法的实现(Python版)的资料费劲了很多周折。这里教程网为您整理了关于基于ID3决策树算法的实现(Python版)的相关资料,仅供查阅,以下为您介绍关于基于ID3决策树算法的实现(Python版)的详细内容
实例如下:
? 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235# -*- coding:utf-8 -*-
from
numpy
import
*
import
numpy as np
import
pandas as pd
from
math
import
log
import
operator
#计算数据集的香农熵
def
calcShannonEnt(dataSet):
numEntries
=
len
(dataSet)
labelCounts
=
{}
#给所有可能分类创建字典
for
featVec
in
dataSet:
currentLabel
=
featVec[
-
1
]
if
currentLabel
not
in
labelCounts.keys():
labelCounts[currentLabel]
=
0
labelCounts[currentLabel]
+
=
1
shannonEnt
=
0.0
#以2为底数计算香农熵
for
key
in
labelCounts:
prob
=
float
(labelCounts[key])
/
numEntries
shannonEnt
-
=
prob
*
log(prob,
2
)
return
shannonEnt
#对离散变量划分数据集,取出该特征取值为value的所有样本
def
splitDataSet(dataSet,axis,value):
retDataSet
=
[]
for
featVec
in
dataSet:
if
featVec[axis]
=
=
value:
reducedFeatVec
=
featVec[:axis]
reducedFeatVec.extend(featVec[axis
+
1
:])
retDataSet.append(reducedFeatVec)
return
retDataSet
#对连续变量划分数据集,direction规定划分的方向,
#决定是划分出小于value的数据样本还是大于value的数据样本集
def
splitContinuousDataSet(dataSet,axis,value,direction):
retDataSet
=
[]
for
featVec
in
dataSet:
if
direction
=
=
0
:
if
featVec[axis]>value:
reducedFeatVec
=
featVec[:axis]
reducedFeatVec.extend(featVec[axis
+
1
:])
retDataSet.append(reducedFeatVec)
else
:
if
featVec[axis]<
=
value:
reducedFeatVec
=
featVec[:axis]
reducedFeatVec.extend(featVec[axis
+
1
:])
retDataSet.append(reducedFeatVec)
return
retDataSet
#选择最好的数据集划分方式
def
chooseBestFeatureToSplit(dataSet,labels):
numFeatures
=
len
(dataSet[
0
])
-
1
baseEntropy
=
calcShannonEnt(dataSet)
bestInfoGain
=
0.0
bestFeature
=
-
1
bestSplitDict
=
{}
for
i
in
range
(numFeatures):
featList
=
[example[i]
for
example
in
dataSet]
#对连续型特征进行处理
if
type
(featList[
0
]).__name__
=
=
'float'
or
type
(featList[
0
]).__name__
=
=
'int'
:
#产生n-1个候选划分点
sortfeatList
=
sorted
(featList)
splitList
=
[]
for
j
in
range
(
len
(sortfeatList)
-
1
):
splitList.append((sortfeatList[j]
+
sortfeatList[j
+
1
])
/
2.0
)
bestSplitEntropy
=
10000
slen
=
len
(splitList)
#求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
for
j
in
range
(slen):
value
=
splitList[j]
newEntropy
=
0.0
subDataSet0
=
splitContinuousDataSet(dataSet,i,value,
0
)
subDataSet1
=
splitContinuousDataSet(dataSet,i,value,
1
)
prob0
=
len
(subDataSet0)
/
float
(
len
(dataSet))
newEntropy
+
=
prob0
*
calcShannonEnt(subDataSet0)
prob1
=
len
(subDataSet1)
/
float
(
len
(dataSet))
newEntropy
+
=
prob1
*
calcShannonEnt(subDataSet1)
if
newEntropy<bestSplitEntropy:
bestSplitEntropy
=
newEntropy
bestSplit
=
j
#用字典记录当前特征的最佳划分点
bestSplitDict[labels[i]]
=
splitList[bestSplit]
infoGain
=
baseEntropy
-
bestSplitEntropy
#对离散型特征进行处理
else
:
uniqueVals
=
set
(featList)
newEntropy
=
0.0
#计算该特征下每种划分的信息熵
for
value
in
uniqueVals:
subDataSet
=
splitDataSet(dataSet,i,value)
prob
=
len
(subDataSet)
/
float
(
len
(dataSet))
newEntropy
+
=
prob
*
calcShannonEnt(subDataSet)
infoGain
=
baseEntropy
-
newEntropy
if
infoGain>bestInfoGain:
bestInfoGain
=
infoGain
bestFeature
=
i
#若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
#即是否小于等于bestSplitValue
if
type
(dataSet[
0
][bestFeature]).__name__
=
=
'float'
or
type
(dataSet[
0
][bestFeature]).__name__
=
=
'int'
:
bestSplitValue
=
bestSplitDict[labels[bestFeature]]
labels[bestFeature]
=
labels[bestFeature]
+
'<='
+
str
(bestSplitValue)
for
i
in
range
(shape(dataSet)[
0
]):
if
dataSet[i][bestFeature]<
=
bestSplitValue:
dataSet[i][bestFeature]
=
1
else
:
dataSet[i][bestFeature]
=
0
return
bestFeature
#特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票
def
majorityCnt(classList):
classCount
=
{}
for
vote
in
classList:
if
vote
not
in
classCount.keys():
classCount[vote]
=
0
classCount[vote]
+
=
1
return
max
(classCount)
#主程序,递归产生决策树
def
createTree(dataSet,labels,data_full,labels_full):
classList
=
[example[
-
1
]
for
example
in
dataSet]
if
classList.count(classList[
0
])
=
=
len
(classList):
return
classList[
0
]
if
len
(dataSet[
0
])
=
=
1
:
return
majorityCnt(classList)
bestFeat
=
chooseBestFeatureToSplit(dataSet,labels)
bestFeatLabel
=
labels[bestFeat]
myTree
=
{bestFeatLabel:{}}
featValues
=
[example[bestFeat]
for
example
in
dataSet]
uniqueVals
=
set
(featValues)
if
type
(dataSet[
0
][bestFeat]).__name__
=
=
'str'
:
currentlabel
=
labels_full.index(labels[bestFeat])
featValuesFull
=
[example[currentlabel]
for
example
in
data_full]
uniqueValsFull
=
set
(featValuesFull)
del
(labels[bestFeat])
#针对bestFeat的每个取值,划分出一个子树。
for
value
in
uniqueVals:
subLabels
=
labels[:]
if
type
(dataSet[
0
][bestFeat]).__name__
=
=
'str'
:
uniqueValsFull.remove(value)
myTree[bestFeatLabel][value]
=
createTree(splitDataSet
(dataSet,bestFeat,value),subLabels,data_full,labels_full)
if
type
(dataSet[
0
][bestFeat]).__name__
=
=
'str'
:
for
value
in
uniqueValsFull:
myTree[bestFeatLabel][value]
=
majorityCnt(classList)
return
myTree
import
matplotlib.pyplot as plt
decisionNode
=
dict
(boxstyle
=
"sawtooth"
,fc
=
"0.8"
)
leafNode
=
dict
(boxstyle
=
"round4"
,fc
=
"0.8"
)
arrow_args
=
dict
(arrowstyle
=
"<-"
)
#计算树的叶子节点数量
def
getNumLeafs(myTree):
numLeafs
=
0
firstSides
=
list
(myTree.keys())
firstStr
=
firstSides[
0
]
secondDict
=
myTree[firstStr]
for
key
in
secondDict.keys():
if
type
(secondDict[key]).__name__
=
=
'dict'
:
numLeafs
+
=
getNumLeafs(secondDict[key])
else
: numLeafs
+
=
1
return
numLeafs
#计算树的最大深度
def
getTreeDepth(myTree):
maxDepth
=
0
firstSides
=
list
(myTree.keys())
firstStr
=
firstSides[
0
]
secondDict
=
myTree[firstStr]
for
key
in
secondDict.keys():
if
type
(secondDict[key]).__name__
=
=
'dict'
:
thisDepth
=
1
+
getTreeDepth(secondDict[key])
else
: thisDepth
=
1
if
thisDepth>maxDepth:
maxDepth
=
thisDepth
return
maxDepth
#画节点
def
plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.ax1.annotate(nodeTxt,xy
=
parentPt,xycoords
=
'axes fraction'
,
xytext
=
centerPt,textcoords
=
'axes fraction'
,va
=
"center"
, ha
=
"center"
,
bbox
=
nodeType,arrowprops
=
arrow_args)
#画箭头上的文字
def
plotMidText(cntrPt,parentPt,txtString):
lens
=
len
(txtString)
xMid
=
(parentPt[
0
]
+
cntrPt[
0
])
/
2.0
-
lens
*
0.002
yMid
=
(parentPt[
1
]
+
cntrPt[
1
])
/
2.0
createPlot.ax1.text(xMid,yMid,txtString)
def
plotTree(myTree,parentPt,nodeTxt):
numLeafs
=
getNumLeafs(myTree)
depth
=
getTreeDepth(myTree)
firstSides
=
list
(myTree.keys())
firstStr
=
firstSides[
0
]
cntrPt
=
(plotTree.x0ff
+
(
1.0
+
float
(numLeafs))
/
2.0
/
plotTree.totalW,plotTree.y0ff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict
=
myTree[firstStr]
plotTree.y0ff
=
plotTree.y0ff
-
1.0
/
plotTree.totalD
for
key
in
secondDict.keys():
if
type
(secondDict[key]).__name__
=
=
'dict'
:
plotTree(secondDict[key],cntrPt,
str
(key))
else
:
plotTree.x0ff
=
plotTree.x0ff
+
1.0
/
plotTree.totalW
plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,
str
(key))
plotTree.y0ff
=
plotTree.y0ff
+
1.0
/
plotTree.totalD
def
createPlot(inTree):
fig
=
plt.figure(
1
,facecolor
=
'white'
)
fig.clf()
axprops
=
dict
(xticks
=
[],yticks
=
[])
createPlot.ax1
=
plt.subplot(
111
,frameon
=
False
,
*
*
axprops)
plotTree.totalW
=
float
(getNumLeafs(inTree))
plotTree.totalD
=
float
(getTreeDepth(inTree))
plotTree.x0ff
=
-
0.5
/
plotTree.totalW
plotTree.y0ff
=
1.0
plotTree(inTree,(
0.5
,
1.0
),'')
plt.show()
df
=
pd.read_csv(
'watermelon_4_3.csv'
)
data
=
df.values[:,
1
:].tolist()
data_full
=
data[:]
labels
=
df.columns.values[
1
:
-
1
].tolist()
labels_full
=
labels[:]
myTree
=
createTree(data,labels,data_full,labels_full)
print
(myTree)
createPlot(myTree)
最终结果如下:
{'texture': {'blur': 0, 'little_blur': {'touch': {'soft_stick': 1, 'hard_smooth': 0}}, 'distinct': {'density<=0.38149999999999995': {0: 1, 1: 0}}}}
得到的决策树如下:
参考资料:
《机器学习实战》
《机器学习》周志华著
以上这篇基于ID3决策树算法的实现(Python版)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。