保存模型:
def save(model, model_path): torch.save(model.state_dict(), model_path)
加载模型:
def load(model, model_path): model.load_state_dict(torch.load(model_path))
这样会出现一个问题,即明明指定了某张卡,但总有一个模型的显存多出来,占到另一张卡上,很烦人,看到知乎有个方法可以解决
https://www.zhihu.com/question/67209417/answer/355059967
说是把模型的数据放在CPU上就可以解决,等试一下效果
def load(model, model_path):
model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))