• PyTorch基础内容


    修改模型

    import torchvision.models as models
    net = models.resnet50()
    # 查看模型定义
    print(net)
    
    # output
    ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        ......
      (fc): Linear(in_features=2048, out_features=1000, bias=True)
    )
    

    可以看到,Resnet50最后一层(fc)默认输出1000个节点。
    若想将该模型应用于10分类任务中,则需要将最后的输出节点数修改为10。

    import torch.nn as nn
    from collections import OrderedDict
    # 一层全连接层可能太少,可以再加一层。
    classifier = nn.Sequential(OrderedDict([('fc1',nn.Linear(2048,128)),
                              ('relu',nn.ReLU()),
                              ('dropout',nn.Dropout(0.5)),
                              ('fc2',nn.Linear(128,10)),
                              ('output',nn.Softmax(dim=1))]))
    # 将net的fc层替换为自定义的classifier
    net.fc = classifier
    

    再输出net,可以看到最后一层的fc已修改为定义的内容

      (fc): Sequential(
        (fc1): Linear(in_features=2048, out_features=128, bias=True)
        (relu): ReLU()
        (dropout): Dropout(p=0.5, inplace=False)
        (fc2): Linear(in_features=128, out_features=10, bias=True)
        (output): Softmax(dim=1)
      )
    

    PyTorch模型保存与读取

    模型保存

    • 模型存储数据格式:pt, pth, pkl
    import os
    import torch
    
    # 希望使用的GPU编号
    os.environ['CUDA_CISIBLE_DEVICES'] = '0'
    net.cuda()
    # 保存模型,数据格式可以为 pt, pth, pkl
    torch.save(net, './model.pt')
    # 保存权重
    torch.save(net.state_dict(), './weight.pt')
    

    模型加载

    # 读取模型
    loaded_model = torch.load('./model.pt')
    # 将权重加载到模型上,也可先读取到一个变量中,再为loaded_model赋值,分两步进行
    loaded_model.state_dict = torch.load('./weight.pt')
    loaded_model.cuda()
    
    loaded_dict = torch.load('./weight.pt')
    print(loaded_dict.keys())
    # odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 
    ......
    'fc.fc1.weight', 'fc.fc1.bias', 'fc.fc2.weight', 'fc.fc2.bias'])
    
  • 相关阅读:
    css滚动条设置
    动态背景插件three.min.ts
    富文本编辑器ckeditor使用(angular中)
    angular接口传参
    angular组件图标无法显示的问题
    angular项目搭建
    使用Flume
    centos7 安装Flume
    centos7 安装kubernetes
    Nginx的rewrite对地址进行重写
  • 原文地址:https://www.cnblogs.com/ArdenWang/p/16108400.html
Copyright © 2020-2023  润新知