Your first variational autoencoder¶
Version 1.0
(C) 2020 - Umberto Michelucci, Michela Sperti
This notebook is part of the book Applied Deep Learning: a case based approach, 2nd edition from APRESS by U. Michelucci and M. Sperti.
Performance TIP¶
Activate GPU acceleration in the notebook to make it go much faster.
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
Random Sampling¶
The following class is actually doing the random sampling for \(z\).
class Sampling(layers.Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
Encoder part¶
latent_dim = 2
encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()
Model: "encoder"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 28, 28, 1)] 0
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 14, 14, 32) 320 input_1[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 7, 7, 64) 18496 conv2d[0][0]
__________________________________________________________________________________________________
flatten (Flatten) (None, 3136) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 16) 50192 flatten[0][0]
__________________________________________________________________________________________________
z_mean (Dense) (None, 2) 34 dense[0][0]
__________________________________________________________________________________________________
z_log_var (Dense) (None, 2) 34 dense[0][0]
__________________________________________________________________________________________________
sampling (Sampling) (None, 2) 0 z_mean[0][0]
z_log_var[0][0]
==================================================================================================
Total params: 69,076
Trainable params: 69,076
Non-trainable params: 0
__________________________________________________________________________________________________
Decoder Part¶
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()
Model: "decoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 2)] 0
_________________________________________________________________
dense_1 (Dense) (None, 3136) 9408
_________________________________________________________________
reshape (Reshape) (None, 7, 7, 64) 0
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 14, 14, 64) 36928
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 28, 28, 32) 18464
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 28, 28, 1) 289
=================================================================
Total params: 65,089
Trainable params: 65,089
Non-trainable params: 0
_________________________________________________________________
Putting all together¶
class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def train_step(self, data):
if isinstance(data, tuple):
data = data[0]
with tf.GradientTape() as tape:
z_mean, z_log_var, z = encoder(data)
reconstruction = decoder(z)
reconstruction_loss = tf.reduce_mean(
keras.losses.binary_crossentropy(data, reconstruction)
)
reconstruction_loss *= 28 * 28
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_mean(kl_loss)
kl_loss *= -0.5
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
return {
"loss": total_loss,
"reconstruction_loss": reconstruction_loss,
"kl_loss": kl_loss,
}
Fitting the model¶
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(mnist_digits, epochs=20, batch_size=128)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Epoch 1/20
547/547 [==============================] - 4s 7ms/step - loss: 212.1613 - reconstruction_loss: 210.9766 - kl_loss: 1.1847
Epoch 2/20
547/547 [==============================] - 4s 7ms/step - loss: 166.5538 - reconstruction_loss: 163.6692 - kl_loss: 2.8846
Epoch 3/20
547/547 [==============================] - 4s 7ms/step - loss: 159.2968 - reconstruction_loss: 156.2214 - kl_loss: 3.0754
Epoch 4/20
547/547 [==============================] - 4s 7ms/step - loss: 156.6721 - reconstruction_loss: 153.5094 - kl_loss: 3.1627
Epoch 5/20
547/547 [==============================] - 4s 7ms/step - loss: 155.1916 - reconstruction_loss: 151.9821 - kl_loss: 3.2095
Epoch 6/20
547/547 [==============================] - 4s 7ms/step - loss: 154.0873 - reconstruction_loss: 150.8478 - kl_loss: 3.2395
Epoch 7/20
547/547 [==============================] - 4s 7ms/step - loss: 153.1824 - reconstruction_loss: 149.9195 - kl_loss: 3.2629
Epoch 8/20
547/547 [==============================] - 4s 7ms/step - loss: 152.5317 - reconstruction_loss: 149.2627 - kl_loss: 3.2690
Epoch 9/20
547/547 [==============================] - 4s 7ms/step - loss: 151.9381 - reconstruction_loss: 148.6525 - kl_loss: 3.2856
Epoch 10/20
547/547 [==============================] - 4s 6ms/step - loss: 151.4633 - reconstruction_loss: 148.1544 - kl_loss: 3.3089
Epoch 11/20
547/547 [==============================] - 4s 7ms/step - loss: 151.0695 - reconstruction_loss: 147.7615 - kl_loss: 3.3081
Epoch 12/20
547/547 [==============================] - 4s 6ms/step - loss: 150.6613 - reconstruction_loss: 147.3283 - kl_loss: 3.3331
Epoch 13/20
547/547 [==============================] - 4s 7ms/step - loss: 150.3673 - reconstruction_loss: 147.0154 - kl_loss: 3.3518
Epoch 14/20
547/547 [==============================] - 4s 7ms/step - loss: 150.0603 - reconstruction_loss: 146.7028 - kl_loss: 3.3575
Epoch 15/20
547/547 [==============================] - 4s 7ms/step - loss: 149.8645 - reconstruction_loss: 146.4998 - kl_loss: 3.3647
Epoch 16/20
547/547 [==============================] - 4s 7ms/step - loss: 149.5632 - reconstruction_loss: 146.1746 - kl_loss: 3.3886
Epoch 17/20
547/547 [==============================] - 4s 7ms/step - loss: 149.3225 - reconstruction_loss: 145.9330 - kl_loss: 3.3895
Epoch 18/20
547/547 [==============================] - 4s 6ms/step - loss: 149.2139 - reconstruction_loss: 145.8156 - kl_loss: 3.3983
Epoch 19/20
547/547 [==============================] - 4s 7ms/step - loss: 148.9671 - reconstruction_loss: 145.5531 - kl_loss: 3.4140
Epoch 20/20
547/547 [==============================] - 4s 7ms/step - loss: 148.8045 - reconstruction_loss: 145.3841 - kl_loss: 3.4205
<tensorflow.python.keras.callbacks.History at 0x7f4f7018a6a0>
example = np.array([[1.1, 0.5]])
x_decoded = decoder.predict(example)
digit = x_decoded[0].reshape(28, 28)
plt.imshow(digit, cmap="Greys_r")
<matplotlib.image.AxesImage at 0x7f4f124da0b8>
data:image/s3,"s3://crabby-images/6127d/6127d0e84d8a484e0387c0f59045a1ef2599030e" alt="../_images/Variational_Autoencoders_12_1.png"
example = np.array([[0.9, 0.2]])
x_decoded = decoder.predict(example)
digit = x_decoded[0].reshape(28, 28)
plt.imshow(digit, cmap="Greys_r")
<matplotlib.image.AxesImage at 0x7f7c7f22ac88>
data:image/s3,"s3://crabby-images/43b93/43b9350cb5849db140f7d3ac230a1874361bd2a9" alt="../_images/Variational_Autoencoders_13_1.png"
import matplotlib.pyplot as plt
def plot_latent(encoder, decoder):
# display a n*n 2D manifold of digits
n = 20
digit_size = 28
scale = 1.0
figsize = 15
figure = np.zeros((digit_size * n, digit_size * n))
# linearly spaced coordinates corresponding to the 2D plot
# of digit classes in the latent space
grid_x = np.linspace(-scale, scale, n)
grid_y = np.linspace(-scale, scale, n)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = decoder.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[
i * digit_size : (i + 1) * digit_size,
j * digit_size : (j + 1) * digit_size,
] = digit
plt.figure(figsize=(figsize, figsize))
start_range = digit_size // 2
end_range = n * digit_size + start_range + 1
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap="Greys_r")
plt.show()
plot_latent(encoder, decoder)
data:image/s3,"s3://crabby-images/2593a/2593af14d215c11a62c6909d4f9533c678d005a3" alt="../_images/Variational_Autoencoders_14_0.png"