Version 1.0

(C) 2020 - Umberto Michelucci, Michela Sperti

This notebook is part of the book Applied Deep Learning: a case based approach, 2nd edition from APRESS by U. Michelucci and M. Sperti.

## Performance TIP¶

Activate GPU acceleration in the notebook to make it go much faster.

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import matplotlib.pyplot as plt


## Random Sampling¶

The following class is actually doing the random sampling for $$z$$.

class Sampling(layers.Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon


## Encoder part¶

latent_dim = 2

encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 28, 28, 1)]  0
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 14, 14, 32)   320         input_1[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 7, 7, 64)     18496       conv2d[0][0]
__________________________________________________________________________________________________
flatten (Flatten)               (None, 3136)         0           conv2d_1[0][0]
__________________________________________________________________________________________________
dense (Dense)                   (None, 16)           50192       flatten[0][0]
__________________________________________________________________________________________________
z_mean (Dense)                  (None, 2)            34          dense[0][0]
__________________________________________________________________________________________________
z_log_var (Dense)               (None, 2)            34          dense[0][0]
__________________________________________________________________________________________________
sampling (Sampling)             (None, 2)            0           z_mean[0][0]
z_log_var[0][0]
==================================================================================================
Total params: 69,076
Trainable params: 69,076
Non-trainable params: 0
__________________________________________________________________________________________________


## Decoder Part¶

latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_2 (InputLayer)         [(None, 2)]               0
_________________________________________________________________
dense_1 (Dense)              (None, 3136)              9408
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 64)          0
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 14, 14, 64)        36928
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 28, 28, 32)        18464
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 28, 28, 1)         289
=================================================================
Total params: 65,089
Trainable params: 65,089
Non-trainable params: 0
_________________________________________________________________


## Putting all together¶

class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder

def train_step(self, data):
if isinstance(data, tuple):
data = data[0]
z_mean, z_log_var, z = encoder(data)
reconstruction = decoder(z)
reconstruction_loss = tf.reduce_mean(
keras.losses.binary_crossentropy(data, reconstruction)
)
reconstruction_loss *= 28 * 28
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_mean(kl_loss)
kl_loss *= -0.5
total_loss = reconstruction_loss + kl_loss
return {
"loss": total_loss,
"reconstruction_loss": reconstruction_loss,
"kl_loss": kl_loss,
}


## Fitting the model¶

(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

vae = VAE(encoder, decoder)
vae.fit(mnist_digits, epochs=20, batch_size=128)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Epoch 1/20
547/547 [==============================] - 4s 7ms/step - loss: 212.1613 - reconstruction_loss: 210.9766 - kl_loss: 1.1847
Epoch 2/20
547/547 [==============================] - 4s 7ms/step - loss: 166.5538 - reconstruction_loss: 163.6692 - kl_loss: 2.8846
Epoch 3/20
547/547 [==============================] - 4s 7ms/step - loss: 159.2968 - reconstruction_loss: 156.2214 - kl_loss: 3.0754
Epoch 4/20
547/547 [==============================] - 4s 7ms/step - loss: 156.6721 - reconstruction_loss: 153.5094 - kl_loss: 3.1627
Epoch 5/20
547/547 [==============================] - 4s 7ms/step - loss: 155.1916 - reconstruction_loss: 151.9821 - kl_loss: 3.2095
Epoch 6/20
547/547 [==============================] - 4s 7ms/step - loss: 154.0873 - reconstruction_loss: 150.8478 - kl_loss: 3.2395
Epoch 7/20
547/547 [==============================] - 4s 7ms/step - loss: 153.1824 - reconstruction_loss: 149.9195 - kl_loss: 3.2629
Epoch 8/20
547/547 [==============================] - 4s 7ms/step - loss: 152.5317 - reconstruction_loss: 149.2627 - kl_loss: 3.2690
Epoch 9/20
547/547 [==============================] - 4s 7ms/step - loss: 151.9381 - reconstruction_loss: 148.6525 - kl_loss: 3.2856
Epoch 10/20
547/547 [==============================] - 4s 6ms/step - loss: 151.4633 - reconstruction_loss: 148.1544 - kl_loss: 3.3089
Epoch 11/20
547/547 [==============================] - 4s 7ms/step - loss: 151.0695 - reconstruction_loss: 147.7615 - kl_loss: 3.3081
Epoch 12/20
547/547 [==============================] - 4s 6ms/step - loss: 150.6613 - reconstruction_loss: 147.3283 - kl_loss: 3.3331
Epoch 13/20
547/547 [==============================] - 4s 7ms/step - loss: 150.3673 - reconstruction_loss: 147.0154 - kl_loss: 3.3518
Epoch 14/20
547/547 [==============================] - 4s 7ms/step - loss: 150.0603 - reconstruction_loss: 146.7028 - kl_loss: 3.3575
Epoch 15/20
547/547 [==============================] - 4s 7ms/step - loss: 149.8645 - reconstruction_loss: 146.4998 - kl_loss: 3.3647
Epoch 16/20
547/547 [==============================] - 4s 7ms/step - loss: 149.5632 - reconstruction_loss: 146.1746 - kl_loss: 3.3886
Epoch 17/20
547/547 [==============================] - 4s 7ms/step - loss: 149.3225 - reconstruction_loss: 145.9330 - kl_loss: 3.3895
Epoch 18/20
547/547 [==============================] - 4s 6ms/step - loss: 149.2139 - reconstruction_loss: 145.8156 - kl_loss: 3.3983
Epoch 19/20
547/547 [==============================] - 4s 7ms/step - loss: 148.9671 - reconstruction_loss: 145.5531 - kl_loss: 3.4140
Epoch 20/20
547/547 [==============================] - 4s 7ms/step - loss: 148.8045 - reconstruction_loss: 145.3841 - kl_loss: 3.4205

<tensorflow.python.keras.callbacks.History at 0x7f4f7018a6a0>

example = np.array([[1.1, 0.5]])
x_decoded = decoder.predict(example)
digit = x_decoded[0].reshape(28, 28)
plt.imshow(digit, cmap="Greys_r")

<matplotlib.image.AxesImage at 0x7f4f124da0b8>

example = np.array([[0.9, 0.2]])
x_decoded = decoder.predict(example)
digit = x_decoded[0].reshape(28, 28)
plt.imshow(digit, cmap="Greys_r")

<matplotlib.image.AxesImage at 0x7f7c7f22ac88>

import matplotlib.pyplot as plt

def plot_latent(encoder, decoder):
# display a n*n 2D manifold of digits
n = 20
digit_size = 28
scale = 1.0
figsize = 15
figure = np.zeros((digit_size * n, digit_size * n))
# linearly spaced coordinates corresponding to the 2D plot
# of digit classes in the latent space
grid_x = np.linspace(-scale, scale, n)
grid_y = np.linspace(-scale, scale, n)[::-1]

for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = decoder.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[
i * digit_size : (i + 1) * digit_size,
j * digit_size : (j + 1) * digit_size,
] = digit

plt.figure(figsize=(figsize, figsize))
start_range = digit_size // 2
end_range = n * digit_size + start_range + 1
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap="Greys_r")
plt.show()

plot_latent(encoder, decoder)