• BN归一化


    class BatchNormalization(Module):
        def __init__(self, in_feature, momentum=0.9, eps=1e-8):
            self.mu = 0
            self.var = 1
            self.momentum = momentum
            self.eps = eps
            self.in_feature = in_feature
            self.gamma = Parameter(np.ones(in_feature))
            self.beta = Parameter(np.zeros(in_feature))
            
        def forward(self, x):
            
            if not self.train_mode:
                y = (x - self.mu) / np.sqrt(self.var + self.eps)
                return y * self.gamma.value.reshape(1, -1, 1, 1) + self.beta.value.reshape(1, -1, 1, 1)
            
            self.b_mu = np.mean(x, axis=(0, 2, 3), keepdims=True)
            self.b_var = np.var(x, axis=(0, 2, 3), keepdims=True)
            self.y = (x - self.b_mu) / np.sqrt(self.b_var + self.eps)
            self.mu = self.b_mu * self.momentum + self.mu * (1 - self.momentum)
            
            n = x.size / x.shape[1]
            unbiased_var = self.b_var * n / (n - 1)
            self.var = unbiased_var * self.momentum + self.var * (1 - self.momentum)
            return self.y * self.gamma.value.reshape(1, -1, 1, 1) + self.beta.value.reshape(1, -1, 1, 1)
        
        def backward(self, G):
            self.gamma.delta = np.sum(G * self.y, axis=(0, 2, 3))
            self.beta.delta = np.sum(G, axis=(0, 2, 3))
            return G * self.gamma.value.reshape(1, -1, 1, 1) / np.sqrt(self.b_var + self.eps)
    View Code
    class Module:
        def __init__(self, name):
            self.name = name
            self.train_mode = False
            
        def __call__(self, *args):
            return self.forward(*args)
        
        def train(self):
            self.train_mode = True
            for m in self.modules():
                m.train()
            
        def eval(self):
            self.train_mode = False
            for m in self.modules():
                m.eval()
            
        def modules(self):
            ms = []
            for attr in self.__dict__:
                m = self.__dict__[attr]
                if isinstance(m, Module):
                    ms.append(m)
            return ms
        
        def params(self):
            ps = []
            for attr in self.__dict__:
                p = self.__dict__[attr]
                if isinstance(p, Parameter):
                    ps.append(p)
                
            ms = self.modules()
            for m in ms:
                ps.extend(m.params())
            return ps
        
        def info(self, n):
            ms = self.modules()
            output = f"{self.name}\n"
            for m in ms:
                output += ('  '*(n+1)) + f"{m.info(n+1)}\n"
            return output[:-1]
        
        def __repr__(self):
            return self.info(0)
    View Code
  • 相关阅读:
    PAT1038
    PAT1034
    PAT1033
    PAT1021
    PAT1030
    PAT1026
    PAT1063
    PAT1064
    PAT1053
    PAT1025
  • 原文地址:https://www.cnblogs.com/xiaoruirui/p/16834838.html
Copyright © 2020-2023  润新知