library(keras)
Map each image, 28-by-28, to a two-dimensional Gaussian distribution (latent_dim = 2L
) with two-dim mean (z_mean
) and two-dim variance (exp(z_log_var
).
K <- backend()
img_shape <- c(28, 28, 1)
batch_size <- 16
latent_dim <- 2L
input_img <- layer_input(shape = img_shape)
x <- input_img %>%
layer_conv_2d(filters = 32, kernel_size = 3, padding = "same",
activation = "relu") %>%
layer_conv_2d(filters = 64, kernel_size = 3, padding = "same",
activation = "relu", strides = c(2, 2)) %>%
layer_conv_2d(filters = 64, kernel_size = 3, padding = "same",
activation = "relu") %>%
layer_conv_2d(filters = 64, kernel_size = 3, padding = "same",
activation = "relu")
shape_before_flattening <- K$int_shape(x)
x <- x %>%
layer_flatten() %>%
layer_dense(units = 32, activation = "relu")
z_mean <- x %>%
layer_dense(units = latent_dim)
z_log_var <- x %>%
layer_dense(units = latent_dim)
sampling <- function(args) {
c(z_mean, z_log_var) %<-% args
epsilon <- K$random_normal(shape = list(K$shape(z_mean)[1], latent_dim),
mean = 0, stddev = 1)
z_mean + K$exp(z_log_var) * epsilon
}
z <- list(z_mean, z_log_var) %>%
layer_lambda(sampling)
Map a point in the two-dim latent space to a 28-by-28 image.
decoder_input <- layer_input(K$int_shape(z)[-1])
x <- decoder_input %>%
layer_dense(units = prod(as.integer(shape_before_flattening[-1])),
activation = "relu") %>%
layer_reshape(target_shape = shape_before_flattening[-1]) %>%
layer_conv_2d_transpose(filters = 32, kernel_size = 3, padding = "same",
activation = "relu", strides = c(2, 2)) %>%
layer_conv_2d(filters = 1, kernel_size = 3, padding = "same",
activation = "sigmoid")
decoder <- keras_model(decoder_input, x)
summary(decoder)
## ___________________________________________________________________________
## Layer (type) Output Shape Param #
## ===========================================================================
## input_2 (InputLayer) (None, 2) 0
## ___________________________________________________________________________
## dense_4 (Dense) (None, 12544) 37632
## ___________________________________________________________________________
## reshape_1 (Reshape) (None, 14, 14, 64) 0
## ___________________________________________________________________________
## conv2d_transpose_1 (Conv2DTransp (None, 28, 28, 32) 18464
## ___________________________________________________________________________
## conv2d_5 (Conv2D) (None, 28, 28, 1) 289
## ===========================================================================
## Total params: 56,385
## Trainable params: 56,385
## Non-trainable params: 0
## ___________________________________________________________________________
z_decoded <- decoder(z)
library(R6)
CustomVariationalLayer <- R6Class("CustomVariationalLayer",
inherit = KerasLayer,
public = list(
vae_loss = function(x, z_decoded) {
x <- K$flatten(x)
z_decoded <- K$flatten(z_decoded)
xent_loss <- metric_binary_crossentropy(x, z_decoded)
kl_loss <- -5e-4 * K$mean(
1 + z_log_var - K$square(z_mean) - K$exp(z_log_var),
axis = -1L)
K$mean(xent_loss + kl_loss)
},
call = function(inputs, mask = NULL) {
x <- inputs[[1]]
z_decoded <- inputs[[2]]
loss <- self$vae_loss(x, z_decoded)
self$add_loss(loss, inputs = inputs)
x
}
)
)
layer_variational <- function(object) {
create_layer(CustomVariationalLayer, object, list())
}
y <- list(input_img, z_decoded) %>%
layer_variational()
Specify loss and optimizer
vae <- keras_model(input_img, y)
vae %>% compile(
optimizer = "rmsprop",
loss = NULL
)
mnist <- dataset_mnist()
c(c(x_train, y_train), c(x_test, y_test)) %<-% mnist
x_train <- x_train / 255
x_train <- array_reshape(x_train, dim =c(dim(x_train), 1))
x_test <- x_test / 255
x_test <- array_reshape(x_test, dim =c(dim(x_test), 1))
vae %>% fit(
x = x_train, y = NULL,
epochs = 10,
batch_size = 64,
validation_data = list(x_test, NULL)
)
n <- 10
digit_size <- 28
grid_x <- qnorm(seq(0.01, 0.99, length.out = n))
grid_y <- qnorm(seq(0.01, 0.99, length.out = n))
op <- par(mfrow = c(n, n), mar = c(0,0,0,0), bg = "black")
for (i in 1:length(grid_x)) {
yi <- grid_x[[i]]
for (j in 1:length(grid_y)) {
xi <- grid_y[[j]]
z_sample <- matrix(c(xi, yi), nrow = 1, ncol = 2)
z_sample <- t(replicate(batch_size, z_sample, simplify = "matrix"))
x_decoded <- decoder %>% predict(z_sample, batch_size = batch_size)
digit <- array_reshape(x_decoded[1,,,], dim = c(digit_size, digit_size))
plot(as.raster(digit))
}
}