Using Keras Functional API to construct a Residual Neural Network

What is a Residual Neural Network?

In principle, neural networks should get better results as they have more layers. A deeper network can learn anything a shallower version of itself can, plus (possibly) more than that. If, for a given dataset, there are no more things a network can learn by adding more layers to it, then it can just learn the identity mapping for those additional layers. In this way, it preserves the information in the previous layers and can not do worse than shallower ones. A network should be able to learn at least the identity mapping if it doesn’t find something better than that.

But in practice, things are not like that. Deeper networks are harder to optimize. With each extra layer that we add to a network, we add more difficulty in the process of training; it becomes harder for the optimization algorithm that we use to find the right parameters. As we add more layers, the network gets better results until at some point; then as we continue to add extra layers, the accuracy starts to drop.

Residual Networks attempt to solve this issue by adding the so-called skip connections. A skip connection is depicted in the image above. As I said previously, deeper networks should be able to learn at least identity mappings; this is what skip connections do: they add identity mappings from one point in the network to a forward point, and then lets the network to learn just that extra 𝐹(𝑥). If there are no more things the network can learn, then it just learns 𝐹(𝑥) as being 0. It turns out that it is easier for the network to learn a mapping closer to 0 than the identity mapping.

A block with a skip connection as in the image above is called a residual block, and a Residual Neural Network (ResNet) is just a concatenation of such blocks.

An interesting fact is that our brains have structures similar to residual networks, for example, cortical layer VI neurons get input from layer I, skipping intermediary layers.

A short introduction to Keras Functional API

If you are reading this, probably you are already familiar with the Sequential class which allows one to easily construct a neural network by just stacking layers one after another, like this:

from keras.models import Sequential
from keras.layers import Dense, Activation

model = Sequential([
    Dense(32, input_shape=(784,)),
    Activation('relu'),
    Dense(10),
    Activation('softmax'),
])

But this way of building neural networks is not sufficient for our needs. With the Sequential class, we can’t add skip connections. Keras also has the Model class, which can be used along with the functional API for creating layers to build more complex network architectures.

When constructed, the class keras.layers.Input returns a tensor object. A layer object in Keras can also be used as a function, calling it with a tensor object as a parameter. The returned object is a tensor that can then be passed as input to another layer, and so on.

As an example:

from keras.layers import Input, Dense
from keras.models import Model

inputs = Input(shape=(784,))
output_1 = Dense(64, activation='relu')(inputs)
output_2 = Dense(64, activation='relu')(output_1)
predictions = Dense(10, activation='softmax')(output_2)

model = Model(inputs=inputs, outputs=predictions)
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
model.fit(data, labels)

But the above code still constructs a network that is sequential, so no real use for this fancy functional syntax so far. The real use of this syntax is when using the so-called Merge layers with which one can combine more input tensors. A few examples of these layers are: Add, Subtract, Multiply, Average. The one that we will need in building residual blocks is Add.
An example that uses Add:

from keras.layers import Input, Dense, Add
from keras.models import Model

input1 = Input(shape=(16,))
x1 = Dense(8, activation='relu')(input1)
input2 = Input(shape=(32,))
x2 = Dense(8, activation='relu')(input2)

added = Add()([x1, x2])

out = Dense(4)(added)
model = Model(inputs=[input1, input2], outputs=out)

This is by no means a comprehensive guide to Keras functional API. If you want to learn more please refer to the docs.

Let’s implement a ResNet

Next, we will implement a ResNet along with its plain (without skip connections) counterpart, for comparison.

The ResNet that we will build here has the following structure:

  • Input with shape (32, 32, 3)
  • 1 Conv2D layer, with 64 filters
  • 2, 5, 5, 2 residual blocks with 64, 128, 256, and 512 filters
  • AveragePooling2D layer with pool size = 4
  • Flatten layer
  • Dense layer with 10 output nodes

It has a total of 30 conv+dense layers. All the kernel sizes are 3×3. We use ReLU activation and BatchNormalization after conv layers.
The plain version is the same except for the skip connections.

We create first a helper function that takes a tensor as input and adds relu and batch normalization to it:

def relu_bn(inputs: Tensor) -> Tensor:
    relu = ReLU()(inputs)
    bn = BatchNormalization()(relu)
    return bn

Then we create a function for constructing a residual block. It takes a tensor x as input and passes it through 2 conv layers; let’s call the output of these 2 conv layers as y. Then adds the input x to y, adds relu and batch normalization, and then returns the resulting tensor. When parameter downsample == True the first conv layer uses strides=2 to halve the output size and we use a conv layer with kernel_size=1 on input x to make it the same shape as y. The Add layer requires the input tensors to be of the same shape.

def residual_block(x: Tensor, downsample: bool, filters: int,                                        kernel_size: int = 3) -> Tensor:
    y = Conv2D(kernel_size=kernel_size,
               strides= (1 if not downsample else 2),
               filters=filters,
               padding="same")(x)
    y = relu_bn(y)
    y = Conv2D(kernel_size=kernel_size,
               strides=1,
               filters=filters,
               padding="same")(y)

    if downsample:
        x = Conv2D(kernel_size=1,
                   strides=2,
                   filters=filters,
                   padding="same")(x)
    out = Add()([x, y])
    out = relu_bn(out)
    return out

create_res_net() function puts everything together.

Here is the full code for this:

The plain network is constructed in a similar way, but it doesn’t have skip connections and we don’t use the residual_block() helper function; everything is done inside create_plain_net().

The code for the plain network:

Training on CIFAR-10 and seeing the results

CIFAR-10 is a dataset of 32×32 RGB images over 10 categories. It contains 50k train images and 10k test images.

Below is a sample of 10 random images from each class:

We will train both ResNet and PlainNet on this dataset for 20 epochs, and then compare the results.

The training took about 55 min for each ResNet and PlainNet on a machine with 1 NVIDIA Tesla K80. There is no significant difference in training time between ResNet and PlainNet.

The results that we got are shown below.

So, we got an increase of 1.59% in validation accuracy by using a ResNet on this dataset. The difference should be bigger on deeper networks. Feel free to experiment and see the results that you get.


If you want to learn more about Machine Learning with Keras and TensorFlow, here is a great book:

References

  1. Deep Residual Learning for Image Recognition
  2. Residual neural network — Wikipedia
  3. Guide to the Functional API — Keras documentation
  4. Model (functional API) — Keras documentation
  5. Merge Layers — Keras documentation
  6. CIFAR-10 and CIFAR-100 datasets

I hope you found this information useful and thanks for reading!

Let’s keep in touch! Feel free to follow me on social media: MediumLinkedInTwitterFacebook to get my latest posts.

This article is also posted on Medium here. You can have a look!


Dorian

Passionate about Data Science, AI, Programming & Math

0 0 vote
Article Rating
Subscribe
Notify of
0 Comments
Inline Feedbacks
View all comments
0
Would love your thoughts, please comment.x
()
x