https://pytorch123.com/ThirdSection/SaveModel/ 这个链接非常的详细!
1、#保存整个网络 torch.save(net, PATH)
# 保存网络中的参数, 速度快,占空间少 torch.save(net.state_dict(),PATH)
#--------------------------------------------------
#针对上面一般的保存方法,加载的方法分别是:
model_dict=torch.load(PATH)
model_dict=model.load_state_dict(torch.load(PATH))
2、然而,在实验中往往需要保存更多的信息,比如优化器的参数,那么可以采取下面的方法保存:
Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10]) Optimizer's state_dict: state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN, 'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma}, checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')