• Pytorch中自定义神经网络卷积核权重


    1. 自定义神经网络卷积核权重

           神经网络被深度学习者深深喜爱,究其原因之一是神经网络的便利性,使用者只需要根据自己的需求像搭积木一样搭建神经网络框架即可,搭建过程中我们只需要考虑卷积核的尺寸,输入输出通道数,卷积方式等等。

           我们使用惯了自带的参数后,当我们要自定义卷积核参数时,突然有种无从下手的感觉,哈哈哈哈哈哈哈哈~~,请允许我开心下,嘿嘿!因为笔者在初入神经网络时也遇到了同样的问题,当时踩了太多坑了,宝宝想哭(灬ꈍ ꈍ灬)!让我悲伤的是,找遍了各个资源区,也没有找到大家的分享。因此,我想把我的方法写出来,希望能帮助到各位宝宝,开心(*^▽^*)。

      话不多说,正文开始......

    2. 定义卷积核权重

      我这里是自定义的dtt系数卷积核权重,直接上权重代码:

    2.1 dtt系数权重Code

      def dtt_matrix(n): 这个函数是n*n的DTT系数矩阵,笔者的是8*8的系数矩阵。

           def dtt_kernel(out_channels, in_channels, kernel_size): 这个方法是设定权重,权重需要包括4个参数(输出通道数,输入通道数,卷积核高,卷积核宽),这里有很多细节要注意,宝宝们要亲自躺下坑,才能映像深刻也,我就不深究了哈,(#^.^#)。

    import numpy as np
    import torch
    import torch.nn as nn
    
    
    # ================================
    # DTT coefficient matrix of n * n
    # ================================
    def dtt_matrix(n):
        dtt_coe = np.zeros([n, n], dtype='float32')
        for i in range(0, n):
            dtt_coe[0, i] = 1/np.sqrt(n)
            dtt_coe[1, i] = (2*i + 1 - n)*np.sqrt(3/(n*(np.power(n, 2) - 1)))
        for i in range(1, n-1):
            dtt_coe[i+1, 0] = -np.sqrt((n-i-1)/(n+i+1)) * np.sqrt((2*(i+1)+1)/(2*(i+1)-1)) * dtt_coe[i, 0]
            dtt_coe[i+1, 1] = (1 + (i+1)*(i+2)/(1-n)) * dtt_coe[i+1, 0]
            dtt_coe[i+1, n-1] = np.power(-1, i+1) * dtt_coe[i+1, 0]
            dtt_coe[i+1, n-2] = np.power(-1, i+1) * dtt_coe[i+1, 1]
            for j in range(2, int(n/2)):
                t1 = (-(i+1) * (i+2) - (2*j-1) * (j-n-1) - j)/(j*(n-j))
                t2 = ((j-1) * (j-n-1))/(j * (n-j))
                dtt_coe[i+1, j] = t1 * dtt_coe[i+1, j-1] + t2 * dtt_coe[i+1, j-2]
                dtt_coe[i+1, n-j-1] = np.power(-1, i-1) * dtt_coe[i+1, j]
        return dtt_coe
    
    
    # ===============================================================
    # DTT coefficient matrix of (out_channels * in_channels * n * n)
    # ===============================================================
    def dtt_kernel(out_channels, in_channels, kernel_size):
        dtt_coe = dtt_matrix(kernel_size)
        dtt_coe = np.array(dtt_coe)
    
        dtt_weight = np.zeros([out_channels, in_channels, kernel_size, kernel_size], dtype='float32')
        temp = np.zeros([out_channels, in_channels, kernel_size, kernel_size], dtype='float32')
    
        order = 0
        for i in range(0, kernel_size):
            for j in range(0, kernel_size):
                dtt_row = dtt_coe[i, :]
                dtt_col = dtt_coe[:, j]
                dtt_row = dtt_row.reshape(len(dtt_row), 1)
                dtt_col = dtt_col.reshape(1, len(dtt_col))
                # print("dtt_row: ", dtt_row)
                # print("dtt_col: ", dtt_col)
                # print("i:", i, "j: ", j)
                temp[order, 0, :, :] = np.dot(dtt_row, dtt_col)
                order = order + 1
        for i in range(0, in_channels):
            for j in range(0, out_channels):
                # dtt_weight[j, i, :, :] = flip_180(temp[j, 0, :, :])
                dtt_weight[j, i, :, :] = temp[j, 0, :, :]
        return torch.tensor(dtt_weight)

    2.2 'same'方式卷积

      如果宝宝需要保持卷积前后的数据尺寸保持不变,即'same'方式卷积,那么你直接使用我这个卷积核(提一下哟,这个我也是借自某位前辈的,我当时没备注哇,先在这里感谢那位前辈,前辈如果路过,还请留言小生哈,(#^.^#))。

    import torch.utils.data
    from torch.nn import functional as F
    import math
    import torch
    from torch.nn.parameter import Parameter
    from torch.nn.functional import pad
    from torch.nn.modules import Module
    from torch.nn.modules.utils import _single, _pair, _triple
    
    class _ConvNd(Module):
        def __init__(self, in_channels, out_channels, kernel_size, stride,
                     padding, dilation, transposed, output_padding, groups, bias):
            super(_ConvNd, self).__init__()
            if in_channels % groups != 0:
                raise ValueError('in_channels must be divisible by groups')
            if out_channels % groups != 0:
                raise ValueError('out_channels must be divisible by groups')
            self.in_channels = in_channels
            self.out_channels = out_channels
            self.kernel_size = kernel_size
            self.stride = stride
            self.padding = padding
            self.dilation = dilation
            self.transposed = transposed
            self.output_padding = output_padding
            self.groups = groups
            if transposed:
                self.weight = Parameter(torch.Tensor(
                    in_channels, out_channels // groups, *kernel_size))
            else:
                self.weight = Parameter(torch.Tensor(
                    out_channels, in_channels // groups, *kernel_size))
            if bias:
                self.bias = Parameter(torch.Tensor(out_channels))
            else:
                self.register_parameter('bias', None)
            self.reset_parameters()
    
        def reset_parameters(self):
            n = self.in_channels
            for k in self.kernel_size:
                n *= k
            stdv = 1. / math.sqrt(n)
            self.weight.data.uniform_(-stdv, stdv)
            if self.bias is not None:
                self.bias.data.uniform_(-stdv, stdv)
    
        def __repr__(self):
            s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
                 ', stride={stride}')
            if self.padding != (0,) * len(self.padding):
                s += ', padding={padding}'
            if self.dilation != (1,) * len(self.dilation):
                s += ', dilation={dilation}'
            if self.output_padding != (0,) * len(self.output_padding):
                s += ', output_padding={output_padding}'
            if self.groups != 1:
                s += ', groups={groups}'
            if self.bias is None:
                s += ', bias=False'
            s += ')'
            return s.format(name=self.__class__.__name__, **self.__dict__)
    
    class Conv2d(_ConvNd):
        def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                     padding=0, dilation=1, groups=1, bias=True):
            kernel_size = _pair(kernel_size)
            stride = _pair(stride)
            padding = _pair(padding)
            dilation = _pair(dilation)
            super(Conv2d, self).__init__(
                in_channels, out_channels, kernel_size, stride, padding, dilation,
                False, _pair(0), groups, bias)
        def forward(self, input):
            return conv2d_same_padding(input, self.weight, self.bias, self.stride,
                            self.padding, self.dilation, self.groups)
    
    # custom con2d, because pytorch don't have "padding='same'" option.
    
    def conv2d_same_padding(input, weight, bias=None, stride=1, padding=1, dilation=1, groups=1):
        input_rows = input.size(2)
        filter_rows = weight.size(2)
        effective_filter_size_rows = (filter_rows - 1) * dilation[0] + 1
        out_rows = (input_rows + stride[0] - 1) // stride[0]
    
        input_cols = input.size(3)
        filter_cols = weight.size(3)
        effective_filter_size_cols = (filter_cols - 1) * dilation[1] + 1
        out_cols = (input_cols + stride[1] - 1) // stride[1]
    
        padding_needed = max(0, (out_rows - 1) * stride[0] + effective_filter_size_rows -input_rows)
        padding_rows = max(0, (out_rows - 1) * stride[0] +
                            (filter_rows - 1) * dilation[0] + 1 - input_rows)
        rows_odd = (padding_rows % 2 != 0)
        padding_cols = max(0, (out_cols - 1) * stride[1] +
                           (filter_cols - 1) * dilation[1] + 1 - input_cols)
        cols_odd = (padding_cols % 2 != 0)
        if rows_odd or cols_odd:
            input = pad(input, [0, int(cols_odd), 0, int(rows_odd)])
        return F.conv2d(input, weight, bias, stride,
                      padding=(padding_rows // 2, padding_cols // 2),
                      dilation=dilation, groups=groups)

     2.3 将权重赋给卷积核

      此处才是宝宝们最关心的吧,不慌,这就来了哈,开心(*^▽^*),进入正文了(#^.^#)。

      这里给了一个简单的网络模型(一个固定卷积+3个全连接,全连接是1*1的Conv2d),代码里我给了注释,宝宝们应该能秒懂滴,(*^▽^*)!

    import torch
    import torchvision
    import torch.nn as nn
    import torch.nn.functional as F
    import numpy as np
    import dtt_kernel
    import util
    import paddingSame
    
    # 定义权重
    dtt_weight1 = dtt_kernel.dtt_kernel(64, 2, 8)
    
    
    class DttNet(nn.Module):
        def __init__(self):
            super(DttNet, self).__init__()
    self.conv1
    = paddingSame.Conv2d(2, 64, 8)
         # 将权重赋给卷积核 self.conv1.weight
    = nn.Parameter(dtt_weight1, requires_grad=False) self.fc1 = util.fc(64, 512, 1) self.fc2 = util.fc(512, 128, 1) self.fc3 = util.fc(128, 2, 1, last=True) def forward(self, x): x = self.conv1(x) x = self.fc1(x) x = self.fc2(x) x = self.fc3(x) return x

     2.4 补充我的util类

    import torch.nn as nn
    
    
    def conv(in_channels, out_channels, kernel_size, stride=1, dilation=1, batch_norm=True):
        if batch_norm:
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=(kernel_size // 2)),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
        else:
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=(kernel_size // 2)),
                nn.ReLU()
            )
    
    
    def fc(in_channels, out_channels, kernel_size, stride=1, bias=True, last=False):
        if last:
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=(kernel_size // 2)),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=(kernel_size // 2)),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )

    3. 总结

      哇哦,写完了耶,不晓得宝宝们有没得收获呢,o((⊙﹏⊙))o,o((⊙﹏⊙))o。大家不懂的可以再下面留言哟,我会时常关注我家的园子呢。若有不足之处,宝宝们也在留言区吱我一下哟,我们下次再见,┏(^0^)┛┏(^0^)┛。

  • 相关阅读:
    前端工程化之动态数据代理
    webapp开发之需要知道的css细节
    html-webpack-plugin详解
    file-loader引起的html-webpack-plugin坑
    浅谈react受控组件与非受控组件
    React创建组件的三种方式及其区别
    react项目开发中遇到的问题
    css伪元素:before和:after用法详解
    python之文件操作
    python之range和xrange
  • 原文地址:https://www.cnblogs.com/haifwu/p/12818399.html
Copyright © 2020-2023  润新知