背景
使用pytorch加载huggingface下载的albert-base-chinede模型出错
Exception has occurred: OSError
Unable to load weights from pytorch checkpoint file. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True.
模型地址:https://huggingface.co/models?search=albert_chinese
方法一:
参考以下文章删除缓存目录,问题还是存在
https://blog.csdn.net/znsoft/article/details/107725285、
https://github.com/huggingface/transformers/issues/6159
方法二:
使用另一台电脑加载相同模型,加载成功,查看两台电脑的torch、transformers版本,发现一个torch为1.1,另一个为torch1.7.x
参考pytorch官网,torch1.6之后修改了模型保存方式,高版本保存的模型,低版本无法加载
The 1.6 release of PyTorch switched torch.save to use a new zipfile-based file format. torch.load still retains the ability to load files in the old format. If for any reason you want torch.save to use the old format, pass the kwarg _use_new_zipfile_serialization=False.
解决方法:
- 升级torch为高版本
- 如果因为cuda兼容等问题无法升级,可以在高版本上加载模型,然后重新save并添加_use_new_zipfile_serialization=False
from transformers import *
import torch
pretrained = 'D:/07_data/albert_base_chinese'
tokenizer = BertTokenizer.from_pretrained(pretrained)
model = AlbertForMaskedLM.from_pretrained(pretrained)
# 它包装在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model
torch.save(model_to_save.state_dict(), 'pytorch_model_unzip.bin', _use_new_zipfile_serialization=False)
其他保存方法请参考:
https://blog.csdn.net/fendouaini/article/details/105322537
https://huggingface.co/transformers/serialization.html#serialization-best-practices