Skip to contents
# Packages ----------------------------------------------------------------
library(torch)
library(torchvision)
library(luz)
library(dotty)

# Data --------------------------------------------------------------------

data <- torchvision::mnist_dataset(
  download = TRUE, 
  transform = \(x) torch_tensor(x)$view(c(-1)) / 255
)

# Encoder -----------------------------------------------------------------

encoder <- nn_module(
  initialize = function(d, latent_d) {
    self$fc1 <- nn_linear(d, d / 2)
    self$fc_mean <- nn_linear(d / 2, latent_d)
    self$fc_logvar <- nn_linear(d / 2, latent_d)
  },
  forward = function(x) {
    h <- nnf_relu(self$fc1(x))
    m <- self$fc_mean(h)
    logvar <- self$fc_logvar(h)
    list(m, logvar)
  }
)


# Autoencoder -------------------------------------------------------------

autoencoder <- nn_module(
  initialize = function(d, latent_d) {
    self$encoder <- encoder(d, latent_d)
    self$decoder <- nn_sequential(
      nn_linear(latent_d, d / 2),
      nn_relu(),
      nn_linear(d / 2, d),
      nn_sigmoid()
    )
  },
  loss = function(x, ...) {
    .[m, logvar] <- self$encoder(x)
    z <- m + torch_exp(0.5 * logvar) * torch_randn_like(logvar)
    x_recon <- self$decoder(z)
    
    recon_loss <- nnf_mse_loss(x_recon, x, "sum")
    kld_loss <- -0.5 * torch_sum(1 + logvar - m$pow(2) - logvar$exp())
    recon_loss + kld_loss
  },  
  predict = function(x, type=c("encode", "decode")) {
    type <- match.arg(type)
    if (type=="encode") {
      .[m, logvar] <- self$encoder(x)
      m
    } else if (type=="decode") {
      self$decoder(x)
    } 
  }
)


# Fit ---------------------------------------------------------------------

model <- autoencoder |> 
  setup(
    optimizer = optim_adam
  ) |>
  set_hparams(
    d = 28*28, 
    latent_d = 2 
  ) |>
  set_opt_hparams(
    lr = 1e-3
  ) |> 
  fit(
    data,
    epochs = 20,
    dataloader_options = list(batch_size = 128, shuffle = TRUE)
  )

# Latent representation ----------------------------------------------------

preds <- predict(model, data)

df <- as.data.frame(as.matrix(preds))
df$labels <- sapply(seq_len(length(mnist)), \(i) mnist[i]$y)

# Visualization ------------------------------------------------------------

library(ggplot2)
ggplot(df, aes(x = V1, y = V2, color = as.factor(labels))) +
  geom_point(size=0.1)

# Visualize the latent space -----------------------------------------------

grid_size <- 15
v1 <- seq(min(df$V1), max(df$V1), length.out = grid_size)
v2 <- seq(min(df$V2), max(df$V2), length.out = grid_size)

preds <- v1 |>
  lapply(\(v) predict(model, cbind(v, v2), type="decode")) |>
  lapply(\(p) p$view(c(-1, 28, 28))$unbind(1)) |> 
  lapply(\(imgs) torch_cat(imgs)) |>
  torch_cat(dim = 2) |> 
  as.array()

plot(as.raster(preds))