鸢尾花数据集----决策树神经网络

合集下载
  1. 1、下载文档前请自行甄别文档内容的完整性,平台不提供额外的编辑、内容补充、找答案等附加服务。
  2. 2、"仅部分预览"的文档,不可在线预览部分如存在完整性等问题,可反馈申请退款(可完整预览的文档不适用该条件!)。
  3. 3、如文档侵犯您的权益,请联系客服反馈,我们会尽快为您处理(人工客服工作时间:9:00-18:30)。

鸢尾花数据集----决策树神经⽹络为⽅便理解两种不同预测分类算法我们均调⽤ sklearn ⾥ datasets 的鸢尾花数据集
决策树1(复杂):
1import numpy as np
2from sklearn import datasets
3from sklearn.model_selection import train_test_split
4import matplotlib as mpl
5import matplotlib.pyplot as plt
6from sklearn import tree
7from sklearn.pipeline import Pipeline
8from sklearn.tree import DecisionTreeClassifier
9from sklearn.preprocessing import StandardScaler
10
11# 防⽌画图汉字乱码
12 mpl.rcParams['font.sans-serif'] = [u'SimHei']
13 mpl.rcParams['axes.unicode_minus'] = False
14
15#数据准备
16 dataset = datasets.load_iris() # 此时训练数据(train)与标签(target) 已经分离为字典数据集
17# 数据集已经将标签数据化(化为0-2标签值) ⽆需再处理
18
19 data = dataset['data'] # 取出对应键的值值为array类型
20 target = dataset['target']
21# input = torch.FloatTensor(dataset['data'])
22# y = torch.LongTensor(dataset['target'])
23
24 x = np.array(data)
25 y = np.array(target)
26 x = x[:, :2] # 此时的数据为 150⾏ 4列为⽅便画图我们只取前两个特征
27# 将数据集 7 / 3 分
28 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=1)
29
30 model = Pipeline([
31 ('ss', StandardScaler()),
32 ('DTC', DecisionTreeClassifier(criterion='entropy', max_depth=3))])
33# clf = DecisionTreeClassifier(criterion='entropy', max_depth=3)
34 model = model.fit(x_train, y_train)
35 y_test_hat = model.predict(x_test) # 测试数据 y_test_hat 为预测值
36# print(y_test) 45个预测样本的真实标签
37# [0 1 1 0 2 1 2 0 0 2 1 0 2 1 1 0 1 1 0 0 1 1 1 0 2 1 0 0 1 2 1 2 1 2 2 0 1 0 1 2 2 0 2 2 1]
38# print(y_test_hat) 45个预测样本的预测标签
39# [0 1 2 0 2 2 2 0 0 2 1 0 2 2 1 0 1 1 0 0 1 0 2 0 2 1 0 0 1 2 1 2 1 2 1 0 1 0 2 2 2 0 1 2 2]
40
41
42# 保存
43# dot -Tpng -o 1.png 1.dot
44 f = open('.\\iris_tree.dot', 'w')
45 tree.export_graphviz(model.get_params('DTC')['DTC'], out_file=f)
46
47# 画图
48 N, M = 100, 100 # 横纵各采样多少个值
49 x1_min, x1_max = x[:, 0].min(), x[:, 0].max() # 第0列的范围
50 x2_min, x2_max = x[:, 1].min(), x[:, 1].max() # 第1列的范围
51 t1 = np.linspace(x1_min, x1_max, N)
52 t2 = np.linspace(x2_min, x2_max, M)
53 x1, x2 = np.meshgrid(t1, t2) # ⽣成 v ⽹格采样点
54 x_show = np.stack((x1.flat, x2.flat), axis=1) # 测试点
55
56# # ⽆意义,只是为了凑另外两个维度
57# # 打开该注释前,确保注释掉x = x[:, :2]
58# x3 = np.ones(x1.size) * np.average(x[:, 2])
59# x4 = np.ones(x1.size) * np.average(x[:, 3])
60# x_test = np.stack((x1.flat, x2.flat, x3, x4), axis=1) # 测试点
61
62 cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])
63 cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])
64 y_show_hat = model.predict(x_show) # 预测值预测的标签值
65
66 y_show_hat = y_show_hat.reshape(x1.shape) # 使之与输⼊的形状相同
67 plt.figure(facecolor='w')
68 plt.pcolormesh(x1, x2, y_show_hat, cmap=cm_light) # 预测值的显⽰
69 plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test.ravel(), edgecolors='k', s=100, cmap=cm_dark, marker='o') # 测试数据
70 plt.scatter(x[:, 0], x[:, 1], c=y.ravel(), edgecolors='k', s=40, cmap=cm_dark) # 全部数据
71 plt.xlabel("花萼长度", fontsize=15) # 花萼长度、花萼宽度
72 plt.ylabel("花萼宽度", fontsize=15)
73 plt.xlim(x1_min, x1_max)
74 plt.ylim(x2_min, x2_max)
75 plt.grid(True)
76 plt.title(u'鸢尾花数据的决策树分类', fontsize=17)
77 plt.show()
78
79# 训练集上的预测结果
80 y_test = y_test.reshape(-1)
81
82 result = (y_test_hat == y_test) # True则预测正确,False则预测错误
83 acc = np.mean(result)
84print('准确度: %.2f%%' % (100 * acc))
85
86# 过拟合:错误率
87 depth = np.arange(1, 45)
88 err_list = []
89for d in depth: # 进⾏15
90 clf = DecisionTreeClassifier(criterion='entropy', max_depth=d)
91 clf = clf.fit(x_train, y_train)
92 y_test_hat = clf.predict(x_test) # 测试数据
93 result = (y_test_hat == y_test) # True则预测正确,False则预测错误
94 err = 1 - np.mean(result)
95 err_list.append(err)
96print(d, ' 准确度: %.2f%%' % (100 * err))
97 plt.figure(facecolor='w')
98 plt.plot(depth, err_list, 'ro-', lw=2)
99 plt.xlabel(u'决策树深度', fontsize=15)
100 plt.ylabel(u'错误率', fontsize=15)
101 plt.title(u'决策树深度与过拟合', fontsize=17)
102 plt.grid(True)
103
104 plt.show()
105
106from sklearn import tree # 需要导⼊的包
107
108 f = open('D:\\py_project\\iris_tree.dot', 'w')
109
110 tree.export_graphviz(model.get_params('DTC')['DTC'], out_file=f)
决策树2:
数据集为本地导⼊与 from sklearn import datasets 数据集⼀样 1import numpy as np
2from sklearn.model_selection import train_test_split
3from sklearn import tree
4 with open(r'D:\py_project\8.iris.txt', "r", encoding='UTF-8') as fp:
5 data = fp.read().splitlines()
6 lit = []
7for str in data:
8 str = str.split(',', 5)
9 lit.append(str)
10 feature = np.array(lit)
11 lable = []
12for i in feature:
13 lable.append(i[4])
14
15 X = feature[:, 0:4]
16 X = np.array(X, dtype=float)
17print(X)
18
19def iris_type(lable):
20 it = {'Iris-setosa':0,
21'Iris-versicolor':1,
22'Iris-virginica':2}
23 Lable = []
24for i in lable:
25 Lable.append(it[i])
26
27return Lable
28 lable = iris_type(lable)
29 Y = np.array(lable)
30
31 x_train,x_test,y_train,y_test = train_test_split(X, Y, train_size=0.7)
32 clf = tree.DecisionTreeClassifier().fit(x_train,y_train)
33 y_test_hat = clf.predict(x_test)
34 count = len(y_test)
35 err = 0
36for i in range(count):
37if y_test[i] != y_test_hat[i]:
38 err += 1
39
40print("正确率ACC:",float((count-err)/count))
神经⽹络:
1import numpy as np
2from collections import Counter
3from sklearn import datasets
4import torch.nn.functional as Fun
5from torch.autograd import Variable
6import matplotlib.pyplot as plt
7import torch
8
9 dataset = datasets.load_iris()
10 dataut=dataset['data']
11 priciple=dataset['target']
12
13 input=torch.FloatTensor(dataset['data'])
14 label=torch.LongTensor(dataset['target'])
15
16#定义BP神经⽹络
17class Net(torch.nn.Module):
18def__init__(self, n_feature, n_hidden, n_output):
19 super(Net, self).__init__()
20 self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
21 self.out = torch.nn.Linear(n_hidden, n_output) # output layer
22
23def forward(self, x):
24 x = Fun.relu(self.hidden(x)) # activation function for hidden layer we choose sigmoid
25 x = self.out(x)
26return x
27
28 net = Net(n_feature=4, n_hidden=20, n_output=3)
29 optimizer = torch.optim.SGD(net.parameters(), lr=0.02) #SGD: 随机梯度下降
30 loss_func = torch.nn.CrossEntropyLoss() #针对分类问题的损失函数!
31
32#训练数据
33for t in range(500):
34 out = net(input) # input x and predict based on x
35 loss = loss_func(out, label) # 输出与label对⽐
36 optimizer.zero_grad() # clear gradients for next train
37 loss.backward() # backpropagation, compute gradients
38 optimizer.step() # apply gradients
39
40 out = net(input) #out是⼀个计算矩阵,可以⽤Fun.softmax(out)转化为概率矩阵
41 prediction = torch.max(out, 1)[1] # 1返回index 0返回原值
42 pred_y = prediction.data.numpy()
43 target_y = label.data.numpy()
44 accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
45print("莺尾花预测准确率",accuracy)
鸢尾花数据集:
共150个分为三种类别 setosa,versicolor,virginnica
花萼长度、花萼宽度,花瓣长度,花瓣宽度,种类
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa 5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa 5.5,3.5,1.3,0.2,Iris-setosa 4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa 5.0,3.5,1.3,0.3,Iris-setosa 4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa 5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa 5.0,3.3,1.4,0.2,Iris-setosa 7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor 6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor 4.9,2.4,3.3,1.0,Iris-versicolor 6.6,2.9,4.6,1.3,Iris-versicolor 5.2,2.7,3.9,1.4,Iris-versicolor 5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor 6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor 5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor 5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor 6.3,2.5,4.9,1.5,Iris-versicolor 6.1,2.8,4.7,1.2,Iris-versicolor 6.4,2.9,4.3,1.3,Iris-versicolor 6.6,3.0,4.4,1.4,Iris-versicolor 6.8,2.8,4.8,1.4,Iris-versicolor 6.7,3.0,5.0,1.7,Iris-versicolor 6.0,2.9,4.5,1.5,Iris-versicolor 5.7,2.6,3.5,1.0,Iris-versicolor 5.5,2.4,3.8,1.1,Iris-versicolor 5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor 6.7,3.1,4.7,1.5,Iris-versicolor 6.3,2.3,4.4,1.3,Iris-versicolor 5.6,3.0,4.1,1.3,Iris-versicolor 5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor 5.8,2.6,4.0,1.2,Iris-versicolor 5.0,2.3,3.3,1.0,Iris-versicolor 5.6,2.7,4.2,1.3,Iris-versicolor 5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor 5.1,2.5,3.0,1.1,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica 5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica 6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica 4.9,2.5,4.5,1.7,Iris-virginica 7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica 6.5,3.2,5.1,2.0,Iris-virginica 6.4,2.7,5.3,1.9,Iris-virginica 6.8,3.0,5.5,2.1,Iris-virginica 5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica 7.7,2.6,6.9,2.3,Iris-virginica 6.0,2.2,5.0,1.5,Iris-virginica 6.9,3.2,5.7,2.3,Iris-virginica 5.6,2.8,4.9,2.0,Iris-virginica 7.7,2.8,6.7,2.0,Iris-virginica 6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica 6.2,2.8,4.8,1.8,Iris-virginica 6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica 7.4,2.8,6.1,1.9,Iris-virginica 7.9,3.8,6.4,2.0,Iris-virginica 6.4,2.8,5.6,2.2,Iris-virginica 6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica 6.3,3.4,5.6,2.4,Iris-virginica 6.4,3.1,5.5,1.8,Iris-virginica 6.0,3.0,4.8,1.8,Iris-virginica 6.9,3.1,5.4,2.1,Iris-virginica 6.7,3.1,5.6,2.4,Iris-virginica 6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica 6.7,3.3,5.7,2.5,Iris-virginica 6.7,3.0,5.2,2.3,Iris-virginica 6.3,2.5,5.0,1.9,Iris-virginica 6.5,3.0,5.2,2.0,Iris-virginica 6.2,3.4,5.4,2.3,Iris-virginica 5.9,3.0,5.1,1.8,Iris-virginica。

相关文档
最新文档