决策树算法总结
- 1、下载文档前请自行甄别文档内容的完整性,平台不提供额外的编辑、内容补充、找答案等附加服务。
- 2、"仅部分预览"的文档,不可在线预览部分如存在完整性等问题,可反馈申请退款(可完整预览的文档不适用该条件!)。
- 3、如文档侵犯您的权益,请联系客服反馈,我们会尽快为您处理(人工客服工作时间:9:00-18:30)。
决策树研发二部
目录
1. 算法介绍 (1)
1.1.分支节点选取 (1)
1.2.构建树 (3)
1.3.剪枝 (10)
2. sk-learn中的使用 (12)
3. sk-learn中源码分析 (13)
1.算法介绍
决策树算法是机器学习中的经典算法之一,既可以作为分类算法,也可以作为回归算法。决策树算法又被发展出很多不同的版本,按照时间上分,目前主要包括,ID3、C4.5和CART版本算法。其中ID3版本的决策树算法是最早出现的,可以用来做分类算法。C4.5是针对ID3的不足出现的优化版本,也用来做分类。CART也是针对ID3优化出现的,既可以做分类,可以做回归。
决策树算法的本质其实很类似我们的if-elseif-else语句,通过条件作为分支依据,最终的数学模型就是一颗树。不过在决策树算法中我们需要重点考虑选取分支条件的理由,以及谁先判断谁后判断,包括最后对过拟合的处理,也就是剪枝。这是我们之前写if语句时不会考虑的问题。
决策树算法主要分为以下3个步骤:
1.分支节点选取
2.构建树
3.剪枝
1.1.分支节点选取
分支节点选取,也就是寻找分支节点的最优解。既然要寻找最优,那么必须要有一个衡量标准,也就是需要量化这个优劣性。常用的衡量指标有熵和基尼系数。
熵:熵用来表示信息的混乱程度,值越大表示越混乱,包含的信息量也就越多。比如,A班有10个男生1个女生,B班有5个男生5个女生,那么B班的熵值就比A班大,也就是B班信息越混乱。
基尼系数:同上,也可以作为信息混乱程度的衡量指标。
有了量化指标后,就可以衡量使用某个分支条件前后,信息混乱程度的收敛效果了。使用分支前的混乱程度,减去分支后的混乱程度,结果越大,表示效果越好。
#计算熵值
def entropy(dataSet):
tNum = len(dataSet)
print(tNum)
#用来保存标签对应的个数的,比如,男:6,女:5
labels = {}
for node in dataSet:
curL = node[-1] #获取标签
if curL not in labels.keys():
labels[curL] = 0 #如果没有记录过该种标签,就记录并初始化为0 labels[curL] += 1 #将标签记录个数加1
#此时labels中保存了所有标签和对应的个数
res = 0
#计算公式为-p*logp,p为标签出现概率
for node in labels:
p = float(labels[node]) / tNum
res -= p * log(p, 2)
return res
#计算基尼系数
def gini(dataSet):
tNum = len(dataSet)
print(tNum)
# 用来保存标签对应的个数的,比如,男:6,女:5
labels = {}
for node in dataSet:
curL = node[-1] # 获取标签
if curL not in labels.keys():
labels[curL] = 0 # 如果没有记录过该种标签,就记录并初始化为0 labels[curL] += 1 # 将标签记录个数加1
# 此时labels中保存了所有标签和对应的个数
res = 1
# 计算公式为-p*logp,p为标签出现概率
for node in labels:
p = float (labels[node]) / tNum res -= p * p return res
1.2. 构建树
ID3算法:利用信息熵增益,决定选取哪个特征作为分支节点。分支前的总样本熵值-分支后的熵值总和=信息熵增益。
T1的信息熵增益:1 – 13/20*0.961 - 7/20*0.863 = 0.073 T2的信息熵增益:1 – 12/20*0.812 - 8/20*0.544 = 0.295 所以使用T2作为分支特征更优。
ID3算法建树:
依据前面的逻辑,递归寻找最优分支节点,直到下面情况结束 1. 叶节点已经属于同一标签
2. 虽然叶节点不属于同一标签,但是特征已经用完了
3. 熵小于预先设置的阈值
4. 树的深度达到了预先设置的阈值 ID3算法的不足:
1.取值多的特征比取值少的特征更容易被选取。
2.不包含剪枝操作,过拟合严重
3.特征取值必须是离散的,或者有限的区间的。
于是有了改进算法C4.5
C4.5算法:基于ID3算法进行了改进,首先,针对ID3的不足1,采用信息增益率取代ID3中使用信息增益而造成的偏向于选取取值较多的特征作为分裂点的问题。针对ID3的不足2,采用剪枝操作,缓解过拟合问题。针对ID3的不足3,采用将连续值先排列,然后逐个尝试分裂,找到连续值中的最佳分裂点。
信息增益率的计算:先计算信息增益,然后除以spliteInfo。spliteInfo为分裂后的子集合的函数,假设分裂后的子集合个数为sub1和sub2,total为分裂前的个数。spliteInfo = -sub1 / total * log(sub1 / total) – sub2 / total * log(sub2 / total)
#index:特征序号
#value:特征值
#该方法表示将index对应特征的值为value的集合返回,返回集合中不包含index对应的特征
def spliteDataSet(dataSet, index, value):
newDataSet = []
for node in dataSet:
if node[index] == value:
#[0,index)列的数据
newData = node[:index]
#[index+1,最后]列的数据
newData.extend(node[index + 1:])
newDataSet.append(newData)
return newDataSet;
#选择最优分裂项
def chooseBestFeature(dataSet):
#特征个数
featureNum = len(dataSet[0]) - 1
#计算整体样本的熵值
baseEntropy = entropy(dataSet)
print("baseEntropy = %f"%(baseEntropy))
#保存最大的信息增益率