• pytorchnum_flat_features(x)


    #coding=utf-8
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.autograd import Variable

    class Net(nn.Module):
    #定义Net的初始化函数,这个函数定义了该神经网络的基本结构
    def __init__(self):
    super(Net, self).__init__() #复制并使用Net的父类的初始化方法,即先运行nn.Module的初始化函数
    self.conv1 = nn.Conv2d(1, 6, 5) # 定义conv1函数的是图像卷积函数:输入为图像(1个频道,即灰度图),输出为 6张特征图, 卷积核为5x5正方形
    self.conv2 = nn.Conv2d(6, 16, 5)# 定义conv2函数的是图像卷积函数:输入为6张特征图,输出为16张特征图, 卷积核为5x5正方形
    self.fc1 = nn.Linear(16*5*5, 120) # 定义fc1(fullconnect)全连接函数1为线性函数:y = Wx + b,并将16*5*5个节点连接到120个节点上。
    self.fc2 = nn.Linear(120, 84)#定义fc2(fullconnect)全连接函数2为线性函数:y = Wx + b,并将120个节点连接到84个节点上。
    self.fc3 = nn.Linear(84, 10)#定义fc3(fullconnect)全连接函数3为线性函数:y = Wx + b,并将84个节点连接到10个节点上。

    #定义该神经网络的向前传播函数,该函数必须定义,一旦定义成功,向后传播函数也会自动生成(autograd)
    def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) #输入x经过卷积conv1之后,经过激活函数ReLU(原来这个词是激活函数的意思),使用2x2的窗口进行最大池化Max pooling,然后更新到x。
    x = F.max_pool2d(F.relu(self.conv2(x)), 2) #输入x经过卷积conv2之后,经过激活函数ReLU,使用2x2的窗口进行最大池化Max pooling,然后更新到x。
    x = x.view(-1, self.num_flat_features(x)) #view函数将张量x变形成一维的向量形式,总特征数并不改变,为接下来的全连接作准备。
    x = F.relu(self.fc1(x)) #输入x经过全连接1,再经过ReLU激活函数,然后更新x
    x = F.relu(self.fc2(x)) #输入x经过全连接2,再经过ReLU激活函数,然后更新x
    x = self.fc3(x) #输入x经过全连接3,然后更新x
    return x

    #使用num_flat_features函数计算张量x的总特征量(把每个数字都看出是一个特征,即特征总量),比如x是4*2*2的张量,那么它的特征总量就是16。
    def num_flat_features(self, x):

    size = x.size()[1:] # 这里为什么要使用[1:],是因为pytorch只接受批输入,也就是说一次性输入好几张图片,那么输入数据张量的维度自然上升到了4维。【1:】让我们把注意力放在后3维上面

    num_features = 1
    for s in size:
    num_features *= s
    return num_features


    net = Net()
    net

    # 以下代码是为了看一下我们需要训练的参数的数量
    print(net)
    params = list(net.parameters())

    k=0
    for i in params:
    l =1
    print ("该层的结构:"+str(list(i.size())))
    for j in i.size():
    l *= j
    print ("参数和:"+str(l))
    k = k+l

    print ("总参数和:"+ str(k))


    def num_flat_features(x):
    size = x.size()[1:] # 这里为什么要使用[1:],是因为pytorch只接受批输入,也就是说一次性输入好几张图片,那么输入数据张量的维度自然上升到了4维。【1:】让我们把注意力放在后3维上面

    num_features = 1
    for s in size:
    num_features *= s
    return num_features
    a=torch.arange(1,17).resize(2,2,2,2)
    print(a.size())
    print(a.size()[1:])
    size=num_flat_features(a)
    print(size)

    Net(
    (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (fc1): Linear(in_features=400, out_features=120, bias=True)
    (fc2): Linear(in_features=120, out_features=84, bias=True)
    (fc3): Linear(in_features=84, out_features=10, bias=True)
    )
    该层的结构:[6, 1, 5, 5]
    参数和:150
    该层的结构:[6]
    参数和:6
    该层的结构:[16, 6, 5, 5]
    参数和:2400
    该层的结构:[16]
    参数和:16
    该层的结构:[120, 400]
    参数和:48000
    该层的结构:[120]
    参数和:120
    该层的结构:[84, 120]
    参数和:10080
    该层的结构:[84]
    参数和:84
    该层的结构:[10, 84]
    参数和:840
    该层的结构:[10]
    参数和:10
    总参数和:61706
    torch.Size([2, 2, 2, 2])
    torch.Size([2, 2, 2])
    8

  • 相关阅读:
    HTML有2种路径的写法:绝对路径和相对路径
    ZB本地设置
    java main函数
    java static 关键字
    03013_动态页面技术-JSP
    Oracle数据库的文件以及Oracle体系架构
    记录一次mybatis缓存和事务传播行为导致ut挂的排查过程
    rtmp规范1.0全面指南
    程序员小哥教你秋招拿大厂offer
    ubuntu配置samba解决linux的svn使用舒适问题
  • 原文地址:https://www.cnblogs.com/tianyudizhua/p/15505363.html
Copyright © 2020-2023  润新知