完整的PyTorch代码识别手写数字
import torch
import torchvision
from torchvision import datasets, transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, ))])
trainset = datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainsetloader = torch.utils.data.DataLoader(trainset, batch_size=20000, shuffle=True)
testset = datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
testsetloader = torch.utils.data.DataLoader(testset, batch_size=20000, shuffle=True)
first_in, first_out, second_out = 28*28, 128, 10
model = torch.nn.Sequential(
torch.nn.Linear(first_in, first_out),
torch.nn.ReLU(),
torch.nn.Linear(first_out, second_out),
)
loss_fn = torch.nn.CrossEntropyLoss()
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(10):
for i, one_batch in enumerate(trainsetloader,0):
data,label = one_batch
data[0].view(1,784)
data = data.view(data.shape[0],-1)
model_output = model(data)
loss = loss_fn(model_output , label)
if i%500 == 0:
print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(model,'./my_handwrite_recognize_model.pt')
testdataiter = iter(testsetloader)
testimages, testlabels = testdataiter.next()
img_vector = testimages[0].squeeze().view(1,-1)
result_digit = model(img_vector)
print("该手写数字图片识别结果为:", result_digit.max(1)[1],"标签为:",testlabels[0])