import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms import os device = ('cuda:1' if torch.cuda.is_available() else 'cpu') # device = ('cpu') # Training settings batch_size = 64 root = 'pytorch-master/mnist_data' train_dataset = datasets.MNIST(root=root, train=True, transform=transforms.ToTensor(), download=True) test_dataset = datasets.MNIST(root=root, train=False, transform=transforms.ToTensor(), download=True) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True ) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=True) save_path = os.path.join(root, 'savepath') os.makedirs(save_path, exist_ok=True) class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, 5) self.conv3 = nn.Conv2d(20, 40, 3) self.mp = nn.MaxPool2d(2) self.mp1 = nn.MaxPool2d(2) self.fc1 = nn.Linear(2560, 512) self.fc2 = nn.Linear(512, 10) def forward(self, x): in_size = x.size(0) x = F.relu(self.mp(self.conv1(x))) x = F.relu(self.mp(self.conv2(x))) x = F.relu(self.mp1(self.conv3(x))) x = x.view(in_size, -1) x = self.fc1(x) x = self.fc2(x) return F.log_softmax(x, dim=1) model = Net().to(device) optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) def data_enhance(data, batch_idx): new_data = torch.zeros((data.size(0), data.size(1), 28 * 3, 28 * 3)) noise = torch.rand(new_data.size()) index = batch_idx % 9 if index == 0: new_data[:, :, 0:28, 0:28] = data elif index == 1: new_data[:, :, 28:56, 0:28] = data elif index == 2: new_data[:, :, 56:, 0:28] = data elif index == 3: new_data[:, :, 0:28, 28:56] = data elif index == 4: new_data[:, :, 28:56, 28:56] = data elif index == 5: new_data[:, :, 56:, 28:56] = data elif index == 6: new_data[:, :, 0:28, 56:] = data elif index == 7: new_data[:, :, 28:56, 56:] = data elif index == 8: new_data[:, :, 56:, 56:] = data new_data = noise*0.7 + new_data*0.3 return new_data def train(epoch): for batch_idx, (data, target) in enumerate(train_loader): data = data_enhance(data, batch_idx).to(device) output = model(data) loss = F.nll_loss(output, target.to(device)) if batch_idx % 200 == 0: contest = 'Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} '.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.data.item()) print(contest) with open(os.path.join(root, 'log.txt'), 'a') as f: f.write(contest) loss.backward() optimizer.step() optimizer.zero_grad() torch.save(model.state_dict(), os.path.join(save_path, str(epoch) + '.pth')) def test(): test_loss = 0 correct = 0 for index, (data, target) in enumerate(test_loader): data = data_enhance(data, index).to(device) output = model(data) test_loss += F.nll_loss(output, target.to(device), size_average=False).data.item() pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.to(device).data.view_as(pred)).cpu().sum() test_loss /= len(test_loader.dataset) contest = 'Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%) '.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)) print(contest) with open(os.path.join(root, 'log.txt'), 'a') as f: f.write(contest) from torchvision.utils import save_image feature = [] def get_features_hook(self, input, output): feature.append(output) def show(para_path): print('device:{}'.format(device)) show_path = os.path.join(root, 'show') os.makedirs(show_path, exist_ok=True) model = Net() model.load_state_dict(torch.load(para_path,map_location='cpu')) model = model.to(device) for index, (data, target) in enumerate(test_loader): print(index) data = data_enhance(data, index).to(device) save_image(data, os.path.join(show_path, str(index) + '_img.jpg')) handle = model.mp1.register_forward_hook(get_features_hook) model(data) handle.remove() feat = torch.max(feature[-1], dim=1, keepdim=True)[0] save_image(feat, os.path.join(show_path, str(index) + '_feat.jpg')) if index > 3: break if __name__ == '__main__': act = 2 if act == 1: print('start training...') for epoch in range(1, 100): train(epoch) test() else: print('start show..') show('/pytorch-master/mnist_data/savepath/40.pth')
输入:(为了增加难度,对mnist数据集的图片进行了平移,加噪音操作)
可视化效果:(可以看出,网络确实学习到了数字特征(至少是位置信息),最终能达到0.96的准确率)