• 【PyTorch】Module Vivid


    Model

    children

    parameters

    modules

    state_dict

    Container

     1 import torch
     2 import torch.nn as nn
     3 from torchvision.models.resnet import (_resnet, Bottleneck)
     4 
     5 
     6 class model_with_container(nn.Module):
     7     def __init__(self, type):
     8         super(model_with_container, self).__init__()
     9         self.type = type
    10         if self.type == 'ModuleList':
    11             self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(5)])
    12         elif self.type == 'Sequential':
    13             self.seq = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
    14 
    15     def forward(self, x):
    16         if self.type == 'ModuleList':
    17             for idx in [2, 3, 4, 0, 1]:
    18                 x = self.linears[idx](x)
    19         elif self.type == 'Sequential':
    20             x = self.seq(x)
    21         return self.type, x
    22 
    23 
    24 def creat_model():
    25     model = _resnet('resnet50', Bottleneck, [2, 0, 0, 0], False, False)
    26     model.layer1[0].__delattr__('conv2')
    27     model.layer1[0].__delattr__('bn2')
    28     model.layer1[0].__delattr__('conv3')
    29     model.layer1[0].__delattr__('bn3')
    30     model.layer1[1].__delattr__('conv2')
    31     model.layer1[1].__delattr__('bn2')
    32     model.layer1[1].__delattr__('conv3')
    33     model.layer1[1].__delattr__('bn3')
    34     model.__delattr__('layer2')
    35     model.__delattr__('layer3')
    36     model.__delattr__('layer4')
    37     return model
    38 
    39 
    40 def test_model():
    41     model = creat_model()
    42     for names_and_children, children in zip(model.named_children(), model.children()):
    43         i, j = names_and_children
    44         k = children
    45         print(i, id(j) == id(k))
    46 
    47     for names_and_parameters, parameters in zip(model.named_parameters(), model.parameters()):
    48         i, j = names_and_parameters
    49         k = parameters
    50         print(i, id(j) == id(k))
    51 
    52     for names_and_modules, modules in zip(model.named_modules(), model.modules()):
    53         i, j = names_and_modules
    54         k = modules
    55         print(i, id(j) == id(k))
    56 
    57     for name, parameter in model.state_dict().items():
    58         print(name)
    59 
    60 
    61 def test_container():
    62     x = torch.randn([2, 10])
    63     model = model_with_container('ModuleList')
    64     # model = model_with_container('Sequential')
    65     print(model, model(x), sep='\n')
    66 
    67 
    68 if __name__ == '__main__':
    69     test_model()
    70     test_container()
  • 相关阅读:
    PTA(Advanced Level)1063.Set Similarity
    PTA(Advanced Level)1047.Student List for Course
    PTA(Advanced Level)1023.Palindromic Number
    PTA(Advanced Level)1023.Have Fun with Numbers
    PTA(Basic Level)1017.A除以B
    PTA(Advanced Level)1059.Prime Factors
    PTA(Advanced Level)1096.Consecutive Factors
    expected primary-expression before xx token错误处理
    PTA(Advanced Level)1078.Hashing
    PTA(Advanced Level)1015.Reversible Primes
  • 原文地址:https://www.cnblogs.com/VividBinGo/p/15990830.html
Copyright © 2020-2023  润新知