• 权重pth读取更换其值重新保存方法(附带参数计算)


    本文主要解决模型权重迁移,主要使用pytorch读取某个权重,将其赋值给新权重格式,以下为原始代码:

    顺带参数计算函数代码:

    参数计算:

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

     权重更改代码如下:

    if __name__ == '__main__':
        train_pth_root=r'D:\Users\User\Desktop\mask-try\epoch_100.pth'  # 模型训练后得到的权重 
        pre_pth_root = r'D:\Users\User\Desktop\mask-try\resnet50-19c8e357.pth'  # 原始预训练权重,如mmdet的resnet预训练权重
        train_net=torch.load(train_pth_root)
        net_state_dict = train_net['state_dict']  # 训练模型权重保存字典键值
        pre_net=torch.load(pre_pth_root)
        # 以下替换和更改成预训练权重格式,这里需要根据具体情况决定,本代码是基于mmdection修改的
        keys_lst=[k.replace('backbone.','') for k in net_state_dict.keys() if 'backbone.' in k]
        for k,v in pre_net.items():
            if k in keys_lst:
                k_new='backbone.'+k
                pre_net[k]=net_state_dict[k_new]
        # 保存新权重
        torch.save(pre_net,'D:/Users/User/Desktop/mask-try/fasterrcnn_adaw.pth')

  • 相关阅读:
    爱上你的一百个理由 (网摘)
    梦想向右,沉默向左
    明夕何夕,君已陌路。
    不肯嫁的几种男人(转)
    一剪梅
    C# preprocessor Directives
    Language
    C# Language Tour
    Web application
    Unsafe code
  • 原文地址:https://www.cnblogs.com/tangjunjun/p/15820396.html
Copyright © 2020-2023  润新知