1 start_epoch = params.start_epoch 2 stop_epoch = params.stop_epoch 3 if params.resume != '': 4 resume_file = get_resume_file('%s/checkpoints/%s'%(params.save_dir, params.resume), params.resume_epoch) #get_resume_file函数得到epoch.tar文件 5 if resume_file is not None: 6 tmp = torch.load(resume_file) 7 start_epoch = tmp['epoch']+1 8 model.load_state_dict(tmp['state']) 9 print(' resume the training with at {} epoch (model file {})'.format(start_epoch, params.resume))
对tar文件进行加载,并且选取其中需要的字典权重
1 tmp = torch.load(modelfile) # load parameter file:400.tar 2 try: 3 state = tmp['state'] 4 except KeyError: 5 state = tmp['model_state'] 6 except: 7 raise 8 state_keys = list(state.keys()) #列举字典中的key 9 for i, key in enumerate(state_keys): 10 if "feature." in key and not 'gamma' in key and not 'beta' in key: 11 newkey = key.replace("feature.","") 12 state[newkey] = state.pop(key) #删除该key并返回对应的值,不影响上面的训练 13 else: 14 state.pop(key) 15 16 model.load_state_dict(state)