• PyTorch模型加载与保存的最佳实践


    一般来说PyTorch有两种保存和读取模型参数的方法。但这篇文章我记录了一种最佳实践,可以在加载模型时避免掉一些问题。

    第一种方案是保存整个模型:

    1
    torch.save(model_object, 'model.pth')

    第二种方法是保存模型网络参数:

    1
    torch.save(model_object.state_dict(), 'params.pth')

    加载的时候分别这样加载:

    1
    model = torch.load('model.pth')

    以及:

    1
    model_object.load_state_dict(torch.load('params.pth'))

    改进的方案

    注意到这个方案是因为模型在加载之后,loss会飙升之后再慢慢降回来。查阅有关分析之后,判定是优化器optimizer的问题。

    如果模型的保存是大专栏  PyTorch模型加载与保存的最佳实践g>为了恢复训练状态,那么可以考虑同时保存优化器optimizer的参数:

    1
    2
    3
    4
    5
    6
    7
    state = {
    'epoch': epoch,
    'net': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
    }
    torch.save(state, filepath)

    然后这样加载:

    1
    2
    3
    4
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['net'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch'] + 1

    如果模型的保存是为了方便以后进行validation和test,可以在加载完之后制定model.eval()固定dropout和BN层。

  • 相关阅读:
    第一个java程序和注释
    hadoop map端join
    hadoop wordcount入门
    hadoop reduce端联结
    hadoop streaming的使用
    HDU5752 Sqrt Bo
    L2-008 manacher 的应用
    L3-001 凑零钱
    L2-001 紧急救援
    如何在ubuntu下安装go开发环境
  • 原文地址:https://www.cnblogs.com/lijianming180/p/12370491.html
Copyright © 2020-2023  润新知