• 【pytorch】保存与加载模型(1.7.0官方教程)


    0 项目场景

    pytorch训练完模型后,如何保存与加载?保存/加载有两种方式:一是保存/加载模型参数,二是保存/加载整个模型。

    1 模型参数

    保存/加载模型参数,官方推荐用这种方式,原因也给了:说这种方式对于日后恢复模型更具灵活性。

    1.1 保存

    torch.save(model.state_dict(), PATH)
    

    state_dict里保存有模型的参数,PATH是保存路径,推荐.pt.pth作为文件拓展名。

    1.2 加载

    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()
    

    TheModelClass是你定义的模型结构,PATH是保存路径。如果你的模型结构中含有dropoutbatch normalization层,在测试之前一定要加上model.eval()(如果没有可以不加),不然会产生错误的输出结果。

    2 整个模型

    保存/加载整个模型,官方不推荐这种方式,原因也给了:说是在其它项目中使用或重构后,代码可能会中断。

    2.1 保存

    torch.save(model, PATH)
    

    2.2 加载

    # Model class must be defined somewhere
    model = torch.load(PATH)
    model.eval()
    

    定义模型结构的类必须在代码中出现。这种保存/加载模型的方式从语法上来说更加简洁和直观,但是将模型引入其它项目中使用可能出错,所以只在自己的项目中使用应该没有问题,想将模型引入其它项目中使用还是推荐第一种保存/加载方式。

    3 断点续训

    顾名思义就是从上次没训练完的地方继续训练,这对高效训练来说具有重要意义。

    3.1 保存

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        ...
    }, PATH)
    

    其中model.state_dict()optimizer.state_dict()是必须要保存的,因为这两项会随着模型的训练而更新。epochloss等是作为记录用的,能让你直观的了解到目前训练到第几轮了,损失是多少。PATH是保存路径,建议以.tar为文件拓展名。

    3.2 加载

    model = TheModelClass(*args, **kwargs)
    optimizer = TheOptimizerClass(*args, **kwargs)
    
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
    model.train()
    # - or -
    model.eval()
    

    首先初始化模型和优化器,然后加载之前保存的模型和优化器参数。接着你可以选择从上一次结束的地方继续训练或者直接测试。继续训练的话加上model.train(),测试模型的话加上model.eval(),如果模型结构中没有dropoutbatch normalization层,可以不加。

    4 多个模型

    有时候你可能需要将多个模型保存到一个文件中,比如GAN

    4.1 保存

    torch.save({
        'modelA_state_dict': modelA.state_dict(),
        'modelB_state_dict': modelB.state_dict(),
        'optimizerA_state_dict': optimizerA.state_dict(),
        'optimizerB_state_dict': optimizerB.state_dict(),
        ...
    }, PATH)
    

    4.2 加载

    modelA = TheModelAClass(*args, **kwargs)
    modelB = TheModelBClass(*args, **kwargs)
    optimizerA = TheOptimizerAClass(*args, **kwargs)
    optimizerB = TheOptimizerBClass(*args, **kwargs)
    
    checkpoint = torch.load(PATH)
    modelA.load_state_dict(checkpoint['modelA_state_dict'])
    modelB.load_state_dict(checkpoint['modelB_state_dict'])
    optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
    optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
    
    modelA.train()
    modelB.train()
    # - or -
    modelA.eval()
    modelB.eval()
    

    5. 迁移学习

    有时候我们在训练一个新的模型B时可以用到已有的模型A的参数,比如迁移学习,这样就不用从头开始训了,模型可以很快的收敛,大大地提高了训练效率。

    5.1 保存

    torch.save(modelA.state_dict(), PATH)
    

    5.2 加载

    modelB = TheModelBClass(*args, **kwargs)
    modelB.load_state_dict(torch.load(PATH), strict=False)
    

    strict=False:模型A和模型B是不完全一样的,模型B训练的时候可能只需要A中一部分值,其它不要的值就丢掉,设置strict=False就是为了匹配需要的那部分值,忽略不需要的那部分值。

    6 关于设备

    如何在不同的设备,比如CPU或GPU上,保存与加载模型?

    6.1 GPU保存 & CPU加载

    模型在GPU上训练,但想把它加载到CPU上时,用这种方式

    6.1.1 GPU保存

    torch.save(model.state_dict(), PATH)
    

    6.1.2 CPU加载

    device = torch.device('cpu')
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH, map_location=device))
    

    6.2 GPU保存 & GPU加载

    模型在GPU上训练,想把它加载到GPU上时,用这种方式

    6.2.1 GPU保存

    torch.save(model.state_dict(), PATH)
    

    6.2.2 GPU加载

    device = torch.device("cuda")
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.to(device)
    # Make sure to call input = input.to(device) on any input tensors that you feed to the model
    

    6.3 CPU保存 & CPU加载

    模型在CPU上训练,想把它加载到CPU上时,用这种方式

    6.3.1 CPU保存

    torch.save(model.state_dict(), PATH)
    

    6.3.2 CPU加载

    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    

    6.4 CPU保存 & GPU加载

    模型在CPU上训练,想把它加载到GPU上时,用这种方式

    6.4.1 CPU保存

    torch.save(model.state_dict(), PATH)
    

    6.4.2 GPU加载

    device = torch.device("cuda")
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
    model.to(device)
    # Make sure to call input = input.to(device) on any input tensors that you feed to the model
    

    7 引用参考

    https://pytorch.org/tutorials/beginner/saving_loading_models.html
    
  • 相关阅读:
    Django REST framework 1
    爬虫基本原理
    QueryDict对象
    Django组件ModelForm
    MongoDB
    Algorithm
    BOM
    CSS
    Vue
    AliPay
  • 原文地址:https://www.cnblogs.com/ghgxj/p/14219114.html
Copyright © 2020-2023  润新知