Your first variational autoencoder

Version 1.0

(C) 2020 - Umberto Michelucci, Michela Sperti

This notebook is part of the book Applied Deep Learning: a case based approach, 2nd edition from APRESS by U. Michelucci and M. Sperti.

Performance TIP

Activate GPU acceleration in the notebook to make it go much faster.

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import matplotlib.pyplot as plt

Random Sampling

The following class is actually doing the random sampling for \(z\).

class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

Encoder part

latent_dim = 2

encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()
Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 14, 14, 32)   320         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 7, 7, 64)     18496       conv2d[0][0]                     
__________________________________________________________________________________________________
flatten (Flatten)               (None, 3136)         0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
dense (Dense)                   (None, 16)           50192       flatten[0][0]                    
__________________________________________________________________________________________________
z_mean (Dense)                  (None, 2)            34          dense[0][0]                      
__________________________________________________________________________________________________
z_log_var (Dense)               (None, 2)            34          dense[0][0]                      
__________________________________________________________________________________________________
sampling (Sampling)             (None, 2)            0           z_mean[0][0]                     
                                                                 z_log_var[0][0]                  
==================================================================================================
Total params: 69,076
Trainable params: 69,076
Non-trainable params: 0
__________________________________________________________________________________________________

Decoder Part

latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()
Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 2)]               0         
_________________________________________________________________
dense_1 (Dense)              (None, 3136)              9408      
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 14, 14, 64)        36928     
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 28, 28, 32)        18464     
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 28, 28, 1)         289       
=================================================================
Total params: 65,089
Trainable params: 65,089
Non-trainable params: 0
_________________________________________________________________

Putting all together

class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = encoder(data)
            reconstruction = decoder(z)
            reconstruction_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(data, reconstruction)
            )
            reconstruction_loss *= 28 * 28
            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

Fitting the model

(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(mnist_digits, epochs=20, batch_size=128)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Epoch 1/20
547/547 [==============================] - 4s 7ms/step - loss: 212.1613 - reconstruction_loss: 210.9766 - kl_loss: 1.1847
Epoch 2/20
547/547 [==============================] - 4s 7ms/step - loss: 166.5538 - reconstruction_loss: 163.6692 - kl_loss: 2.8846
Epoch 3/20
547/547 [==============================] - 4s 7ms/step - loss: 159.2968 - reconstruction_loss: 156.2214 - kl_loss: 3.0754
Epoch 4/20
547/547 [==============================] - 4s 7ms/step - loss: 156.6721 - reconstruction_loss: 153.5094 - kl_loss: 3.1627
Epoch 5/20
547/547 [==============================] - 4s 7ms/step - loss: 155.1916 - reconstruction_loss: 151.9821 - kl_loss: 3.2095
Epoch 6/20
547/547 [==============================] - 4s 7ms/step - loss: 154.0873 - reconstruction_loss: 150.8478 - kl_loss: 3.2395
Epoch 7/20
547/547 [==============================] - 4s 7ms/step - loss: 153.1824 - reconstruction_loss: 149.9195 - kl_loss: 3.2629
Epoch 8/20
547/547 [==============================] - 4s 7ms/step - loss: 152.5317 - reconstruction_loss: 149.2627 - kl_loss: 3.2690
Epoch 9/20
547/547 [==============================] - 4s 7ms/step - loss: 151.9381 - reconstruction_loss: 148.6525 - kl_loss: 3.2856
Epoch 10/20
547/547 [==============================] - 4s 6ms/step - loss: 151.4633 - reconstruction_loss: 148.1544 - kl_loss: 3.3089
Epoch 11/20
547/547 [==============================] - 4s 7ms/step - loss: 151.0695 - reconstruction_loss: 147.7615 - kl_loss: 3.3081
Epoch 12/20
547/547 [==============================] - 4s 6ms/step - loss: 150.6613 - reconstruction_loss: 147.3283 - kl_loss: 3.3331
Epoch 13/20
547/547 [==============================] - 4s 7ms/step - loss: 150.3673 - reconstruction_loss: 147.0154 - kl_loss: 3.3518
Epoch 14/20
547/547 [==============================] - 4s 7ms/step - loss: 150.0603 - reconstruction_loss: 146.7028 - kl_loss: 3.3575
Epoch 15/20
547/547 [==============================] - 4s 7ms/step - loss: 149.8645 - reconstruction_loss: 146.4998 - kl_loss: 3.3647
Epoch 16/20
547/547 [==============================] - 4s 7ms/step - loss: 149.5632 - reconstruction_loss: 146.1746 - kl_loss: 3.3886
Epoch 17/20
547/547 [==============================] - 4s 7ms/step - loss: 149.3225 - reconstruction_loss: 145.9330 - kl_loss: 3.3895
Epoch 18/20
547/547 [==============================] - 4s 6ms/step - loss: 149.2139 - reconstruction_loss: 145.8156 - kl_loss: 3.3983
Epoch 19/20
547/547 [==============================] - 4s 7ms/step - loss: 148.9671 - reconstruction_loss: 145.5531 - kl_loss: 3.4140
Epoch 20/20
547/547 [==============================] - 4s 7ms/step - loss: 148.8045 - reconstruction_loss: 145.3841 - kl_loss: 3.4205
<tensorflow.python.keras.callbacks.History at 0x7f4f7018a6a0>
example = np.array([[1.1, 0.5]])
x_decoded = decoder.predict(example)
digit = x_decoded[0].reshape(28, 28)
plt.imshow(digit, cmap="Greys_r")
<matplotlib.image.AxesImage at 0x7f4f124da0b8>
../_images/Variational_Autoencoders_12_1.png
example = np.array([[0.9, 0.2]])
x_decoded = decoder.predict(example)
digit = x_decoded[0].reshape(28, 28)
plt.imshow(digit, cmap="Greys_r")
<matplotlib.image.AxesImage at 0x7f7c7f22ac88>
../_images/Variational_Autoencoders_13_1.png
import matplotlib.pyplot as plt


def plot_latent(encoder, decoder):
    # display a n*n 2D manifold of digits
    n = 20
    digit_size = 28
    scale = 1.0
    figsize = 15
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent(encoder, decoder)
../_images/Variational_Autoencoders_14_0.png