Lab: Batchnormalization Layer
What is a batchnormalization layer?
It is a layer that normalize the output before the activation layer. The original paper was proposed by Sergey Ioffe in 2015.
Batch Normalization Layer looks like this:
Why batchnormalization?
The distribution of each layer's input changes because the weights of the previous layer change as we update weights by the gradient descent. This is called a covariance shift, which makes the network training difficult.
For example, if the activation layer is a relu layer and the input of the activation layer is shifted to less than zeros, no weights will be activated!
One thing also worth mentioning is that $gamma$ and $eta$ parameters in $$ y = gamma hat{x} + eta $$ are also trainable.
What it means is that if we don't need the batchnormalization, its parameters will be updated such that it offsets the normalization step.
For example, assume that
egin{align}
gamma &= sqrt{sigma^2_B + epsilon}
eta &= mu_B
end{align}
then
$$ y_i = gamma hat{x_i} + eta = x_i $$
Also note that $mu$ and $sigma$ are computed using moving averages during the training step. However, during the test time, the computed $mu$ and $sigma$ will be used as fixed
Conclusion
- Always use the batch normalization!
Enough Talk: how to implement in Tensorflow
1. Load Library
- We use the famous MNIST data
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
%matplotlib inline
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
from ._conv import register_converters as _register_converters
WARNING:tensorflow:From <ipython-input-1-b9007a7f7abb>:6: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
mnist.train.images.shape
(55000, 784)
2. Define Model & Solver Class
- Object-Oriented-Programming allows to define multiple model easily
- Why do we separate model and solver classes?
- We can just swap out the model class in the Solver class when we need a different network architecture
- Usually we need one solver class
class Model:
"""Network Model Class
Note that this class has only the constructor.
The actual model is defined inside the constructor.
Attributes
----------
X : tf.float32
This is a tensorflow placeholder for MNIST images
Expected shape is [None, 784]
y : tf.float32
This is a tensorflow placeholder for MNIST labels (one hot encoded)
Expected shape is [None, 10]
mode : tf.bool
This is used for the batch normalization
It's `True` at training time and `False` at test time
loss : tf.float32
The loss function is a softmax cross entropy
train_op
This is simply the training op that minimizes the loss
accuracy : tf.float32
The accuracy operation
Examples
----------
>>> model = Model("Batch Norm", 32, 10)
"""
def __init__(self, name, input_dim, output_dim, hidden_dims=[32, 32], use_batchnorm=True, activation_fn=tf.nn.relu, optimizer=tf.train.AdamOptimizer, lr=0.01):
""" Constructor
Parameters
--------
name : str
The name of this network
The entire network will be created under `tf.variable_scope(name)`
input_dim : int
The input dimension
In this example, 784
output_dim : int
The number of output labels
There are 10 labels
hidden_dims : list (default: [32, 32])
len(hidden_dims) = number of layers
each element is the number of hidden units
use_batchnorm : bool (default: True)
If true, it will create the batchnormalization layer
activation_fn : TF functions (default: tf.nn.relu)
Activation Function
optimizer : TF optimizer (default: tf.train.AdamOptimizer)
Optimizer Function
lr : float (default: 0.01)
Learning rate
"""
with tf.variable_scope(name):
# Placeholders are defined
self.X = tf.placeholder(tf.float32, [None, input_dim], name='X')
self.y = tf.placeholder(tf.float32, [None, output_dim], name='y')
self.mode = tf.placeholder(tf.bool, name='train_mode')
# Loop over hidden layers
net = self.X
for i, h_dim in enumerate(hidden_dims):
with tf.variable_scope('layer{}'.format(i)):
net = tf.layers.dense(net, h_dim)
if use_batchnorm:
net = tf.layers.batch_normalization(net, training=self.mode)
net = activation_fn(net)
# Attach fully connected layers
net = tf.contrib.layers.flatten(net)
net = tf.layers.dense(net, output_dim)
self.loss = tf.nn.softmax_cross_entropy_with_logits(logits=net, labels=self.y)
self.loss = tf.reduce_mean(self.loss, name='loss')
# When using the batchnormalization layers,
# it is necessary to manually add the update operations
# because the moving averages are not included in the graph
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=name)
with tf.control_dependencies(update_ops):
self.train_op = optimizer(lr).minimize(self.loss)
# Accuracy etc
softmax = tf.nn.softmax(net, name='softmax')
self.accuracy = tf.equal(tf.argmax(softmax, 1), tf.argmax(self.y, 1))
self.accuracy = tf.reduce_mean(tf.cast(self.accuracy, tf.float32))
class Solver:
"""Solver class
This class will contain the model class and session
Attributes
----------
model : Model class
sess : TF session
Methods
----------
train(X, y)
Run the train_op and Returns the loss
evalulate(X, y, batch_size=None)
Returns "Loss" and "Accuracy"
If batch_size is given, it's computed using batch_size
because most GPU memories cannot handle the entire training data at once
Example
----------
>>> sess = tf.InteractiveSession()
>>> model = Model("BatchNorm", 32, 10)
>>> solver = Solver(sess, model)
# Train
>>> solver.train(X, y)
# Evaluate
>>> solver.evaluate(X, y)
"""
def __init__(self, sess, model):
self.model = model
self.sess = sess
def train(self, X, y):
feed = {
self.model.X: X,
self.model.y: y,
self.model.mode: True
}
train_op = self.model.train_op
loss = self.model.loss
return self.sess.run([train_op, loss], feed_dict=feed)
def evaluate(self, X, y, batch_size=None):
if batch_size:
N = X.shape[0]
total_loss = 0
total_acc = 0
for i in range(0, N, batch_size):
X_batch = X[i:i + batch_size]
y_batch = y[i:i + batch_size]
feed = {
self.model.X: X_batch,
self.model.y: y_batch,
self.model.mode: False
}
loss = self.model.loss
accuracy = self.model.accuracy
step_loss, step_acc = self.sess.run([loss, accuracy], feed_dict=feed)
total_loss += step_loss * X_batch.shape[0]
total_acc += step_acc * X_batch.shape[0]
total_loss /= N
total_acc /= N
return total_loss, total_acc
else:
feed = {
self.model.X: X,
self.model.y: y,
self.model.mode: False
}
loss = self.model.loss
accuracy = self.model.accuracy
return self.sess.run([loss, accuracy], feed_dict=feed)
3. Instantiate Model/Solver classes
input_dim = 784
output_dim = 10
N = 55000
tf.reset_default_graph()
sess = tf.InteractiveSession()
# We create two models: one with the batch norm and other without
bn = Model('batchnorm', input_dim, output_dim, use_batchnorm=True)
nn = Model('no_norm', input_dim, output_dim, use_batchnorm=False)
# We create two solvers: to train both models at the same time for comparison
# Usually we only need one solver class
bn_solver = Solver(sess, bn)
nn_solver = Solver(sess, nn)
WARNING:tensorflow:From <ipython-input-3-eb5ec237b799>:91: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.
See `tf.nn.softmax_cross_entropy_with_logits_v2`.
epoch_n = 10
batch_size = 32
# Save Losses and Accuracies every epoch
# We are going to plot them later
train_losses = []
train_accs = []
valid_losses = []
valid_accs = []
4. Run the train step
init = tf.global_variables_initializer()
sess.run(init)
for epoch in range(epoch_n):
for _ in range(N//batch_size):
X_batch, y_batch = mnist.train.next_batch(batch_size)
_, bn_loss = bn_solver.train(X_batch, y_batch)
_, nn_loss = nn_solver.train(X_batch, y_batch)
b_loss, b_acc = bn_solver.evaluate(mnist.train.images, mnist.train.labels, batch_size)
n_loss, n_acc = nn_solver.evaluate(mnist.train.images, mnist.train.labels, batch_size)
# Save train losses/acc
train_losses.append([b_loss, n_loss])
train_accs.append([b_acc, n_acc])
print(f'[Epoch {epoch}-TRAIN] Batchnorm Loss(Acc): {b_loss:.5f}({b_acc:.2%}) vs No Batchnorm Loss(Acc): {n_loss:.5f}({n_acc:.2%})')
b_loss, b_acc = bn_solver.evaluate(mnist.validation.images, mnist.validation.labels)
n_loss, n_acc = nn_solver.evaluate(mnist.validation.images, mnist.validation.labels)
# Save valid losses/acc
valid_losses.append([b_loss, n_loss])
valid_accs.append([b_acc, n_acc])
print(f'[Epoch {epoch}-VALID] Batchnorm Loss(Acc): {b_loss:.5f}({b_acc:.2%}) vs No Batchnorm Loss(Acc): {n_loss:.5f}({n_acc:.2%})')
print()
[Epoch 0-TRAIN] Batchnorm Loss(Acc): 0.15541(95.31%) vs No Batchnorm Loss(Acc): 0.20668(93.89%)
[Epoch 0-VALID] Batchnorm Loss(Acc): 0.17589(94.60%) vs No Batchnorm Loss(Acc): 0.21615(93.52%)
[Epoch 1-TRAIN] Batchnorm Loss(Acc): 0.10887(96.68%) vs No Batchnorm Loss(Acc): 0.16064(95.04%)
[Epoch 1-VALID] Batchnorm Loss(Acc): 0.12936(96.22%) vs No Batchnorm Loss(Acc): 0.17713(94.68%)
[Epoch 2-TRAIN] Batchnorm Loss(Acc): 0.10877(96.53%) vs No Batchnorm Loss(Acc): 0.16177(95.23%)
[Epoch 2-VALID] Batchnorm Loss(Acc): 0.13271(96.28%) vs No Batchnorm Loss(Acc): 0.17715(94.72%)
[Epoch 3-TRAIN] Batchnorm Loss(Acc): 0.07497(97.68%) vs No Batchnorm Loss(Acc): 0.15190(95.52%)
[Epoch 3-VALID] Batchnorm Loss(Acc): 0.10931(96.80%) vs No Batchnorm Loss(Acc): 0.18461(95.24%)
[Epoch 4-TRAIN] Batchnorm Loss(Acc): 0.07478(97.68%) vs No Batchnorm Loss(Acc): 0.14759(95.91%)
[Epoch 4-VALID] Batchnorm Loss(Acc): 0.10948(96.62%) vs No Batchnorm Loss(Acc): 0.17635(95.24%)
[Epoch 5-TRAIN] Batchnorm Loss(Acc): 0.05865(98.09%) vs No Batchnorm Loss(Acc): 0.12529(96.31%)
[Epoch 5-VALID] Batchnorm Loss(Acc): 0.09065(97.12%) vs No Batchnorm Loss(Acc): 0.16717(95.34%)
[Epoch 6-TRAIN] Batchnorm Loss(Acc): 0.05874(98.15%) vs No Batchnorm Loss(Acc): 0.15819(95.55%)
[Epoch 6-VALID] Batchnorm Loss(Acc): 0.09372(97.24%) vs No Batchnorm Loss(Acc): 0.19886(95.12%)
[Epoch 7-TRAIN] Batchnorm Loss(Acc): 0.04970(98.40%) vs No Batchnorm Loss(Acc): 0.11202(96.88%)
[Epoch 7-VALID] Batchnorm Loss(Acc): 0.09236(97.20%) vs No Batchnorm Loss(Acc): 0.17182(95.82%)
[Epoch 8-TRAIN] Batchnorm Loss(Acc): 0.04792(98.50%) vs No Batchnorm Loss(Acc): 0.12621(96.40%)
[Epoch 8-VALID] Batchnorm Loss(Acc): 0.09268(97.24%) vs No Batchnorm Loss(Acc): 0.18629(95.38%)
[Epoch 9-TRAIN] Batchnorm Loss(Acc): 0.05247(98.30%) vs No Batchnorm Loss(Acc): 0.16342(95.51%)
[Epoch 9-VALID] Batchnorm Loss(Acc): 0.10053(97.20%) vs No Batchnorm Loss(Acc): 0.23571(94.26%)
5. Performance Comparison
- With the batchnormalization, the loss is lower and it's more accurate too!
bn_solver.evaluate(mnist.test.images, mnist.test.labels)
[0.105519876, 0.9713]
nn_solver.evaluate(mnist.test.images, mnist.test.labels)
[0.23670065, 0.9449]
def plot_compare(loss_list: list, ylim=None, title=None) -> None:
bn = [i[0] for i in loss_list]
nn = [i[1] for i in loss_list]
plt.figure(figsize=(15, 10))
plt.plot(bn, label='With BN')
plt.plot(nn, label='Without BN')
if ylim:
plt.ylim(ylim)
if title:
plt.title(title)
plt.legend()
plt.grid('on')
plt.show()
plot_compare(train_losses, title='Training Loss at Epoch')
/anaconda3/lib/python3.6/site-packages/matplotlib/cbook/deprecation.py:107: MatplotlibDeprecationWarning: Passing one of 'on', 'true', 'off', 'false' as a boolean is deprecated; use an actual boolean (True/False) instead.
warnings.warn(message, mplDeprecation, stacklevel=1)
plot_compare(train_accs, [0, 1.0], title="Training Acc at Epoch")
/anaconda3/lib/python3.6/site-packages/matplotlib/cbook/deprecation.py:107: MatplotlibDeprecationWarning: Passing one of 'on', 'true', 'off', 'false' as a boolean is deprecated; use an actual boolean (True/False) instead.
warnings.warn(message, mplDeprecation, stacklevel=1)
plot_compare(valid_losses, title='Validation Loss at Epoch')
/anaconda3/lib/python3.6/site-packages/matplotlib/cbook/deprecation.py:107: MatplotlibDeprecationWarning: Passing one of 'on', 'true', 'off', 'false' as a boolean is deprecated; use an actual boolean (True/False) instead.
warnings.warn(message, mplDeprecation, stacklevel=1)
plot_compare(valid_accs, [0, 1.], title='Validation Acc at Epoch')
/anaconda3/lib/python3.6/site-packages/matplotlib/cbook/deprecation.py:107: MatplotlibDeprecationWarning: Passing one of 'on', 'true', 'off', 'false' as a boolean is deprecated; use an actual boolean (True/False) instead.
warnings.warn(message, mplDeprecation, stacklevel=1)