• 加载预训练模型修改类别数与不修改类别数


    #不修改类别数
    checkpoint = torch.load("./save_model_other/best_model_pre.pth") print(checkpoint.state_dict().keys()) model.load_state_dict(checkpoint.state_dict())

    #修改类别数
    ''' # 模型参数加载函数 def transfer_state_dict(pretrained_dict, model_dict): state_dict = {} for k, v in pretrained_dict.state_dict().items(): if k in model_dict.state_dict().keys(): state_dict[k] = v else: print("Missing keys in state_dict: {}".format(k)) return state_dict checkpoint = torch.load("./save_model_other/best_model_pre.pth") state_dict = transfer_state_dict(checkpoint, model) del state_dict["model2.2.weight"] del state_dict["model2.2.bias"] model_dict = model.state_dict() model_dict.update(state_dict) model.load_state_dict(model_dict) ''' #print(checkpoint.state_dict().keys()) #model.load_state_dict(checkpoint.state_dict())
  • 相关阅读:
    异或运算用途
    js正则表达式子校验
    SMART原则
    边际成本,机会成本,沉默成本
    cxf 例子
    CXF使用JMS作为传输协议的配置
    js验证手机号,身份证,车牌号验证
    redis应用
    list集合去重复元素
    lodop
  • 原文地址:https://www.cnblogs.com/xiaochouk/p/16591692.html
Copyright © 2020-2023  润新知