import torchvision import torch import torch.utils.data.dataloader as Data from torch.autograd import Variable import numpy as np import torch.nn as nn import torch.nn.functional as F from PIL import Image import matplotlib.pyplot as plt #残差块 if_use_gpu=0 class ResidualBlock(nn.Module): def __init__(self, inchannel, outchannel, stride=1): super(ResidualBlock, self).__init__() self.left = nn.Sequential( nn.Conv2d(inchannel,outchannel,kernel_size=3,padding=1,stride=stride,bias=False), nn.BatchNorm2d(outchannel), nn.ReLU(), nn.Conv2d(outchannel, outchannel, kernel_size=3, padding=1, stride=stride, bias=False), nn.BatchNorm2d(outchannel) ) self.right = nn.Sequential() #输入输出信道数不一样,把残差块的信道卷积到和输出一样 if(inchannel != outchannel): self.right = nn.Sequential( nn.Conv2d(inchannel, outchannel, kernel_size=3, padding=1, stride=stride, bias=False), nn.BatchNorm2d(outchannel), ) def forward(self, x): out = self.left(x) out += self.right(x) out =F.relu(out) return out class ResNet(nn.Module): def __init__(self, ResidualBlock, num_classes=10): super(ResNet, self).__init__() self.inchannel = 64 self.conv1 = nn.Sequential( nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1,bias=False), nn.BatchNorm2d(64), nn.ReLU(), ) self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1) self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=1) self.conv2 = nn.Conv2d(128,128,3,stride=2) self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=1) self.conv3 = nn.Conv2d(256, 256, 3, stride=2) #self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=1) self.conv4 = nn.Conv2d(256,256,6) self.fc = nn.Linear(256, num_classes) def make_layer(self, block, channels, num_blocks, stride): layer = [] for i in range(num_blocks): layer.append(block(self.inchannel,channels,stride)) self.inchannel = channels #对layer拆包 return nn.Sequential(*layer) def forward(self, x): out = self.conv1(x) out = self.layer1(out) out = self.layer2(out) out = self.conv2(out) out = self.layer3(out) out = self.conv3(out) #out = self.layer4(out) out = self.conv4(out) #out = F.avg_pool2d(out,4) out = out.view(out.size(0), -1) out = self.fc(out) return out def ResNet18(): return ResNet(ResidualBlock) train_data = torchvision.datasets.MNIST( './mnist', train=True,transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), ]), download=True ) train_data.data = train_data.data[:10000] train_data.targets = train_data.targets[:10000] test_data = torchvision.datasets.MNIST( './mnist', train=False, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), ]) ) print("train_data:", train_data.train_data.size()) print("train_labels:", train_data.train_labels.size()) print("test_data:", test_data.test_data.size()) train_loader = Data.DataLoader(dataset=train_data, batch_size=32, shuffle=True) test_loader = Data.DataLoader(dataset=test_data, batch_size=32) model = ResNet18() if if_use_gpu: model = model.cuda() print(model) optimizer = torch.optim.Adam(model.parameters()) loss_func = torch.nn.CrossEntropyLoss() for epoch in range(1): print('epoch {}'.format(epoch + 1)) for i, data in enumerate(train_loader, 0): # get the inputs inputs, labels = data batch_x, batch_y = Variable(inputs), Variable(labels) if if_use_gpu: batch_x = batch_x.cuda() batch_y = batch_y.cuda() out = model(batch_x) batch_y = batch_y.long() loss = loss_func(out, batch_y) optimizer.zero_grad() loss.backward() optimizer.step() # 返回每行元素最大值 pred = torch.max(out, 1)[1] train_correct = (pred == batch_y).sum() train_correct = train_correct.item() train_loss = loss.item() print('batch:{},Train Loss: {:.6f}, Acc: {:.6f}'.format(i+1,train_loss , train_correct /32)) # evaluation-------------------------------- model.eval() eval_loss = 0. eval_acc = 0. for batch_x, batch_y in test_loader: batch_x, batch_y = Variable(batch_x, requires_grad=False), Variable(batch_y,requires_grad=False) if if_use_gpu: batch_x = batch_x.cuda() batch_y = batch_y.cuda() out = model(batch_x) loss = loss_func(out, batch_y) eval_loss += loss.item() pred = torch.max(out, 1)[1] num_correct = (pred == batch_y).sum() eval_acc += num_correct.item() print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len( test_data)), eval_acc / (len(test_data))))