• 【PyTorch】使用笔记


     检查使用的GPU的索引

    • torch.cuda.is_available()   cuda是否可用
    • torch.cuda.device_count()   返回gpu数量
    • torch.cuda.get_device_name(0)  返回gpu名字,设备索引默认从0开始;
    • torch.cuda.current_device()  返回当前设备索引;

    如果使用GPU时,出现找不到cudnn可用,可能是因为GPU卡太老,pytorch不支持。

    卷积Conv2d

    示例:nn.Conv2d(1, 64, 2, 1, 1),分别对应conv2d( in_channels, out_channels, kernel_size, stide, padding)

    torch.nn.Conv2d( in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

    • dilation ( int )——控制卷积核元素之间的距离,默认1
    • groups ( int )——控制输入通道和输出通道的阻塞连接数,默认1
      • group=1,输出是所有输入的卷积;
      • group=2,相当于有两个并排的卷积层,每个卷积层计算输入通道的一半,输出也是输出通道的一半,随后将两个输出连接起来得到最后结果;
      • group=in_channels,每一个输入通道分别和它对应的卷积核进行卷积
    • kernel_size, stride, padding, dilation
      • 是int数时,表示height和width值相同
      • 是tuple数组时,则分别表示height和weight
    • bias ( bool ) —— 是否添加可学习的偏置到输出中

    二维卷积层,输入到输出尺寸的计算:

    Wout​ = (Win + 2*padding - k) / stride + 1

    其中 Win 是输入图像尺寸,padding补0数目,k 是卷积核尺寸,stride是步长

    如果dilation 大于 1, 代表是空洞卷积,则需计算 空洞卷积 的等效卷积核 尺寸K,带入上式 k 中

    k_d k_d − ∗ ( d − 1)

    其中 K 代表等效卷积核尺寸,k_d 代表实际卷积核尺寸, d 代表 dilation--空洞卷积的参数

    参数变量:

    • weight(tensor)——卷积的权重,(out_channels, in_channels, kernel_size)
    • bias(tensor)——卷积的偏置,(out_channels)

    BatchNorm2d

    示例,nn.BatchNorm2d(64),

    括号内写传入的channel数

    MaxPool2d

    示例,nn.MaxPool2d(2, 2, 0),分别对应MaxPool2d(kernel_size, stride, padding)

    torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

    Transforms

    transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0],自动每个数除以255。相当于归一化到 [0, 1]之间。

    transforms.Normalize() 使用如下公式进行归一化:

    channel=(channel-mean)/std (因为transforms.ToTensor()已经把数据处理成[0,1],那么(x-0.5)/0.5就是[-1.0, 1.0])

    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))

    transforms完整使用方法:

    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(), # 将图片转成tensor,并把数值normalize到[0,1]
        transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)),
    ])
    View Code

    然后在DataSet函数

    __init__时,self.transform=transform

     __getitem__(index)时, x = self.transform(X[index])

    那么问题来了,一般是normalize到哪个区间中呢??

    打印网络各层参数及图形输出尺寸

      1 def show_summary():
      2     from collections import OrderedDict
      3     import pandas as pd
      4     import numpy as np
      5 
      6     import torch
      7     from torch.autograd import Variable
      8     import torch.nn.functional as F
      9     from torch import nn
     10 
     11 
     12     def get_names_dict(model):
     13         """
     14         Recursive walk to get names including path
     15         """
     16         names = {}
     17         def _get_names(module, parent_name=''):
     18             for key, module in module.named_children():
     19                 name = parent_name + '.' + key if parent_name else key
     20                 names[name]=module
     21                 if isinstance(module, torch.nn.Module):
     22                     _get_names(module, parent_name=name)
     23         _get_names(model)
     24         return names
     25 
     26 
     27     def torch_summarize_df(input_size, model, weights=False, input_shape=True, nb_trainable=False):
     28         """
     29         Summarizes torch model by showing trainable parameters and weights.
     30         
     31         author: wassname
     32         url: https://gist.github.com/wassname/0fb8f95e4272e6bdd27bd7df386716b7
     33         license: MIT
     34         
     35         Modified from:
     36         - https://github.com/pytorch/pytorch/issues/2001#issuecomment-313735757
     37         - https://gist.github.com/wassname/0fb8f95e4272e6bdd27bd7df386716b7/
     38         
     39         Usage:
     40             import torchvision.models as models
     41             model = models.alexnet()
     42             df = torch_summarize_df(input_size=(3, 224,224), model=model)
     43             print(df)
     44             
     45             #              name class_name        input_shape       output_shape  nb_params
     46             # 1     features=>0     Conv2d  (-1, 3, 224, 224)   (-1, 64, 55, 55)      23296#(3*11*11+1)*64
     47             # 2     features=>1       ReLU   (-1, 64, 55, 55)   (-1, 64, 55, 55)          0
     48             # ...
     49         """
     50 
     51         def register_hook(module):
     52             def hook(module, input, output):
     53                 name = ''
     54                 for key, item in names.items():
     55                     if item == module:
     56                         name = key
     57                 #<class 'torch.nn.modules.conv.Conv2d'>
     58                 class_name = str(module.__class__).split('.')[-1].split("'")[0]
     59                 module_idx = len(summary)
     60 
     61                 m_key = module_idx + 1
     62 
     63                 summary[m_key] = OrderedDict()
     64                 summary[m_key]['name'] = name
     65                 summary[m_key]['class_name'] = class_name
     66                 if input_shape:
     67                     summary[m_key][
     68                         'input_shape'] = (-1, ) + tuple(input[0].size())[1:]
     69                 summary[m_key]['output_shape'] = (-1, ) + tuple(output.size())[1:]
     70                 if weights:
     71                     summary[m_key]['weights'] = list(
     72                         [tuple(p.size()) for p in module.parameters()])
     73 
     74     #             summary[m_key]['trainable'] = any([p.requires_grad for p in module.parameters()])
     75                 if nb_trainable:
     76                     params_trainable = sum([torch.LongTensor(list(p.size())).prod() for p in module.parameters() if p.requires_grad])
     77                     summary[m_key]['nb_trainable'] = params_trainable
     78                 params = sum([torch.LongTensor(list(p.size())).prod() for p in module.parameters()])
     79                 summary[m_key]['nb_params'] = params
     80                 
     81 
     82             if  not isinstance(module, nn.Sequential) and 
     83                 not isinstance(module, nn.ModuleList) and 
     84                 not (module == model):
     85                 hooks.append(module.register_forward_hook(hook))
     86                 
     87         # Names are stored in parent and path+name is unique not the name
     88         names = get_names_dict(model)
     89 
     90         # check if there are multiple inputs to the network
     91         if isinstance(input_size[0], (list, tuple)):
     92             x = [Variable(torch.rand(1, *in_size)) for in_size in input_size]
     93         else:
     94             x = Variable(torch.rand(1, *input_size))
     95 
     96         if next(model.parameters()).is_cuda:
     97             x = x.cuda()
     98 
     99         # create properties
    100         summary = OrderedDict()
    101         hooks = []
    102 
    103         # register hook
    104         model.apply(register_hook)
    105 
    106         # make a forward pass
    107         model(x)
    108 
    109         # remove these hooks
    110         for h in hooks:
    111             h.remove()
    112 
    113         # make dataframe
    114         df_summary = pd.DataFrame.from_dict(summary, orient='index')
    115 
    116         return df_summary
    117 
    118 
    119     # Test on alexnet
    120     import torchvision.models as models
    121     model = Classifier_1()
    122     df = torch_summarize_df(input_size=(1, 48, 48), model=model)
    123     print(df)
    124 
    125 show_summary()
    View Code

    效果如下:

          name   class_name        input_shape       output_shape        nb_params
    1    cnn.0       Conv2d    (-1, 1, 48, 48)   (-1, 64, 49, 49)      tensor(320)
    2    cnn.1  BatchNorm2d   (-1, 64, 49, 49)   (-1, 64, 49, 49)      tensor(128)
    3    cnn.2         ReLU   (-1, 64, 49, 49)   (-1, 64, 49, 49)                0
    4    cnn.3    MaxPool2d   (-1, 64, 49, 49)   (-1, 64, 24, 24)                0
    5    cnn.4       Conv2d   (-1, 64, 24, 24)  (-1, 128, 25, 25)    tensor(32896)
    6    cnn.5  BatchNorm2d  (-1, 128, 25, 25)  (-1, 128, 25, 25)      tensor(256)
    7    cnn.6         ReLU  (-1, 128, 25, 25)  (-1, 128, 25, 25)                0
    8    cnn.7    MaxPool2d  (-1, 128, 25, 25)  (-1, 128, 12, 12)                0
    9    cnn.8       Conv2d  (-1, 128, 12, 12)  (-1, 256, 12, 12)   tensor(295168)
    10   cnn.9  BatchNorm2d  (-1, 256, 12, 12)  (-1, 256, 12, 12)      tensor(512)
    11  cnn.10         ReLU  (-1, 256, 12, 12)  (-1, 256, 12, 12)                0
    12  cnn.11    MaxPool2d  (-1, 256, 12, 12)    (-1, 256, 6, 6)                0
    13  cnn.12       Conv2d    (-1, 256, 6, 6)    (-1, 256, 6, 6)   tensor(590080)
    14  cnn.13  BatchNorm2d    (-1, 256, 6, 6)    (-1, 256, 6, 6)      tensor(512)
    15  cnn.14         ReLU    (-1, 256, 6, 6)    (-1, 256, 6, 6)                0
    16  cnn.15    MaxPool2d    (-1, 256, 6, 6)    (-1, 256, 3, 3)                0
    17    fc.0       Linear         (-1, 2304)          (-1, 512)  tensor(1180160)
    18    fc.1      Dropout          (-1, 512)          (-1, 512)                0
    19    fc.2         ReLU          (-1, 512)          (-1, 512)                0
    20    fc.3       Linear          (-1, 512)          (-1, 256)   tensor(131328)
    21    fc.4      Dropout          (-1, 256)          (-1, 256)                0
    22    fc.5       Linear          (-1, 256)            (-1, 7)     tensor(1799)
    View Code

    参考:https://zhuanlan.zhihu.com/p/33992733

    保存模型,读取模型

    只保存、读取模型参数

    # 保存
    torch.save(model.state_dict(), 'parameter.pkl')
    # 加载
    model = TheModelClass(...)
    model.load_state_dict(torch.load('parameter.pkl'))
    View Code

    保存、读取完整模型

    # 保存
    torch.save(model, 'model.pkl')
    # 加载
    model = torch.load('model.pkl')
    View Code



    个人学习记录,如有描述欠妥之处,欢迎大家指出交流~(*^__^*) 

    参考:

    https://www.jianshu.com/p/6ba95579082c

  • 相关阅读:
    SpringBootMybatis 关于Mybatis-generator-gui的使用|数据库的编码注意点|各项复制模板
    SpringBootMVC04——Mybatis
    SpringBootMVC02——SpringDataJpa与ThymeLeaf
    Big Data(六)用户权限实操&HDFS-API实操
    Big Data(五)关于Hadoop的HA的实践搭建
    Big Data(四)关于Hadoop的HA&CAP理论详解
    Big Data(三)伪分布式和完全分布式的搭建
    SpringBootMVC02——Spring Data JPA的使用&JSP的使用
    SpringBootMVC01——A simple SpringBootMVC Sample
    yum安装mysql
  • 原文地址:https://www.cnblogs.com/YeZzz/p/13041650.html
Copyright © 2020-2023  润新知