今天跑wav2vec的预训练模型:
import torch from fairseq.models.wav2vec import Wav2VecModel import librosa cp = torch.load('../models/wav2vec_large.pt') model = Wav2VecModel.build_model(cp['args'], task=None) model.load_state_dict(cp['model']) signal, sr = librosa.load('../static/test.wav') tensors = torch.from_numpy(signal).unsqueeze(0) z = model.feature_extractor(tensors) c = model.feature_aggregator(z) #print('c:', c) print(c.shape)
但是遇到一个非常恶心的问题,截图如下:
分明是按照网上的代码一步一步来的,就是报错,困扰的很长时间,最后发现是fairseq安装的版本不对。最开始的时候安装的版本是1.0,但是fairseq是一个更新非常快的库,但是代码中加载的模型已经提出来有一段时间了,所以会出现参数不匹配的问题,将fairseq版本改为0.9.0版本就可以运行出来了。
完整的代码见github:https://github.com/SolbiatiAlessandro/wav2vec.git
如果变换版本之后还是不行的话,建议参考https://blog.csdn.net/starinline/article/details/109944198这篇博客,里边的博主也遇到了相同的问题,但是他改变的是hydra/_internal/utils.py中的参数,细节请到该博客阅读
本人在尝试了上边的方法之后,问题仍然没有解决,所以我建议,大家如果也遇到了相同的问题,先尝试一下上边博主的方法,如果尝试无果的话,再尝试更换一个fairseq的版本。