• [Pytorch]Pytorch 保存模型与加载模型(转)


    转自:知乎

    目录:

    • 保存模型与加载模型
    • 冻结一部分参数,训练另一部分参数
    • 采用不同的学习率进行训练

    1.保存模型与加载

    简单的保存与加载方法:

    # 保存整个网络
    torch.save(net, PATH) 
    # 保存网络中的参数, 速度快,占空间少
    torch.save(net.state_dict(),PATH)
    #--------------------------------------------------
    #针对上面一般的保存方法,加载的方法分别是:
    model_dict=torch.load(PATH)
    model_dict=model.load_state_dict(torch.load(PATH))
    


    然而,在实验中往往需要保存更多的信息,比如优化器的参数,那么可以采取下面的方法保存:

    torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,
                                'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},
                               checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')
    

    以上包含的信息有,epochID, state_dict, min loss, optimizer, 自定义损失函数的两个参数;格式以字典的格式存储。

    加载的方式:

    def load_checkpoint(model, checkpoint_PATH, optimizer):
        if checkpoint != None:
            model_CKPT = torch.load(checkpoint_PATH)
            model.load_state_dict(model_CKPT['state_dict'])
            print('loading checkpoint!')
            optimizer.load_state_dict(model_CKPT['optimizer'])
        return model, optimizer
    

    其他的参数可以通过以字典的方式获得

    但是,但是,我们可能修改了一部分网络,比如加了一些,删除一些,等等,那么需要过滤这些参数,加载方式:

    def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
        if checkpoint != 'No':
            print("loading checkpoint...")
            model_dict = model.state_dict()
            modelCheckpoint = torch.load(checkpoint)
            pretrained_dict = modelCheckpoint['state_dict']
            # 过滤操作
            new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
            model_dict.update(new_dict)
            # 打印出来,更新了多少的参数
            print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
            model.load_state_dict(model_dict)
            print("loaded finished!")
            # 如果不需要更新优化器那么设置为false
            if loadOptimizer == True:
                optimizer.load_state_dict(modelCheckpoint['optimizer'])
                print('loaded! optimizer')
            else:
                print('not loaded optimizer')
        else:
            print('No checkpoint is included')
        return model, optimizer
    

    2.冻结部分参数,训练另一部分参数

    1)添加下面一句话到模型中

    for p in self.parameters():
        p.requires_grad = False
    

    比如加载了resnet预训练模型之后,在resenet的基础上连接了新的模快,resenet模块那部分可以先暂时冻结不更新,只更新其他部分的参数,那么可以在下面加入上面那句话

    class RESNET_MF(nn.Module):
        def __init__(self, model, pretrained):
            super(RESNET_MF, self).__init__()
            self.resnet = model(pretrained)
            for p in self.parameters():
                p.requires_grad = False
            self.f = SpectralNorm(nn.Conv2d(2048, 512, 1))
            self.g = SpectralNorm(nn.Conv2d(2048, 512, 1))
            self.h = SpectralNorm(nn.Conv2d(2048, 2048, 1))
            ...
    

    同时在优化器中添加:filter(lambda p: p.requires_grad, model.parameters())

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, betas=(0.9, 0.999),
                                   eps=1e-08, weight_decay=1e-5)
    

    2) 参数保存在有序的字典中,那么可以通过查找参数的名字对应的id值,进行冻结

    查找的代码:

        model_dict = torch.load('net.pth.tar').state_dict()
        dict_name = list(model_dict)
        for i, p in enumerate(dict_name):
            print(i, p)
    

    保存一下这个文件,可以看到大致是这个样子的:

    0 gamma
    1 resnet.conv1.weight
    2 resnet.bn1.weight
    3 resnet.bn1.bias
    4 resnet.bn1.running_mean
    5 resnet.bn1.running_var
    6 resnet.layer1.0.conv1.weight
    7 resnet.layer1.0.bn1.weight
    8 resnet.layer1.0.bn1.bias
    9 resnet.layer1.0.bn1.running_mean
    ....
    

    同样在模型中添加这样的代码:

    for i,p in enumerate(net.parameters()):
        if i < 165:
            p.requires_grad = False
    

    在优化器中添加上面的那句话可以实现参数的屏蔽

  • 相关阅读:
    MySQL修改时区的方法小结
    MYSQL日期 字符串 时间戳互转
    2017php经典面试题
    PHP获得真实客户端的真实IP REMOTE_ADDR,HTTP_CLIENT_IP,HTTP_X_FORWARDED_FOR
    开放api接口签名验证
    MySql之ALTER命令用法详细解读(转)
    easyUI datagrid 清空
    webApi文档好帮手-apidoc使用教程
    驼峰命名和下划线命名互转php实现
    SQL Server 数据导入Mysql详细教程
  • 原文地址:https://www.cnblogs.com/kk17/p/10074188.html
Copyright © 2020-2023  润新知