Recurrent Batch Normalization

by: Tim Cooijmans, Nicolas Ballas, César Laurent, Çağlar Gülçehre & Aaron Courville

ArXiv, 2016

This paper investigates the effect of batch normalization in recurrent neural networks. Since first proposed in Google Inception v2 network, batch normalization has become a standard technique in training deep neural networks. However, it has been reported in previous work [2] that hidden-to-hidden translation batch normalization may hurt the performance of LSTM. This paper, in contrast with previous work [2], shows that batch normalization can improve the performance of LSTMs when applied properly, and the scale parameter (gamma) in the batch normalization is crucial to avoid gradient vanishing.

Techniques such as drop out and batch normalization have been widely applied in training convolutional neural networks. In theory, one could directly apply dropout and batch normalization to each time step in recurrent neural networks, since a recurrent neural network is just a very deep feed-forward network with shared parameters over time steps (and is often so-implemented through unrolling), so. In practice, however, the large depth (time steps) of recurrent networks compared with ordinary ConvNets may prohibit naive application of these techniques.

Previous work [1] applies dropout only to input-to-hidden transitions in recurrent networks, but not hidden-to-hidden transitions, so that input data are only corrupted by a fixed number of dropout not related to the actual number of time steps. [2] follows the same intuition and only applied batch normalizations to input-to-hidden transitions.

This paper, however, applied batch normalization to both input-to-hidden and hidden-to-hidden transition in recurrent LSTM networks, and also to the cell state vector before the output. In other words, in this paper batch normalization is applied everywhere, except for the cell state update (so that the dynamics of LSTM cell is still preserved).

The paper empirically shows that when the scale parameter (gamma) is initialized to 1.0 as common practice in batch normalization, the gradient vanishes when back propagating through time. However, this problem can be fixed by initializing gamma to a smaller value. It is conjectured that tanh is the main reason for gradient vanishing, the paper recommends using 0.1 as initial scale parameter and 0 as initial bias parameter. Then, several experiments shows that the batch normalized LSTM proposed in this paper outperforms vanilla LSTM in different scenarios.

In its analysis, the paper empirically investigates the gradient vanishing problem in batch normalized LSTM, and links it to tanh function. However, it will be better to see mathematically analysis to the gradient flow. Also, it seems to me that the cell state normalization in output (BN in Eqn. 8 in the paper) is especially worth investigating, since all other batch normalization terms (those in Eqn. 6) can be merged into weight matrix (at test time), but the one in Eqn. 8 cannot and an extra affine transform has to be introduced into LSTM at test time. It is also worth investigating with ablation study the effect of each individual batch normalization term in the final performance.

In summary, this paper provides useful advice on how batch normalization can be applied to LSTM networks. It would be interesting to see further investigation on the gradient flow in recurrent neural network with batch normalization.

References
[1] W. Zaremba, I. Sutskever, O. Vinyals, and G. Brain, “Recurrent Neural Network Regularization,” arXiv Prepr., arXiv:1409.2329, 2014.
[2] C. Laurent, G. Pereyra, P. Brakel, Y. Zhang, and Y. Bengio, “Batch Normalized Recurrent Neural Networks,” arXiv Prepr., arXiv:1510.01378, 2015.

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s