这是我的代码和注释,你可以通过直接复制代码到你的pycharm中跑起来。
你不需要另外去准备数据集,当本地没有数据集运行代码就会自动下载
这是一个很小的项目,你不需要准备GPU
mnist_train.py
1 import torch
2
3 # nn包用来完成神经网络的搭建
4 from torch import nn
5
6 # functional包含常用的函数
7 from torch.nn import functional as F
8
9 # optim优化数据包,用来更新权重
10 from torch import optim
11
12 # 视觉相关的工具包
13 import torchvision
14
15 # 导入画图工具包
16 from matplotlib import pyplot as plt
17
18 # 从utils 包里导入所需工具
19 from utils import plot_image, plot_curve, one_hot
20
21 # step1. load dataset 加载数据集
22
23 # 这里设定一次处理多少张图片
24 batch_size = 512
25
26 # 加载训练集
27 train_loader = torch.utils.data.DataLoader(
28 # 加载MNIST数据集(1.图片路径,2.指定下载的图片为text还是train,3.download若1本地没有则去网上下载,
29 # 4.transform格式转换,网上图片一般为numpy格式,转为totensor格式)
30 torchvision.datasets.MNIST('mnist_data', train=True, download=True,
31 transform=torchvision.transforms.Compose([
32 torchvision.transforms.ToTensor(),
33 torchvision.transforms.Normalize(
34 (0.1307,), (0.3081,))
35 # 这个参数是正则化,防止过拟合,防止参数过多或过大,避免模型过复杂。有L1正则化和L2正则化,这里是让参数维持在0的附近均匀的分配
36 ])), # 0.3081是均差
37 batch_size=batch_size, shuffle=True) # 加载数据并随机打散数据
38
39 # 加载测试集
40 test_loader = torch.utils.data.DataLoader(
41 torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
42 transform=torchvision.transforms.Compose([
43 torchvision.transforms.ToTensor(),
44 torchvision.transforms.Normalize(
45 (0.1307,), (0.3081,))
46 ])),
47 batch_size=batch_size, shuffle=False)
48
49
50 # # 查看图片
51 x, y = next(iter(train_loader))
52 print(x.shape, y.shape, x.min(), x.max())
53 plot_image(x, y, 'image sample')
54
55
56 # 设置网络层
57
58 class Net(nn.Module):
59
60 def __init__(self):
61 super(Net, self).__init__()
62
63 # xw + b
64 # 第一层,第一个参数为图像大小,第二个参数根据经验值设置输出层大小
65 self.fc1 = nn.Linear(28 * 28, 256)
66 # 第二层,第一个个参数为上一层的输出大小,第二个大小根据经验设置输出层大小
67 self.fc2 = nn.Linear(256, 64)
68 # 最后一层,第一个值为上一层输出大小,第二个参数为输出的种类数
69 self.fc3 = nn.Linear(64, 10)
70
71 # 计算函数
72 def forward(self, x):
73 # x:[b,1,28,28] #relu将线性函数调整变种为非线性函数
74 # h1=relu(xw1 +b1)
75 x = F.relu(self.fc1(x))
76 # h2=relu(h1w2+b20
77 x = F.relu(self.fc2(x))
78 # 第三层为输出层,一般输出概率值
79 x = self.fc3(x)
80
81 return x
82
83
84 # 对创建的神经网络进行初始化
85 net = Net()
86
87 # 设置对计算后的梯度进行梯度更新方法,这里采用SGD随机梯度下降,lr是学习率
88 optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
89
90 train_loss = []
91
92 for epoch in range(3):
93 # 对整个数据集迭代三次
94 for batch_idx, (x, y) in enumerate(train_loader):
95 # 对整个数据集迭代一次
96
97 # x :
98 # print(x.shape,y.shape)
99
100 # 输入
101 x = x.view(x.size(0), 28 * 28)
102
103 # 输出
104 out = net(x) # 我们的目的是将输出更加接近于y
105
106 # 将真实的y转为独热编码
107 y_onehot = one_hot(y)
108
109 # 通过mse_loss计算误差值,也就是均方差
110 loss = F.mse_loss(out, y_onehot)
111
112 # 清零梯度
113 optimizer.zero_grad()
114 # 计算梯度
115 loss.backward()
116 # 更新梯度
117 optimizer.step()
118
119 # 最后我们会得到较为合适的[w1,b1,w2,b2,w3,b3]
120
121 # 将loss数据收集,以便用matplotlib将其变化图示化
122 train_loss.append(loss.item())
123
124
125 # 查看loss下降的变化
126 if batch_idx % 10 == 0:
127
128 print(epoch, batch_idx, loss.item())
129
130
131 plot_curve(train_loss)
132
133 # 我们最终想要看到的并不是loss而是准确率
134 # 准确度的测试
135 # 在test测试集取数据然后进行测试
136 total_correct = 0
137 for x,y in test_loader:
138 x = x.view(x.size(0), 28*28)
139 out = net(x)
140 # out: [b, 10] => pred: [b]
141 pred = out.argmax(dim=1)
142 correct = pred.eq(y).sum().float().item()
143 total_correct += correct
144 145 total_num = len(test_loader.dataset)
146 acc = total_correct / total_num
147 print('test acc:', acc)
148 149 x, y = next(iter(test_loader))
150 out = net(x.view(x.size(0), 28*28))
151 pred = out.argmax(dim=1)
152 plot_image(x, pred, 'test')
utils.py文件中包含的是画图函数,和独热编码的函数,可以直接调用,比如上面的代码就调用了它。将它一并放入你的pycharm中
import torch
from matplotlib import pyplot as plt
# 画一条曲线
def plot_curve(data):
fig = plt.figure()
plt.plot(range(len(data)), data, color='blue')
plt.legend(['value'], loc='upper right')
plt.xlabel('step')
plt.ylabel('value')
plt.show()
# 可视化查看识别结果
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
def one_hot(label, depth=10):
out = torch.zeros(label.size(0), depth)
idx = torch.LongTensor(label).view(-1, 1)
out.scatter_(dim=1, index=idx, value=1) # 生成独热编码
return out