• pytorch 状态字典:state_dict 模型和参数保存


    pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

    (注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

    优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

    备注:

    1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

    torch.save(model.state_dict(), PATH)
    2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()
    注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

    -------------------------------------------------------------------------------------------------------------------------------

    模态字典(state_dict)的保存(model是一个网络结构类的对象)

    1.1)仅保存学习到的参数,用以下命令

        torch.save(model.state_dict(), PATH)

    1.2)加载model.state_dict,用以下命令

        model = TheModelClass(*args, **kwargs)
        model.load_state_dict(torch.load(PATH))
        model.eval()

        备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

    -----------

    2.1)保存整个model的状态,用以下命令

        torch.save(model,PATH)

    2.2)加载整个model的状态,用以下命令:

              # Model class must be defined somewhere

        model = torch.load(PATH)

        model.eval()

    --------------------------------------------------------------------------------------------------------------------------------------

    state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

    ----------------------------------------------------------------------------------------------------------------------

    如何仅加载某一层的训练的到的参数(某一层的state)

    If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

    conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
    --------------------------------------------------------------------------------------------

    加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

    for param in list(model.pretrained.parameters()):
    param.requires_grad = False
    注意: requires_grad的操作对象是tensor.

    疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

    回答:经测试,不可以.model.conv1 没有requires_grad属性.

    ---------------------------------------------------------------------------------------------

    全部测试代码:

    #-*-coding:utf-8-*-
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim



    # define model
    class TheModelClass(nn.Module):
    def __init__(self):
    super(TheModelClass,self).__init__()
    self.conv1 = nn.Conv2d(3,6,5)
    self.pool = nn.MaxPool2d(2,2)
    self.conv2 = nn.Conv2d(6,16,5)
    self.fc1 = nn.Linear(16*5*5,120)
    self.fc2 = nn.Linear(120,84)
    self.fc3 = nn.Linear(84,10)

    def forward(self,x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1,16*5*5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

    # initial model
    model = TheModelClass()

    #initialize the optimizer
    optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

    # print the model's state_dict
    print("model's state_dict:")
    for param_tensor in model.state_dict():
    print(param_tensor,' ',model.state_dict()[param_tensor].size())

    print(" optimizer's state_dict")
    for var_name in optimizer.state_dict():
    print(var_name,' ',optimizer.state_dict()[var_name])

    print(" print particular param")
    print(' ',model.conv1.weight.size())
    print(' ',model.conv1.weight)

    print("------------------------------------")
    torch.save(model.state_dict(),'./model_state_dict.pt')
    # model_2 = TheModelClass()
    # model_2.load_state_dict(torch.load('./model_state_dict'))
    # model.eval()
    # print(' ',model_2.conv1.weight)
    # print((model_2.conv1.weight == model.conv1.weight).size())
    ## 仅仅加载某一层的参数
    conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
    print(conv1_weight_state==model.conv1.weight)

    model_2 = TheModelClass()
    model_2.load_state_dict(torch.load('./model_state_dict.pt'))
    model_2.conv1.requires_grad=False
    print(model_2.conv1.requires_grad)
    print(model_2.conv1.bias.requires_grad)
    ---------------------
    作者:wzg2016
    来源:CSDN
    原文:https://blog.csdn.net/strive_for_future/article/details/83240081
    版权声明:本文为博主原创文章,转载请附上博文链接!

  • 相关阅读:
    [转载]三十分钟理解:线性插值,双线性插值Bilinear Interpolation算法
    Java Web-EL表达式 in JSP
    MVC开发模式
    Java Web-Cookie和Session
    Java Web-JSP学习
    小知识:修改IDEA的模板
    Java Web-servlet、HTTP in servlet和捎带的Java绘图学习
    [转载]SSD原理与实现
    [转载]边框回归(Bounding Box Regression)
    jenkins添加TPS与服务器监控变化曲线图
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11194378.html
Copyright © 2020-2023  润新知