Build a Simple Recurrent Neural Network with Keras

Earlier on this blog, we went over How to Build a Recurrent Neural Network from Scratch, How to Build a Neural Network from Scratch in Python 3, and How to Build a Neural Network with Sci-Kit Learn. As a continuation in the Neural Network series, this post is going to go over how to build a Recurrent Neural Network with Keras SimpleRNN in Tensorflow.

In this post we’ll use Keras and Tensorflow to create a simple RNN, and train and test it on the MNIST dataset. Here are the steps we’ll go through:

  1. Creating a Simple Recurrent Neural Network with Keras
    1. Importing the Right Modules
    2. Adding Layers to Your Model
  2. Training and Testing our RNN on the MNIST Dataset
    1. Load the MNIST dataset
    2. Compile the Recurrent Neural Network
    3. Train and Fit the Model
    4. Test the RNN Model

To follow along, you’ll need to install tensorflow which you can do using the line in the terminal below.

pip install tensorflow 

Creating a Simple Recurrent Neural Network with Keras

Simple RNN
Simple RNN Image from GitHub

Using Keras and Tensorflow makes building neural networks much easier to build. It’s much easier to build neural networks with these libraries than from scratch. The best reason to build a neural network from scratch is to understand how neural networks work. In practical situations, using a library like Tensorflow is the best approach. It’s straightforward and simple to build a neural network with Tensorflow and Keras, let’s take a look at how to do that.

Importing the Right Modules

The first thing we need to do is import the right modules. For this example, we’re going to be working with tensorflow. We don’t technically need to do the bottom two imports, but they save us time when writing so when we add layers, we don’t need to type tf.keras.layers. but can rather just write layers.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

Adding Layers to Your Model

The first thing we’re going to do is set up our model by adding layers. In this example we’ll be creating a three layer model. To start, we’ll set up a Sequential model. Sequential models are just your basic feedforward neural networks. After setting up the model we’ll add a SimpleRNN layer with 64 nodes, expecting an input of shape (None, 28) because that’s the input shape of the MNIST dataset. You’ll have to adjust your input_shape parameter based on your dataset.

After our initial SimpleRNN layer, we’ll add a BatchNormalization layer. This layer normalizes its inputs. This layer only matters for inference tasks. Finally, we’ll add a Dense layer which is simply a fully connected layer. We’ll use a layer with 10 nodes because there are 10 possible outputs for the MNIST dataset.

model = keras.Sequential()
model.add(layers.SimpleRNN(64, input_shape=(None, 28)))
model.add(layers.BatchNormalization())
model.add(layers.Dense(10))
print(model.summary())
Keras RNN model summary
Keras Simple RNN Model Summary

Training and Testing our RNN on the MNIST Dataset

At this point, we’ve set up our three layer RNN with a SimpleRNN layer, a BatchNormalization layer, and a fully connected Dense layer. Now that we have an RNN set up, let’s train it on the MNIST dataset.

Load the MNIST dataset

The first thing we’ll do is load up the MNIST dataset from Keras. We’ll use the load_data() function from the MNIST dataset to load a pre-separated training and testing dataset. After loading the datasets, we’ll normalize our training data by dividing by 255. This is due to the scale of 256 for RGB images. Finally, we’ll set aside a sample and sample label for testing later.

mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0
sample, sample_label = x_test[0], y_test[0]

Compile the Recurrent Neural Network

Before we train our Recurrent Neural Network, we’ll have to compile it. Compiling a neural network in Keras just means setting up the hyperparameters. For our example, let’s pass in a loss function, an optimizer, and the metrics we want to judge by.

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer="sgd",
    metrics=["accuracy"],
)

Train and Fit the Model

Now that the model is compiled, let’s train the model. To train the model in Keras, we just call the fit function. To use the fit function, we’ll need to pass in the training data for x and y, the validation, the batch_size, and the epochs. For this example, we’ll just train for 1 epoch.

model.fit(
    x_train, y_train, validation_data=(x_test, y_test), batch_size=64, epochs=1
)

Test the RNN Model

We’ve set up the RNN, compiled it, and trained it. Now let’s run a test and see how it does. We’ll use that sample data we set aside earlier and run it through a predict function from the model. Then we’ll print out the result. 

result = tf.argmax(model.predict(tf.expand_dims(sample, 0)), axis=1)
print(result.numpy(), sample_label)

We get an accuracy of about 96% after 10 epochs of training the Simple RNN from Keras, that’s pretty good.

Keras Simple RNN after 10 Epochs
Keras Simple RNN after 10 Epochs

Build a Simple RNN with Keras Summary

That’s it, that’s all there is to build a simple RNN with Keras and Tensorflow. In this post we went over how to set up a model by adding different layers. Specifically, we used the SimpleRNN, BatchNormaliztion, and Dense layers. Then we went over how to compile a neural network in Keras by passing it a loss function, an optimizer, and metrics to judge on. Finally, we loaded up the MNIST dataset, fit the model to it, and ran a test on one point of sample data.

Further Reading

I run this site to help you and others like you find cool projects and practice software skills. If this is helpful for you and you enjoy your ad free site, please help fund this site by donating below! If you can’t donate right now, please think of us next time.

Yujian Tang

3 thoughts on “Build a Simple Recurrent Neural Network with Keras

Leave a Reply

Discover more from PythonAlgos

Subscribe now to keep reading and get access to the full archive.

Continue reading