Load Libraries

library(keras)

VAE Encoder Network

Map each image, 28-by-28, to a two-dimensional Gaussian distribution (latent_dim = 2L) with two-dim mean (z_mean) and two-dim variance (exp(z_log_var).

K <- backend()
img_shape <- c(28, 28, 1)
batch_size <- 16
latent_dim <- 2L
input_img <- layer_input(shape = img_shape)
x <- input_img %>%
  layer_conv_2d(filters = 32, kernel_size = 3, padding = "same",
                activation = "relu") %>%
  layer_conv_2d(filters = 64, kernel_size = 3, padding = "same",
                activation = "relu", strides = c(2, 2)) %>%
  layer_conv_2d(filters = 64, kernel_size = 3, padding = "same",
                activation = "relu") %>%
  layer_conv_2d(filters = 64, kernel_size = 3, padding = "same",
                activation = "relu")

shape_before_flattening <- K$int_shape(x)

x <- x %>%
  layer_flatten() %>%
  layer_dense(units = 32, activation = "relu")

z_mean <- x %>%
  layer_dense(units = latent_dim)

z_log_var <- x %>%
  layer_dense(units = latent_dim)

Latent Space Sampling

sampling <- function(args) {
  c(z_mean, z_log_var) %<-% args
  epsilon <- K$random_normal(shape = list(K$shape(z_mean)[1], latent_dim),
                             mean = 0, stddev = 1)
  z_mean + K$exp(z_log_var) * epsilon
}

z <- list(z_mean, z_log_var) %>%
  layer_lambda(sampling)

VAE Decoder Network

Map a point in the two-dim latent space to a 28-by-28 image.

decoder_input <- layer_input(K$int_shape(z)[-1])
x <- decoder_input %>%
  layer_dense(units = prod(as.integer(shape_before_flattening[-1])),
              activation = "relu") %>%
  layer_reshape(target_shape = shape_before_flattening[-1]) %>%
  layer_conv_2d_transpose(filters = 32, kernel_size = 3, padding = "same",
                          activation = "relu", strides = c(2, 2)) %>%
  layer_conv_2d(filters = 1, kernel_size = 3, padding = "same",
                activation = "sigmoid")
decoder <- keras_model(decoder_input, x)
summary(decoder)
## ___________________________________________________________________________
## Layer (type)                     Output Shape                  Param #     
## ===========================================================================
## input_2 (InputLayer)             (None, 2)                     0           
## ___________________________________________________________________________
## dense_4 (Dense)                  (None, 12544)                 37632       
## ___________________________________________________________________________
## reshape_1 (Reshape)              (None, 14, 14, 64)            0           
## ___________________________________________________________________________
## conv2d_transpose_1 (Conv2DTransp (None, 28, 28, 32)            18464       
## ___________________________________________________________________________
## conv2d_5 (Conv2D)                (None, 28, 28, 1)             289         
## ===========================================================================
## Total params: 56,385
## Trainable params: 56,385
## Non-trainable params: 0
## ___________________________________________________________________________
z_decoded <- decoder(z)

Custom Layer (used to compute the VAE loss)

library(R6)
CustomVariationalLayer <- R6Class("CustomVariationalLayer",
                                  inherit = KerasLayer,
                                  public = list(
                                    vae_loss = function(x, z_decoded) {
                                      x <- K$flatten(x)
                                      z_decoded <- K$flatten(z_decoded)
                                      xent_loss <- metric_binary_crossentropy(x, z_decoded)
                                      kl_loss <- -5e-4 * K$mean(
                                        1 + z_log_var - K$square(z_mean) - K$exp(z_log_var),
axis = -1L)
                                      K$mean(xent_loss + kl_loss)
                                      },
call = function(inputs, mask = NULL) {
  x <- inputs[[1]]
  z_decoded <- inputs[[2]]
  loss <- self$vae_loss(x, z_decoded)
  self$add_loss(loss, inputs = inputs)
  x
  } 
)
)

layer_variational <- function(object) {
  create_layer(CustomVariationalLayer, object, list())
}

y <- list(input_img, z_decoded) %>%
  layer_variational()

Train VAE on MNIST data

Specify loss and optimizer

vae <- keras_model(input_img, y)
vae %>% compile(
  optimizer = "rmsprop",
  loss = NULL
)
mnist <- dataset_mnist()
c(c(x_train, y_train), c(x_test, y_test)) %<-% mnist
x_train <- x_train / 255
x_train <- array_reshape(x_train, dim =c(dim(x_train), 1))
x_test <- x_test / 255
x_test <- array_reshape(x_test, dim =c(dim(x_test), 1))

vae %>% fit(
  x = x_train, y = NULL,
  epochs = 10,
  batch_size = 64,
  validation_data = list(x_test, NULL)
  )
n <- 10
digit_size <- 28
grid_x <- qnorm(seq(0.01, 0.99, length.out = n))
grid_y <- qnorm(seq(0.01, 0.99, length.out = n))
op <- par(mfrow = c(n, n), mar = c(0,0,0,0), bg = "black")
for (i in 1:length(grid_x)) {
  yi <- grid_x[[i]]
  for (j in 1:length(grid_y)) {
    xi <- grid_y[[j]]
    z_sample <- matrix(c(xi, yi), nrow = 1, ncol = 2)
    z_sample <- t(replicate(batch_size, z_sample, simplify = "matrix"))
    x_decoded <- decoder %>% predict(z_sample, batch_size = batch_size)
    digit <- array_reshape(x_decoded[1,,,], dim = c(digit_size, digit_size))
    plot(as.raster(digit))
  }
} 

par(op)