只保存模型参数
# 保存 torch.save(model.state_dict(), 'parameter.pkl') # 加载 model = TheModelClass(...) model.load_state_dict(torch.load('parameter.pkl'))
保存完整模型
# 保存 torch.save(model, 'model.pkl') # 加载 model = torch.load('model.pkl')
只保存模型参数
# 保存 torch.save(model.state_dict(), 'parameter.pkl') # 加载 model = TheModelClass(...) model.load_state_dict(torch.load('parameter.pkl'))
保存完整模型
# 保存 torch.save(model, 'model.pkl') # 加载 model = torch.load('model.pkl')