• pytorch入门1——简单的网络搭建


    代码如下:

    %matplotlib inline
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torchsummary import summary
    from torchvision import models
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 6, 5)
            self.conv2 = nn.Conv2d(6, 16, 5)
            #此处的16*5*5为conv2经过pooling之后的尺寸,即为fc1的输入尺寸,在这里写死了,因此后面的输入图片大小不能任意调整
            self.fc1 = nn.Linear(16*5*5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
        def forward(self, x):
            x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
            x = F.max_pool2d(F.relu(self.conv2(x)), 2)
            x = x.view(-1, self.num_flat_features(x))
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
        def num_flat_features(self, x):
            size = x.size()[1:]
            num_features = 1
            for s in size:
                num_features *= s
            return num_features
    net = Net()
    print(net)
    
    params = list(net.parameters())
    print (len(params))
    print(params[0].size())
    print(params[1].size())
    print(params[2].size())
    print(params[3].size())
    print(params[4].size())
    print(params[5].size())
    print(params[6].size())
    print(params[7].size())
    print(params[8].size())
    print(params[9].size())
    
    input = torch.randn(1, 1, 32, 32)
    out = net(input)
    print(out)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    vgg = net.to(device)
    summary(vgg, (1, 32, 32))

    上述代码完成了以下功能:

    1、建立一个简单的网络,并给各个网络层的参数size进行赋值;

    2、查看各个网络层参数量;

    3、给网路一个随机的输入,查看网络输出;

    4、查看网络每一层的额输出blob的大小;

    这里需要注意的是,在进行第一个全连接层的定义时,self.fc1 = nn.Linear(16*5*5, 120)

    第一个参数是根据网络结构计算出来的到达该层的feature map的尺寸,因此后面在给定网络输入的时候,不能任意调整网络的输入尺寸,该尺寸经过conv1+pooling+conv2+pooling之后的尺寸必须要为5*5才可以;

  • 相关阅读:
    iOS 获取内外网ip
    iOS 查看层级关系以及调用堆栈
    CoreML Use of undeclared type & Use of unresolved identifier
    AFN的实时网络监控 但是block连续调用了两次
    iOS 11 偏好设置(NSUserDefaults)无效了?
    iOS 11 UIScrollView的新特性(automaticallyAdjustsScrollViewInsets 不起作用了)
    Xcode9~iOS11初体验 无线调试
    Hook~iOS用钩子实现代码注入(埋点方案)
    tomcat启动时端口占用的问题怎么解决
    Memcached在Linux环境下的使用详解http://blog.51cto.com/soysauce93/1737161
  • 原文地址:https://www.cnblogs.com/rainsoul/p/11272554.html
Copyright © 2020-2023  润新知