• 可视化分类网络的feature map


    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torchvision import datasets, transforms
    import os
    
    device = ('cuda:1' if torch.cuda.is_available() else 'cpu')
    # device = ('cpu')
    
    # Training settings
    batch_size = 64
    root = 'pytorch-master/mnist_data'
    train_dataset = datasets.MNIST(root=root,
                                   train=True,
                                   transform=transforms.ToTensor(),
                                   download=True)
    
    test_dataset = datasets.MNIST(root=root,
                                  train=False,
                                  transform=transforms.ToTensor(),
                                  download=True)
    
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True
                                               )
    
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              drop_last=True)
    
    save_path = os.path.join(root, 'savepath')
    os.makedirs(save_path, exist_ok=True)
    
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5)
            self.conv2 = nn.Conv2d(10, 20, 5)
            self.conv3 = nn.Conv2d(20, 40, 3)
            self.mp = nn.MaxPool2d(2)
            self.mp1 = nn.MaxPool2d(2)
            self.fc1 = nn.Linear(2560, 512)
            self.fc2 = nn.Linear(512, 10)
    
        def forward(self, x):
            in_size = x.size(0)
            x = F.relu(self.mp(self.conv1(x)))
            x = F.relu(self.mp(self.conv2(x)))
            x = F.relu(self.mp1(self.conv3(x)))
            x = x.view(in_size, -1)
            x = self.fc1(x)
            x = self.fc2(x)
            return F.log_softmax(x, dim=1)
    
    
    model = Net().to(device)
    
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
    
    
    def data_enhance(data, batch_idx):
        new_data = torch.zeros((data.size(0), data.size(1), 28 * 3, 28 * 3))
        noise = torch.rand(new_data.size())
    
        index = batch_idx % 9
        if index == 0:
            new_data[:, :, 0:28, 0:28] = data
        elif index == 1:
            new_data[:, :, 28:56, 0:28] = data
        elif index == 2:
            new_data[:, :, 56:, 0:28] = data
        elif index == 3:
            new_data[:, :, 0:28, 28:56] = data
        elif index == 4:
            new_data[:, :, 28:56, 28:56] = data
        elif index == 5:
            new_data[:, :, 56:, 28:56] = data
        elif index == 6:
            new_data[:, :, 0:28, 56:] = data
        elif index == 7:
            new_data[:, :, 28:56, 56:] = data
        elif index == 8:
            new_data[:, :, 56:, 56:] = data
    
        new_data = noise*0.7 + new_data*0.3
        return new_data
    
    
    def train(epoch):
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data_enhance(data, batch_idx).to(device)
            output = model(data)
            loss = F.nll_loss(output, target.to(device))
    
            if batch_idx % 200 == 0:
                contest = 'Train Epoch: {} [{}/{} ({:.0f}%)]	Loss: {:.6f}
    '.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.data.item())
                print(contest)
                with open(os.path.join(root, 'log.txt'), 'a') as f:
                    f.write(contest)
    
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
    
        torch.save(model.state_dict(), os.path.join(save_path, str(epoch) + '.pth'))
    
    
    def test():
        test_loss = 0
        correct = 0
        for index, (data, target) in enumerate(test_loader):
            data = data_enhance(data, index).to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target.to(device), size_average=False).data.item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.to(device).data.view_as(pred)).cpu().sum()
    
        test_loss /= len(test_loader.dataset)
        contest = 'Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
    
    '.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset))
        print(contest)
        with open(os.path.join(root, 'log.txt'), 'a') as f:
            f.write(contest)
    
    from torchvision.utils import save_image
    
    feature = []
    
    
    def get_features_hook(self, input, output):
        feature.append(output)
    
    
    def show(para_path):
        print('device:{}'.format(device))
        show_path = os.path.join(root, 'show')
        os.makedirs(show_path, exist_ok=True)
        model = Net()
        model.load_state_dict(torch.load(para_path,map_location='cpu'))
        model = model.to(device)
        for index, (data, target) in enumerate(test_loader):
            print(index)
            data = data_enhance(data, index).to(device)
            save_image(data, os.path.join(show_path, str(index) + '_img.jpg'))
            handle = model.mp1.register_forward_hook(get_features_hook)
            model(data)
            handle.remove()
            feat = torch.max(feature[-1], dim=1, keepdim=True)[0]
            save_image(feat, os.path.join(show_path, str(index) + '_feat.jpg'))
            if index > 3:
                break
    
    
    if __name__ == '__main__':
        act = 2
        if act == 1:
            print('start training...')
            for epoch in range(1, 100):
                train(epoch)
                test()
        else:
            print('start show..')
            show('/pytorch-master/mnist_data/savepath/40.pth')
    

     输入:(为了增加难度,对mnist数据集的图片进行了平移,加噪音操作)

     可视化效果:(可以看出,网络确实学习到了数字特征(至少是位置信息),最终能达到0.96的准确率)

     

  • 相关阅读:
    微信nickname乱码(emoji)及mysql编码格式设置(utf8mb4)解决的过程
    eclipse Specified VM install not found: type Standard VM, name
    eclipse中安装Open Explorer
    关于Java变量的可见性问题
    Win8&Win2012R2如何支持DOTA2输入法
    Adobe Flash player 10 提示:Error#2044:未处理的IOErrorEvent. text=Error#2036:加载未完成 的解决方法
    IntelliJ IDEA 12.1.4 解决中文乱码
    Win8.1RTM英文版安装中文语言包的两种方法
    在FlashDevelop里使用1.8版本的的TortoiseSVN
    [修复Win8.1 BUG] 解决Win8.1英文字体发虚不渲染问题
  • 原文地址:https://www.cnblogs.com/jiangnanyanyuchen/p/13325322.html
Copyright © 2020-2023  润新知