• 基于pytorch神经网络模型参数的加载及自定义


    最近在训练MobileNet时经常会对其模型参数进行各种操作,或者替换其中的几层之类的,故总结一下用到的对神经网络参数的各种操作方法。

    1.将matlab的.mat格式参数整理转换为tensor类型的模型参数

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import scipy.io as scio
    import os
    import numpy as np
    from config import Config
    import json
    config = Config()
    
    Mul = Config.MUL.astype('float32')
    Shift = Config.SHIFT.astype('float32')
    
    def load_json(j_fn):
        with open(j_fn,'r') as f:
            data = json.load(f)
        return data
    
    def save_json(dic,j_fn):
        json_str = json.dumps(dic)
        with open(j_fn,'w') as json_file:
            json_file.write(json_str)
    
    w_dic = {}
    b_dic = {}
    dic_all = {}
    for i in range(1,28,2):
        a = 'w'+str(i)    #按顺序命名
        b = 'b'+str(i)
        dic_all[a] = torch.from_numpy(scio.loadmat(config.WEIGHT_PATH + str(i)+'.mat')['wei'] * Mul[i-1]/(2**Shift[i-1])).permute(3, 2, 0, 1)
        dic_all[b] = torch.squeeze(torch.from_numpy(scio.loadmat(config.BIAS_PATH + str(i)+'.mat')['bias'] * Mul[i-1]/(2**Shift[i-1])))
        # print(a, 'Mul'+str(i-1))
        if i == 27:
            break
        a = 'w'+str(i+1)
        b = 'b'+str(i+1)
        dic_all[a] = torch.from_numpy(scio.loadmat(config.WEIGHT_PATH + str(i+1)+'.mat')['wei'] * Mul[i]/(2**Shift[i])).permute(2, 0, 1).unsqueeze(1)
        dic_all[b] = torch.squeeze(torch.from_numpy(scio.loadmat(config.BIAS_PATH + str(i+1)+'.mat')['bias'] * Mul[i]/(2**Shift[i])))
    #此处由于自己之前的命名问题,中间跳过了28层(池化层),直接按照有参数的层存储了参数,故27后的文件名变成了29
    dic_all['w29'] = torch.squeeze(torch.from_numpy(scio.loadmat(config.WEIGHT_PATH + '29.mat')['wei'] * Mul[28]/(2**Shift[28])).permute(3, 2, 0, 1)[1:, :])
    dic_all['b29'] = torch.squeeze(torch.from_numpy(scio.loadmat(config.BIAS_PATH + '29.mat')['bias'] * Mul[28]/(2**Shift[28])))[1:]
    #存为.pth文件
    param_fn = 'mobilenet_param_float.pth'
    torch.save(dic_all,param_fn)

    其中,mul和shift为量化后的乘子和移位参数(如果参数是浮点的则可以忽略这部分),另外,我的weight和bias是按照每层单独存在一个按照层序号命名的.mat文件中。且由于是从matlab的程序得到的,需要对参数的维度进行一下转换(permute()方法),同时对需要增加或减少维度的用unsqueeze()或torch.squeeze()方法进行改变(注意一定要和网络需要的输入维度相同才行)。最后按照原来对参数文件命名的方式保存成字典存成.pth文件(此时的字典还不能直接使用,需要在具体定义的网络中更换想应的key值)。

    *另外,代码中用来读取和存储.json文件的函数可以忽略,在这里没有用到

    2.将自定义网络的参数替换成自己需要的(DIY模型参数)

    from mobilenet_v1 import MobileNet_v1
    import torch
    from config import Config
    from load_data import loadtestdata
    from torch.autograd import Variable
    import numpy as np
    from Mobilenetv1_quantified import MobileNet, MobileNet_Bayer
    import json
    import matplotlib.pyplot as plt
    import numpy as np
    import torchvision
    
    param_keys = ['w1', 'b1', 'w2', 'b2', 'w3', 'b3', 'w4', 'b4', 'w5', 'b5', 'w6', 'b6', 'w7', 'b7', 'w8', 'b8', 'w9', 'b9', 'w10', 'b10', 'w11', 'b11', 'w12', 'b12', 'w13', 'b13', 'w14', 'b14', 'w15', 'b15', 'w16', 'b16', 'w17', 'b17', 'w18', 'b18', 'w19', 'b19', 'w20', 'b20', 'w21', 'b21', 'w22', 'b22', 'w23', 'b23', 'w24', 'b24', 'w25', 'b25', 'w26', 'b26', 'w27', 'b27', 'w29', 'b29']
    file_name = '/home/wangshuyu/MobileNet_v1/mobilenet_param_float.pth'
    dic_param = torch.load(file_name)      # 此处打开上一步存成的参数字典(按照每一层的权重、偏置的顺序)
    Model = MobileNet()                    # 实例化预定义的MobileNet网络(网络结构将在其他文中给出)
    net_dic = Model.state_dict()           # 加载预定义网络的参数字典,用来获取网络的键值
    for i, param_tensor in enumerate(net_dic ,0):
        net_dic[param_tensor] = dic_param[param_keys[i]]
        # print(i,'	',param_tensor ,net_dic[param_tensor].shape)   #可以用来查看参数的维度
    param_fn = 'MobileNet_float.pth'
    torch.save(net_dic,param_fn)
    # 下面开始是自己定义的另一个网络,只需要固定MobileNet其中一部分参数,剩下的部分参数用来训练,因此只从第11个之后的开始取参数
    model2 = MobileNet_Bayer()
    dic2 = model2.state_dict()
    key_list = list(net_dic.keys())
    for i, param_tensor in enumerate(dic2 ,0):
        if i > 11:
            dic2[param_tensor] = (net_dic[key_list[i - 2]])
        print(i, '	', param_tensor, dic2[param_tensor].shape)
    param_fn2 = 'MobileNet_Bayer.pth'
    torch.save(dic2,param_fn2)

    这里主要实现了将之前存好的量化后的mobilenet每层参数根据自己定义的网络构建了参数字典,在训练或测试的时候,只需要加载之前存好的预训练参数就可以了:

    from Mobilenetv1_quantified import MobileNet
    import torch
    from load_data import loadtestdata Net
    = MobileNet() param_dic = torch.load('MobileNet_float.pth') Net.load_state_dict(param_dic)
    classes = range(0,1000)
    test_data = loadtestdate()
    test(test_data, Net, classes)
  • 相关阅读:
    文件搜索和图像裁剪
    Mat的复制
    map
    substr
    cin,scanf
    strstr
    Applying vector median filter on RGB image based on matlab
    sobel算子的一些细节
    matlab 有趣小细节
    高斯混合模型(GMM)
  • 原文地址:https://www.cnblogs.com/nekoneko-15/p/13617633.html
Copyright © 2020-2023  润新知