The Best RNN for Image Classification: RNN, LSTM, or GRU?

Recurrent Neural Networks (RNNs) are neural networks that are designed for predicting sequence data. Images are not traditionally seen as sequence data, but can be modeled as such. Today we’re going to be testing out how well three different RNN architectures, Simple RNNs, LSTMs, and GRUs, do on image classification via the MNIST digits dataset.

Overview of a Comparison of RNN Architectures for Image Classification

In this post we will go over the following topics:

  • What is the MNIST Digits Dataset?
  • What are Recurrent Neural Networks?
    • Simple Recurrent Neural Networks
    • Long Short-Term Memory (LSTM) Models
    • Gated Recurrent Unit (GRU) Models
  • Keras and Tensorflow for Building Neural Networks
  • Comparing RNN Models on Image Data Classification
    • Image Classification Accuracy with Simple RNNs
    • Image Classification Accuracy with LSTMs
    • Image Classification Accuracy with GRUs
  • RNN vs LSTM vs GRU on Image Classification Summary

What is the MNIST Digits Dataset?

MNIST Dataset, Image from InteliDig

The MNIST Digits Dataset is a set of 60,000 images of handwritten digits. Each image is 28×28 pixels and labeled with the correct digit. This is a famous dataset for neural networks. It’s a common benchmarking dataset measuring how well a neural network is trained. You can find more information about it on the MNIST Datasets Homepage.

What are Recurrent Neural Networks?

Recurrent Neural Networks are neural networks that contain one or more recurrent layers. Traditional neural networks contain feedforward layers. That means each cell or node in a layer passes its output on to the next layer. A recurrent cell also passes its output to itself some amount of times. Thus, just like a recursive function calls itself, a recurrent cell uses its own output. 

RNNs are perfect for predict sequence data in which outputs of a sequence depend on more than just the one datapoint. This is why RNNs are best known for text data or Natural Language Processing. Learn more about RNNs through the Core Concepts of NLP.

Simple Recurrent Neural Networks

Simple Recurrent Neural Networks are the basic RNN architecture. Cells or nodes used in simple RNNs do not have gates in them. Each layer fully connects to the next layer just like in a traditional neural network. To be classified as a simpler recurrent neural network, a neural network must have at least one recurrent layer. The neural net must also not contain LSTM or GRU layers. A simple recurrent layer can be added to Keras models via the layers.SimpleRNN class.

Long Short-Term Memory (LSTM) Models

Long Short-Term Memory or LSTM models are a variation on the RNN architecture. To classify as an LSTM, a neural network must have at least 1 LSTM layer. At their core, they have the same recurrent behavior. However, LSTM nodes are not like regular RNN nodes. LSTM nodes have three extra gates: the input gate, the output gate, and the forget gate. These extra gates translate into LSTMs having four times as many parameters as simple RNNs. LSTMs deal with the vanishing gradient problem for RNNs. LSTMs can be implemented in Keras via the layers.LSTM or layers.LSTMCell classes.

Gated Recurrent Unit (GRU) Models

Gated Recurrent Units (GRUs) are another variation on the recurrent neural network design. GRU cells are similar to Long Short-Term Memory cells. Unlike a cell or node in a traditional neural network, GRUs also contain gates. They contain an input gate and a forget gate. Unlike LSTMs, GRUs do not contain output gates. GRUs were initially introduced in 2014 as an alternative to LSTMs. They show similar performance most of the time, but have less training parameters. GRUs have been shown to outperform LSTMs on certain data sets such as smaller datasets with lower frequency of data. GRU layers can be added to a neural network in Keras through the layers.GRU class.

Keras and Tensorflow for Building Neural Networks

We’re going to use Keras on Tensorflow to build, train, and test our neural networks. Keras is a high level neural network building API and Tensorflow is a low level API. This means that we can use Tensorflow’s backend while interacting with the Keras interface. What’s the advantage of this? Keras is easier to interact with.

Comparing RNN Models on Image Data Classification

Now that we’ve learned a bit about the three best known RNN types, simple, LSTMs, and GRUs, let’s see how each one performs on image classification. Each of the models are built on Keras with the Sequential model structure. They all have three layers, an input layer corresponding to the architecture type, a batch normalization layer, and an output layer. Each neural network will be trained for 10 epochs. To see how to train each model, see How to Build a Simple RNN in Keras, How to Build an LSTM in Keras, and How to Build a GRU in Keras.

Image Classification Accuracy with Simple RNNs

An RNN with a simple RNN layer of 64 units (an output of 64), a Batch normalization layer following that, and a dense output layer of 10 units has 6858 parameters. It achieves a roughly 96% accuracy on the MNIST dataset after 10 epochs.

Image Classification Accuracy with LSTM Models

An LSTM model set up like our simple RNN model with a 64 cell LSTM layer, a batch normalization layer, and a fully connected output layer has 24714 parameters. It achieves an accuracy of roughly 96% after being trained for 10 epochs.

Image Classification Accuracy with GRU Models

A GRU RNN with the same set up as the LSTM and the simple RNN has 18954 parameters. It has an accuracy of roughly 95% after being trained for 10 epochs.

RNN vs LSTM vs GRU on Image Data Summary

Note that none of the three RNN architectures actually achieved their maximum validation accuracy in the 10th epoch. This tells us that we may not even need to train them for 10 epochs on the MNIST dataset. They all reach similar levels of accuracy with the GRU model being slightly lower.

Learn More

To learn more, feel free to reach out to me @yujian_tang on Twitter, connect with me on LinkedIn, and join our Discord. Remember to follow the blog to stay updated with cool Python projects and ways to level up your Software and Python skills! If you liked this article, please Tweet it, share it on LinkedIn, or tell your friends!

I run this site to help you and others like you find cool projects and practice software skills. If this is helpful for you and you enjoy your ad free site, please help fund this site by donating below! If you can’t donate right now, please think of us next time.

Yujian Tang
Yujian Tang

I started my professional software career interning for IBM in high school after winning ACSL two years in a row. I got into AI/ML in college where I published a first author paper to IEEE Big Data. After college I worked on the AutoML infrastructure at Amazon before leaving to work in startups. I believe I create the highest quality software content so that’s what I’m doing now. Drop a comment to let me know!


Make a one-time donation

Make a monthly donation

Make a yearly donation

Choose an amount


Or enter a custom amount


Your contribution is appreciated.

Your contribution is appreciated.

Your contribution is appreciated.

DonateDonate monthlyDonate yearly

Leave a Reply

%d bloggers like this: