Simple example of a custom training loop

Version 1.00

(C) 2020 - Umberto Michelucci, Michela Sperti

This notebook is part of the book Applied Deep Learning: a case based approach, 2nd edition from APRESS by U. Michelucci and M. Sperti.

The goal of this notebook is to show how a custom training loop looks like.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

print(tf.__version__)
2.7.0

A custom training loop is based on the fundamental capacity of Keras to evaluate gradients and derivatives automatically without you doing any math. Remember that training a network with backpropagation means calculating the gradients of the loss function.

In this example let’s start to see how to evaluate the derivative of the function

\[ y=x^2 \]

and let’s try to evaluate it at \(x=3\). If you know calculus you should see quickly that

\[ \frac{dy}{dx}(3) = 2x|_{x=3}=6 \]

This can easily done by using the GradientTape() context (https://www.tensorflow.org/api_docs/python/tf/GradientTape) as you can see in the cell below.

x = tf.Variable(3.0)

with tf.GradientTape() as tape:
  y = x**2

# dy = 2x * dx
dy_dx = tape.gradient(y, x)

print(dy_dx.numpy())
6.0

The same approach work exactly the same when you are dealing with a neural network. In the example below you can see how to calculate the gradient of the loss function of a simple neural network with just one dense layer with 2 neurons.

layer = tf.keras.layers.Dense(2, activation='relu')
x = tf.constant([[1., 2., 3.]])

with tf.GradientTape() as tape:
  # Forward pass
  y = layer(x)
  loss = tf.reduce_mean(y**2)

# Calculate gradients with respect to every trainable variable
grad = tape.gradient(loss, layer.trainable_variables)

In the cell below you can see how to retrieve the value of the trainable parameters and their gradients.

for var, g in zip(layer.trainable_variables, grad):
  print(f'{var.name}, shape: {g.shape}')
dense/kernel:0, shape: (3, 2)
dense/bias:0, shape: (2,)

Custom training loop for a neural network with MNIST data

Now let’s use this approach to a real network. Let’s consider a network with two dense layers, each having 64 neurons, used to classify MNIST images. If you don’t know what MNIST is check THIS LINK. TL;DR MNIST is a dataset composed of 70000 28x28 gray level images of digits. There are roughly 7000 images for each of each digit (0 to 9).

First things first: let’s define the network

inputs = keras.Input(shape=(784,), name="digits")
x1 = layers.Dense(64, activation="relu")(inputs)
x2 = layers.Dense(64, activation="relu")(x1)
outputs = layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 digits (InputLayer)         [(None, 784)]             0         
                                                                 
 dense_1 (Dense)             (None, 64)                50240     
                                                                 
 dense_2 (Dense)             (None, 64)                4160      
                                                                 
 predictions (Dense)         (None, 10)                650       
                                                                 
=================================================================
Total params: 55,050
Trainable params: 55,050
Non-trainable params: 0
_________________________________________________________________

Secondly we need the typical components of a netwok model: an optimizer, the loss function and the dataset. Nothing special to see here.

# Instantiate an optimizer.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
# 
# Note that when developing custom training loop you cannot 
# use model.evluate() therefore you need to track the metrics
# manually.
#
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

Now finally let’s train our small network with a custom training loop. Check the code and the comments and you should immediately see what each component is doing.

epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):

        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        with tf.GradientTape() as tape:

            # Run the forward pass of the layer.
            # The operations that the layer applies
            # to its inputs are going to be recorded
            # on the GradientTape.
            logits = model(x_batch_train, training=True)  # Logits for this minibatch

            # Compute the loss value for this minibatch.
            loss_value = loss_fn(y_batch_train, logits)

        # Use the gradient tape to automatically retrieve
        # the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, model.trainable_weights)

        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y_batch_train, logits)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %s samples" % ((step + 1) * 64))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()
Start of epoch 0
Training loss (for one batch) at step 0: 0.5484
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.3190
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.2064
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.2227
Seen so far: 38464 samples
Training loss (for one batch) at step 800: 0.3685
Seen so far: 51264 samples
Training acc over epoch: 0.8809

Start of epoch 1
Training loss (for one batch) at step 0: 0.1432
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.2973
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.2833
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.3132
Seen so far: 38464 samples
Training loss (for one batch) at step 800: 0.4082
Seen so far: 51264 samples
Training acc over epoch: 0.9048

At the end of each epoch you can see the training accuracy that we have decided to track. This short notebook should have given an idea on how to implement a custom training loop with Keras.