朴素贝叶斯matlab实现
- 1、下载文档前请自行甄别文档内容的完整性,平台不提供额外的编辑、内容补充、找答案等附加服务。
- 2、"仅部分预览"的文档,不可在线预览部分如存在完整性等问题,可反馈申请退款(可完整预览的文档不适用该条件!)。
- 3、如文档侵犯您的权益,请联系客服反馈,我们会尽快为您处理(人工客服工作时间:9:00-18:30)。
clc
clear
close all
data=importdata('data.txt');
wholeData=data.data;
%交叉验证选取训练集和测试集
cv=cvpartition(size(wholeData,1),'holdout',0.04);%0.04表明测试数据集占总数据集的比例
cvpartition(n,'holdout',p)创建一个随机分区,用于在n个观测值上进行保持验证。该分区将观察分为训练集和测试(或保持)集。参数p必须是标量,当0
trainData=wholeData(training(cv),:);
testData=wholeData(test(cv),:);
label=data.textdata;
attributeNumber=size(trainData,2);size(A,2):获取矩阵A的列数。attributeValueNumber=5;
%将分类标签转化为数据(因为在分类数据集中有3个类别,分别是R、B、L所以将类别转换为数字)
sampleNumber=size(label,1);
labelData=zeros(sampleNumber,1);
for i=1:sampleNumber(测试集的行数)
if label{i,1}=='R'
labelData(i,1)=1;
elseif label{i,1}=='B'
labelData(i,1)=2;
else
labelData(i,1)=3;
end
end
trainLabel=labelData(training(cv),:);
trainSampleNumber=size(trainLabel,1);
testLabel=labelData(test(cv),:);
%计算每个分类的样本的概率
labelProbability=tabulate(trainLabel);
tabulate函数的功能是创建向量X信息数据频率表。其函数使用格式:
tbl = tabulate(x)
创建的TBL(数据频率表)的结构:第一列:x的唯一值第二列:每个值的实例数量第三列:每个值的百分比
%P_yi,计算P(yi)
P_y1=labelProbability(1,3)/100;(第一行,第三个元素)
P_y2=labelProbability(2,3)/100;
P_y3=labelProbability(3,3)/100;
count_1=zeros(attributeNumber,attributeValueNumber);%count_1(i,j):y=1情况下,第i个属性取j值的数量统计
count_2=zeros(attributeNumber,attributeValueNumber);%count_1(i,j):y=2情况下,第i个属性取j值的数量统计
count_3=zeros(attributeNumber,attributeValueNumber);%count_1(i,j):y=3情况下,第i个属性取j值的数量统计
%统计每一个特征的每个取值的数量
for jj=1:3
for j=1:trainSampleNumber
for ii=1:attributeNumber
for k=1:attributeValueNumber
if jj==1
if trainLabel(j,1)==1&&trainData(j,ii)==k
count_1(ii,k)=count_1(ii,k)+1;
end
elseif jj==2
if trainLabel(j,1)==2&&trainData(j,ii)==k
count_2(ii,k)=count_2(ii,k)+1;
end
else
if trainLabel(j,1)==3&&trainData(j,ii)==k
count_3(ii,k)=count_3(ii,k)+1;
end
end
end
end
end
end
%计算第i个属性取j值的概率,P_a_y1是分类为y=1前提下取值,其他依次类推。P_a_y1=count_1./labelProbability(1,2);
P_a_y2=count_2./labelProbability(2,2);
P_a_y3=count_3./labelProbability(3,2);
%使用测试集进行数据测试
labelPredictNumber=zeros(3,1);
predictLabel=zeros(size(testData,1),1);
for kk=1:size(testData,1)
testDataTemp=testData(kk,:);
Pxy1=1;
Pxy2=1;
Pxy3=1;
%计算P(x|yi)
for iii=1:attributeNumber
Pxy1=Pxy1*P_a_y1(iii,testDataTemp(iii));
Pxy2=Pxy2*P_a_y2(iii,testDataTemp(iii));
Pxy3=Pxy3*P_a_y3(iii,testDataTemp(iii));
end
%计算P(x|yi)*P(yi)
PxyPy1=P_y1*Pxy1;
PxyPy2=P_y2*Pxy2;
PxyPy3=P_y3*Pxy3;
if PxyPy1>PxyPy2&&PxyPy1>PxyPy3
predictLabel(kk,1)=1;
disp(['this item belongs to No.',num2str(1),' label or the R labe l'])
labelPredictNumber(1,1)=labelPredictNumber(1,1)+1;
elseif PxyPy2>PxyPy1&&PxyPy2>PxyPy3
predictLabel(kk,1)=2;
labelPredictNumber(2,1)=labelPredictNumber(2,1)+1;
disp(['this item belongs to No.',num2str(2),' label or the B labe l'])
elseif PxyPy3>PxyPy2&&PxyPy3>PxyPy1