• 龙良曲pytorch学习笔记_ResNet18


    main----dataloader----train----test

    相对LeNet5的主函数来讲,仅仅是更换了模型名称,其他部分没有变化。

    import torch
    from torch.utils.data import DataLoader
    from torchvision import datasets
    from torchvision import transforms
    from torch import nn,optim
    from resnet import ResNet18
    
    def main():
        batch_size = 32
        cifar_train = datasets.CIFAR10('cifar',train = True,transform = transforms.Compose([
            transforms.Resize((32,32)),
            transforms.ToTensor()
        ]),download = True)
        
        # 可以同时加载多张图片
        cifar_train = DataLoader(cifar_train,batch_size = batch_size,shuffle = True)
        
        cifar_test = datasets.CIFAR10('cifar',train = False,transform = transforms.Compose([
            transforms.Resize((32,32)),
            transforms.ToTensor(),
            # transforms.RandomRotation(5),
            transforms.Normalize(mean = [0.485,0,456,0,406],
                                 std = [0.229.0.224.0.225])
        ]),download = True)
        
        # 可以同时加载多张图片
        cifar_test = DataLoader(cifar_test,batch_size = batch_size,shuffle = True)
    
        # 数据加载成功后可以检验shape
        x,label = iter(cifar_train).next()
        print('x:',x.shape,'label:',label.shape)
    
        device = torch.device('cuda')
        model = ResNet18().to(device)
        criteon = nn.CrossEntropyLoss().to(device)
        optimizer = optim.Adam(model.parameters(),lr=1e-3)
        
        print(model)
        
        for epoch in range(1000):
            
            model.train()
            for batchidx,(x,label) in enumerate(cifar_train):
                # x: [b,3,32,32], label: [b]
                x,label = x.to(device),label.to(device)
                
                logits = model(x)
                # logits:[b,10]
                # label:[b]
                loss = criteon(logits,label)
                
                # backprop
                optimizer.zero_grad()
                loss.backwark()
                optimizer.step()
                
            #
            print(epoch,loss.item())
            
            model.eval()
            # 不需要做梯度相关计算
            with torch.nn_grad():
                # test
                total_correct = 0
                total_num = 0
                for x,label in cifar_test:
                    x,label = x.to(device),label.to(device)
                    # logits:[b,10]
                    logits = model(x)
                    pred = logits.argmax(dim=1)
                    # 获取一个batch的在累加
                    total_correct = += torch.eq(pred,label).float().sum().item()
                    # x.size(0)就是batch_size
                    total_num += x.size(0)
                    
                acc = total_correct / total_num
                print(epoch,acc)
                
    if __name__ == '__main__'
        main()

    ResNet18

     1 import torch
     2 from torch import nn
     3 from torch.nn import functional as F
     4 
     5 class ResBlk(nn.Module):
     6 
     7     def __init__(self,ch_in,ch_out,stride = 1):
     8         super(ResBlk,self).__init__()
     9         
    10         # 改变stride是为了使得图片的size变小,以避免占用过多内存
    11         self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size = 3,stride = stride,padding = 1)
    12         self.bn1 = nn.BatchNorm2d(ch_out)
    13         self.conv2 = nn.Conv2d(ch_out,ch_out,kernel_size = 3,stride = 1,padding = 1)
    14         self.bn2 = nn.BatchNorm2d(ch_out)
    15         
    16         self.extra = nn.Squential()
    17         if ch_out != ch_in:
    18             # [b,ch_in,h,w] -->  [b,ch_in,h,w]
    19             self.extra = nn.Squential(
    20                 # x要和f(x)的size也一样,所以也要设置stride
    21                 # 而channel通过一个卷积层来使得他们一致
    22                 nn.Conv2d(ch_in,ch_out,kernel_size = 1,stride = stride)
    23                 nn.BatchNorm2d(ch_out)
    24             )
    25         
    26     def forward(self,x):
    27         out = F.relu(self.bn1(self.conv1(x)))
    28         # 这里的relu取决于自己
    29         out = F.relu(self.bn2(self.conv2(out)))
    30         # short cut
    31         # extra module: [b,ch_in,h,w] -->  [b,ch_in,h,w]
    32         # element-wise add 需要ch_in和ch_out相等
    33         # 由于是残差网络,所以要把f(x)和短路的x相加
    34         out = self.extra(x) + out
    35         
    36         return out
    37         
    38 class ResNet18(nn.Module):
    39     
    40     def __init__(self):
    41         super(ResNet18,self).__init__()
    42         
    43         self.conv1 = nn.Sequential(
    44             nn.Conv2d(3,64,kernel_size = 3,stride = 1,padding = 1),
    45             nn.BatchNorm2d(64)
    46         )
    47         # followws 4 blocks
    48         # [b,64,h,w] --> [b,128,h,w]
    49         self.blk1 = ResBlk(64,128,stride = 2)
    50         # [b,128,h,w] --> [b,256,h,w]
    51         self.blk2 = ResBlk(128,256,stride = 2)
    52         # [b,256,h,w] --> [b,512,h,w]
    53         self.blk3 = ResBlk(256,512,stride = 2)
    54         # [b,512,h,w] --> [b,512,h,w]
    55         self.blk4 = ResBlk(512,512,stride = 2)
    56         
    57         # 线性层的输入需要测试之后才能知道
    58         self.outlayer = nn.Linear(512*1*1,10)
    59         
    60     def forward(self,x):
    61         x = F.relu(self.conv1(x))
    62         # [b,64,h,w] --> [b,1024,h,w]
    63         x = self.blk1(x)
    64         x = self.blk2(x)
    65         x = self.blk3(x)
    66         x = self.blk4(x)
    67         
    68         # print('after conv:',x.shape) # [b,512,2,2]
    69         # [b,512,1,1] --> [b,512,1,1]
    70         x = F.adaptive_avg_pool2d(x,[1,1])
    71         # print('after conv:',x.shape)
    72         x = x.view(x.size(0),-1)
    73         x = self.outlayer(x)
    74         
    75         return x
    76         
    77 def main():
    78     
    79     blk = ResBlk(64,128,stride = 4)
    80     tmp = torch.randn(2,64,32,32)
    81     out = blk(tmp)
    82     print('block:',out.shape)
    83     
    84     x = torch.randn(2,3,32,32)
    85     model = ResNet18()
    86     out = model(x)
    87     print('resnet:',out.shape)
  • 相关阅读:
    基本的CRUD操作
    java.lang.IllegalStateException: Cannot forward after response has been committed的一个情况解决方法
    一个解决过程:Servlet [某路径xxx] in web application [/项目xxx] threw load() exception和java.lang.ClassNotFoundException XXX
    卸载时候出现: windows installer 程序有问题。此安装需要的dll不能运行 的一个解决方法
    jdk各版本特性
    抽象类与接口
    Integert 与 int例子详解
    Spring(mvc)思维导图
    关于存储数组有序无序
    遍历回顾(手稿)-先序中序求后序----和----中序后序求先序
  • 原文地址:https://www.cnblogs.com/fxw-learning/p/12318422.html
Copyright © 2020-2023  润新知