Reusable Keras code for building and training standard and supervised autoencoders. Examples use the MNIST dataset (784-dimensional inputs, 10 classes).

See Autoencoders for the conceptual background.


Data Preparation

Load and Normalize MNIST

from keras.datasets import mnist
import numpy as np
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 28*28).astype('float32') / 255.
x_test  = x_test.reshape(10000, 28*28).astype('float32') / 255.

One-Hot Encode Labels

def to_one_hot(y, num_class=10):
    results = np.zeros((len(y), num_class))
    for i, label in enumerate(y):
        results[i, label] = 1.
    return results
 
y_train_vec = to_one_hot(y_train)
y_test_vec  = to_one_hot(y_test)

Train / Validation Split

rand_indices  = np.random.permutation(60000)
train_indices = rand_indices[0:10000]
valid_indices = rand_indices[10000:20000]
 
x_tr  = x_train[train_indices, :]
y_tr  = y_train_vec[train_indices, :]
x_val = x_train[valid_indices, :]
y_val = y_train_vec[valid_indices, :]

Unsupervised Autoencoder

Build Model

from keras.layers import Input, Dense
from keras import models
 
input_img  = Input(shape=(784,), name='input_img')
enc_dense_1 = Dense(500, activation='tanh')(input_img)
enc_dense_2 = Dense(300, activation='tanh')(enc_dense_1)
enc_dense_3 = Dense(100, activation='tanh')(enc_dense_2)
bottleneck  = Dense(2,   activation='tanh')(enc_dense_3)   # 2D bottleneck
 
dec_dense_1 = Dense(100, activation='tanh')(bottleneck)
dec_dense_2 = Dense(300, activation='tanh')(dec_dense_1)
dec_dense_3 = Dense(500, activation='tanh')(dec_dense_2)
dec_dense_4 = Dense(784, activation='sigmoid')(dec_dense_3)  # sigmoid for [0,1] output
 
ae = models.Model(input_img, dec_dense_4)
ae.summary()

Compile and Train

from tensorflow.keras import optimizers
 
ae.compile(
    loss='binary_crossentropy',
    optimizer=optimizers.RMSprop(learning_rate=0.002, momentum=0.01)
)
 
history = ae.fit(
    x_tr, x_tr,
    batch_size=128,
    epochs=200,
    validation_data=(x_val, x_val)
)

Extract Encoder for Feature Visualization

ae_encoder   = models.Model(input_img, bottleneck)
encoded_test = ae_encoder.predict(x_test)  # shape: (10000, 2)

Supervised Autoencoder (SAE)

Build Model (with classifier branch off bottleneck)

from keras.layers import Input, Dense, Dropout
from keras import models
 
input_img   = Input(shape=(784,), name='input_img')
enc_dense_1 = Dense(500, activation='tanh')(input_img)
enc_dense_2 = Dense(300, activation='tanh')(enc_dense_1)
enc_dense_3 = Dense(100, activation='tanh')(enc_dense_2)
bottleneck  = Dense(2,   activation='tanh')(enc_dense_3)
 
# Decoder branch
dec_dense_1 = Dense(100, activation='tanh')(bottleneck)
dec_dense_2 = Dense(300, activation='tanh')(dec_dense_1)
dec_dense_3 = Dense(500, activation='tanh')(dec_dense_2)
dec_dense_4 = Dense(784, activation='sigmoid')(dec_dense_3)
 
# Classifier branch (with Dropout to reduce overfitting)
cls_1     = Dense(64,  activation='tanh')(bottleneck)
dropout_1 = Dropout(0.5)(cls_1)
cls_2     = Dense(128, activation='tanh')(dropout_1)
dropout_2 = Dropout(0.5)(cls_2)
cls_out   = Dense(10, activation='softmax')(dropout_2)
 
sae = models.Model(input_img, [dec_dense_4, cls_out])
sae.summary()

Compile with Multiple Losses

sae.compile(
    loss=['binary_crossentropy', 'categorical_crossentropy'],
    loss_weights=[1, 0.1],   # low classification weight reduces overfitting
    optimizer=optimizers.RMSprop(learning_rate=0.002, momentum=0.01)
)

Train with Multiple Targets

history = sae.fit(
    x_tr, [x_tr, y_tr],
    batch_size=128,
    epochs=30,
    validation_data=(x_val, [x_val, y_val])
)

Extract SAE Encoder and Evaluate Features

sae_encoder = models.Model(input_img, bottleneck)
f_tr  = sae_encoder.predict(x_tr)
f_val = sae_encoder.predict(x_val)
f_te  = sae_encoder.predict(x_test)
 
# Build a small classifier on top of the 2D features
input_feat = Input(shape=(2,))
h1  = Dense(128, activation='relu')(input_feat)
h2  = Dense(128, activation='relu')(h1)
out = Dense(10, activation='softmax')(h2)
 
clf = models.Model(input_feat, out)
clf.compile(
    loss='categorical_crossentropy',
    optimizer=optimizers.RMSprop(learning_rate=1e-4),
    metrics=['acc']
)
clf.fit(f_tr, y_tr, batch_size=32, epochs=30, validation_data=(f_val, y_val))

Visualization

Visualize Reconstructions

import matplotlib.pyplot as plt
import numpy as np
 
ae_output = ae.predict(x_test).reshape((10000, 28, 28))
 
fig, axes = plt.subplots(nrows=5, ncols=4, figsize=(4, 4))
for ax, i in zip(axes.flat, np.arange(20)):
    ax.imshow(ae_output[i], cmap='gray')
    ax.axis('off')
plt.tight_layout()
plt.show()

Visualize 2D Latent Space (Color-Coded by Class)

colors      = np.array(['r', 'g', 'b', 'm', 'c', 'k', 'y', 'purple', 'darkred', 'navy'])
colors_test = colors[y_test]
 
fig = plt.figure(figsize=(6, 6))
plt.scatter(encoded_test[:, 0], encoded_test[:, 1], s=10, c=colors_test, edgecolors=colors_test)
plt.axis('off')
plt.tight_layout()
plt.show()

Plot Training / Validation Loss

loss     = history.history['loss']
val_loss = history.history['val_loss']
epochs   = range(len(loss))
 
plt.plot(epochs, loss,     'bo', label='Training Loss')
plt.plot(epochs, val_loss, 'r',  label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()