• Tensorflow的MobileNetV1参数迁移到pytorch上并保存


    因为放弃tensorflow超级久了,也不想再去用它,因为明明很简单用pytorch十几行作出的代码,tensorflow的版本完全看不懂,我这个菜鸡还是老老实实刨地吧。mobilenet的代码网上一大堆,我把我写的贴出来吧,论文简单易读,连我这种英语渣渣两天就看完了。

    mobelnet的代码如下。

    import torch.nn as nn
    import torch
    class Conv_bn(nn.Module):
        def __init__(self,inp,oup,stride):
            super(Conv_bn, self).__init__()
            self.convBn=nn.Sequential(
                nn.Conv2d(inp,oup,3,stride,1,bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True)
            )
        def forward(self,x):
            out=self.convBn(x)
            return out
    
    class Conv_depth(nn.Module):
        def __init__(self,inp,oup,stride):
            super(Conv_depth, self).__init__()
            self.convDepthwise=nn.Sequential(
                nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
                nn.BatchNorm2d(inp),
                nn.ReLU(inplace=True),
    
                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True)
            )
        def forward(self,x):
            out=self.convDepthwise(x)
            return out
    
    
    class MobileNet(nn.Module):
        def __init__(self):
            super(MobileNet, self).__init__()
            self.mobelnet=nn.Sequential(
                Conv_bn(3, 32, 2),
                Conv_depth(32, 64, 1),
                Conv_depth(64, 128, 2),
                Conv_depth(128, 128, 1),
                Conv_depth(128, 256, 2),
                Conv_depth(256, 256, 1),
                Conv_depth(256, 512, 2),
                Conv_depth(512, 512, 1),
                Conv_depth(512, 512, 1),
                Conv_depth(512, 512, 1),
                Conv_depth(512, 512, 1),
                Conv_depth(512, 512, 1),
                Conv_depth(512, 1024, 2),
                Conv_depth(1024, 1024, 1),
                nn.AvgPool2d(7),)
    
            self.fc = nn.Linear(1024, 1000)
    
        # 网络的前向过程
        def forward(self, x):
            x=self.mobelnet(x)
            x=x.view(-1, 1024)
            x=self.fc(x)
            return x

    妈呀,简单吧,但是你不知道tensorflow的版本有多长啊。

    然后转参数把我难住了,没做过,参考了 https://www.jianshu.com/p/0a61caeb693b 这位同学的moielnetV3版本的改法,但是我真的不懂他那个字典怎么定义的,我每次model.层名 就开始给我出红杠杠,报错,我估计可能是他把层都封装成了对象吧,如果有懂的同学希望能给我讲讲哈。我贴我自己的代码吧。

    import json
    import tensorflow as tf
    import os
    from MobileNet.mobilenet_v1 import MobileNet
    import numpy as np
    import torch
    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    CHECKPOINT_PATH='/Users/wenyu/Desktop/TorchProject/MobileNet/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt'
    
    # write the json file
    def new_dict(checkpoint_path,json_path):
        reader=tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
        weights_shape =reader.get_variable_to_shape_map()
        print('the layer',weights_shape['MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean'])
        length=len(weights_shape['MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean'])
        # print(length)
        if not os.path.exists(json_path):
            weights_small = {n: 1 for (n, _) in reader.get_variable_to_shape_map().items()}
            keys_list=list(weights_small.keys())
            for key_ in keys_list:
                if "/ExponentialMovingAverage" in key_:
                    del weights_small[key_]
                elif "/RMSProp" in key_:
                    del weights_small[key_]
            with open(json_path, 'w') as writer:
                json.dump(weights_small, fp=writer, sort_keys=True)
        else:
            print('the json file has been write!')
    
    # get convBn_dict
    def get_convbn_convert_dict(layer_num):
        convert_dict={
            'mobelnet.'+str(layer_num)+'.convBn.0.weight':'MobilenetV1/Conv2d_'+str(layer_num)+'/weights',
            'mobelnet.'+str(layer_num)+'.convBn.1.weight':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/beta',
            'mobelnet.'+str(layer_num)+'.convBn.1.bias':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/gamma',
            'mobelnet.'+str(layer_num)+'.convBn.1.running_mean':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/moving_mean',
            'mobelnet.'+str(layer_num)+'.convBn.1.running_var':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/moving_variance'
        }
        return convert_dict
    
    # get depthWise_dict
    def get_dpwise_convert_dict(layer_num):
        convert_dict={
            'mobelnet.'+str(layer_num)+'.convDepthwise.0.weight':
                'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/depthwise_weights',
            'mobelnet.'+str(layer_num)+'.convDepthwise.1.weight':
                'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/beta',
            'mobelnet.'+str(layer_num)+'.convDepthwise.1.bias':
                'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/gamma',
            'mobelnet.'+str(layer_num)+'.convDepthwise.1.running_mean':
                'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/moving_mean',
            'mobelnet.'+str(layer_num)+'.convDepthwise.1.running_var':
                'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/moving_variance',
            'mobelnet.'+str(layer_num)+'.convDepthwise.3.weight':
                'MobilenetV1/Conv2d_'+str(layer_num)+'_pointwise/weights',
            'mobelnet.'+str(layer_num)+'.convDepthwise.4.weight':
                'MobilenetV1/Conv2d_'+str(layer_num)+'_pointwise/BatchNorm/beta',
            'mobelnet.' + str(layer_num) + '.convDepthwise.4.bias':
                'MobilenetV1/Conv2d_' + str(layer_num) + '_pointwise/BatchNorm/gamma',
            'mobelnet.' + str(layer_num) + '.convDepthwise.4.running_mean':
                'MobilenetV1/Conv2d_' + str(layer_num) + '_pointwise/BatchNorm/moving_mean',
            'mobelnet.' + str(layer_num) + '.convDepthwise.4.running_var':
                'MobilenetV1/Conv2d_' + str(layer_num) + '_pointwise/BatchNorm/moving_variance'
        }
        return convert_dict
    
    # get conversion_dict
    def get_model_dict(layers_num):
        merge = lambda dict1, dict2: {**dict1, **dict2}
        conversion_table = {}
        convBn_dict=get_convbn_convert_dict(0)
        conversion_table=merge(conversion_table,convBn_dict)
        for i in range(1,layers_num):
            dpWise_dict=get_dpwise_convert_dict(i)
            conversion_table=merge(conversion_table,dpWise_dict)
        # load_parameter(CHECKPOINT_PATH,conversion_table)
        return conversion_table
    def write_json(conversion_table,json_path):
        if not os.path.exists(json_path):
            with open(json_path, 'w') as writer:
                json.dump(conversion_table, fp=writer, sort_keys=True)
        else:
            print('the conversion table has been wirten!')
    
    def load_parameter(conversion_table):
        module=MobileNet()
        original_model_dict=module.state_dict()
        pth_list=list(conversion_table.keys())
        ckpt_list=list(conversion_table.values())
        assert len(pth_list)==len(ckpt_list) ,('the length is not right!')
        reader=tf.compat.v1.train.NewCheckpointReader(CHECKPOINT_PATH)
        for i,ckpt_name in enumerate(ckpt_list):
            ckpt_name_value=tf.compat.v1.train.load_variable(CHECKPOINT_PATH,ckpt_name)
            if 'Conv2d' in ckpt_name and 'weights' in ckpt_name:
                ckpt_name_value=np.transpose(ckpt_name_value,(3,2,0,1))
                if 'depthwise' in ckpt_name:
                    ckpt_name_value=np.transpose(ckpt_name_value,(1,0,2,3))
            elif 'BatchNorm' in ckpt_name and ckpt_name_value.ndim==1:
                # ckpt_name_value=np.transpose(ckpt_name_value)
                ckpt_name_value=ckpt_name_value
            pytorch_dict_key=pth_list[i]
            original_model_dict[pytorch_dict_key].data=torch.from_numpy(ckpt_name_value)
    
        torch.save(original_model_dict,'/Users/wenyu/Desktop/TorchProject/MobileNet/tf_to_torch.pth')
        return original_model_dict
    
    if __name__ == '__main__':
        conversion_table=get_model_dict(14)
        dic_mobel=load_parameter(conversion_table)
        print(dic_mobel['mobelnet.1.convDepthwise.0.weight'].shape)

    其中核心就在最后两个函数,可能代码看起来很简单,但是我想了好久要怎么做,第一次做很不熟练,但是通过这次巩固了很多numpy,tensor还有字典的基本知识,很充实。有问题可以在博客下面留言。

  • 相关阅读:
    Python正课101 —— 前端 入门
    Python正课100 —— 数据库 进阶5
    Python正课99 —— 数据库 进阶4
    Navicat15安装教程
    Python正课98 —— 数据库 进阶3
    Python正课97 —— 数据库 进阶2
    Python正课96 —— 数据库 进阶1
    解决:MySQL报错
    Python正课95 —— 数据库 入门
    作业3
  • 原文地址:https://www.cnblogs.com/daremosiranaihana/p/12833493.html
Copyright © 2020-2023  润新知