Opencv2.4.9源码分析——GradientBoostedTrees详解
- 1、下载文档前请自行甄别文档内容的完整性,平台不提供额外的编辑、内容补充、找答案等附加服务。
- 2、"仅部分预览"的文档,不可在线预览部分如存在完整性等问题,可反馈申请退款(可完整预览的文档不适用该条件!)。
- 3、如文档侵犯您的权益,请联系客服反馈,我们会尽快为您处理(人工客服工作时间:9:00-18:30)。
Opencv2.4.9源码分析——Gradient
Boosted Trees
一、原理
梯度提升树(GBT,Gradient Boosted Trees,或称为梯度提升决策树)算法是由Friedman 于1999年首次完整的提出,该算法可以实现回归、分类和排序。GBT的优点是特征属性无需进行归一化处理,预测速度快,可以应用不同的损失函数等。
从它的名字就可以看出,GBT包括三个机器学习的优化算法:决策树方法、提升方法和梯度下降法。前两种算法在我以前的文章中都有详细的介绍,在这里我只做简单描述。
决策树是一个由根节点、中间节点、叶节点和分支构成的树状模型,分支代表着数据的走向,中间节点包含着训练时产生的分叉决策准则,叶节点代表着最终的数据分类结果或回归值,在预测的过程中,数据从根节点出发,沿着分支在到达中间节点时,根据该节点的决策准则实现分叉,最终到达叶节点,完成分类或回归。
提升算法是由一系列“弱学习器”构成,这些弱学习器通过某种线性组合实现一个强学习器,虽然这些弱学习器的分类或回归效果可能仅仅比随机分类或回归要好一点,但最终的强学习器却可以得到一个很好的预测结果。
二、源码分析
下面介绍OpenCV的GBT源码。
首先给出GBT算法所需参数的结构体CvGBTreesParams:
[cpp] view plain copy 在CODE上查看代码片派生到我的代码片CvGBTreesParams::CvGBTreesParams( int _loss_function_type, int _weak_count,
float _shrinkage, float _subsample_portion,
int _max_depth, bool _use_surrogates )
: CvDTreeParams( 3, 10, 0, false, 10, 0, false, false, 0 )
{
loss_function_type = _loss_function_type;
weak_count = _weak_count;
shrinkage = _shrinkage;
subsample_portion = _subsample_portion;
max_depth = _max_depth;
use_surrogates = _use_surrogates;
}
loss_function_type表示损失函数的类型,CvGBTrees::SQUARED_LOSS为平方损失函数,CvGBTrees::ABSOLUTE_LOSS为绝对值损失函数,CvGBTrees::HUBER_LOSS为Huber损失函数,CvGBTrees::DEVIANCE_LOSS为偏差损失函数,前三种用于回归问题,后一种用于分类问题
weak_count表示GBT的优化迭代次数,对于回归问题来说,weak_count也就是决策树的数量,对于分类问题来说,weak_count×K为决策树的数量,K表示类别数量
shrinkage表示收缩因子v
subsample_portion表示训练样本占全部样本的比例,为不大于1的正数
max_depth表示决策树的最大深度
use_surrogates表示是否使用替代分叉节点,为true,表示使用替代分叉节点CvDTreeParams结构详见我的关于决策树的文章
CvGBTrees类的一个构造函数:
[cpp] view plain copy 在CODE上查看代码片派生到我的代码片
CvGBTrees::CvGBTrees( const cv::Mat& trainData, int tflag,
const cv::Mat& responses, const cv::Mat& varIdx,
const cv::Mat& sampleIdx, const cv::Mat& varType,
const cv::Mat& missingDataMask,
CvGBTreesParams _params )
{
data = 0; //表示样本数据集合
weak = 0; //表示一个弱学习器
default_model_name = "my_boost_tree";
// orig_response表示样本的响应值,sum_response表示拟合函数Fm(x),sum_response_tmp表示Fm+1(x)
orig_response = sum_response = sum_response_tmp = 0;
// subsample_train和subsample_test分别表示训练样本集和测试样本集
subsample_train = subsample_test = 0;
// missing表示缺失的特征属性,sample_idx表示真正用到的样本的索引
missing = sample_idx = 0;
class_labels = 0; //表示类别标签
class_count = 1; //表示类别的数量
delta = 0.0f; //表示Huber损失函数中的参数δ
clear(); //清除一些全局变量和已有的所有弱学习器
//GBT算法的学习
train(trainData, tflag, responses, varIdx, sampleIdx, varType, missingDataMask, _params, false);
}
GBT算法的学习构建函数:
[cpp] view plain copy 在CODE上查看代码片派生到我的代码片
bool