class BatchNormalization(Module):
def __init__(self, in_feature, momentum=0.9, eps=1e-8):
self.mu = 0
self.var = 1
self.momentum = momentum
self.eps = eps
self.in_feature = in_feature
self.gamma = Parameter(np.ones(in_feature))
self.beta = Parameter(np.zeros(in_feature))
def forward(self, x):
if not self.train_mode:
y = (x - self.mu) / np.sqrt(self.var + self.eps)
return y * self.gamma.value.reshape(1, -1, 1, 1) + self.beta.value.reshape(1, -1, 1, 1)
self.b_mu = np.mean(x, axis=(0, 2, 3), keepdims=True)
self.b_var = np.var(x, axis=(0, 2, 3), keepdims=True)
self.y = (x - self.b_mu) / np.sqrt(self.b_var + self.eps)
self.mu = self.b_mu * self.momentum + self.mu * (1 - self.momentum)
n = x.size / x.shape[1]
unbiased_var = self.b_var * n / (n - 1)
self.var = unbiased_var * self.momentum + self.var * (1 - self.momentum)
return self.y * self.gamma.value.reshape(1, -1, 1, 1) + self.beta.value.reshape(1, -1, 1, 1)
def backward(self, G):
self.gamma.delta = np.sum(G * self.y, axis=(0, 2, 3))
self.beta.delta = np.sum(G, axis=(0, 2, 3))
return G * self.gamma.value.reshape(1, -1, 1, 1) / np.sqrt(self.b_var + self.eps)