#不修改类别数
checkpoint = torch.load("./save_model_other/best_model_pre.pth")
print(checkpoint.state_dict().keys())
model.load_state_dict(checkpoint.state_dict())
#修改类别数
'''
# 模型参数加载函数
def transfer_state_dict(pretrained_dict, model_dict):
state_dict = {}
for k, v in pretrained_dict.state_dict().items():
if k in model_dict.state_dict().keys():
state_dict[k] = v
else:
print("Missing keys in state_dict: {}".format(k))
return state_dict
checkpoint = torch.load("./save_model_other/best_model_pre.pth")
state_dict = transfer_state_dict(checkpoint, model)
del state_dict["model2.2.weight"]
del state_dict["model2.2.bias"]
model_dict = model.state_dict()
model_dict.update(state_dict)
model.load_state_dict(model_dict)
'''
#print(checkpoint.state_dict().keys())
#model.load_state_dict(checkpoint.state_dict())