# Packages ----------------------------------------------------------------
library(torch)
library(torchvision)
library(luz)
library(dotty)
# Data --------------------------------------------------------------------
data <- torchvision::mnist_dataset(
download = TRUE,
transform = \(x) torch_tensor(x)$view(c(-1)) / 255
)
# Encoder -----------------------------------------------------------------
encoder <- nn_module(
initialize = function(d, latent_d) {
self$fc1 <- nn_linear(d, d / 2)
self$fc_mean <- nn_linear(d / 2, latent_d)
self$fc_logvar <- nn_linear(d / 2, latent_d)
},
forward = function(x) {
h <- nnf_relu(self$fc1(x))
m <- self$fc_mean(h)
logvar <- self$fc_logvar(h)
list(m, logvar)
}
)
# Autoencoder -------------------------------------------------------------
autoencoder <- nn_module(
initialize = function(d, latent_d) {
self$encoder <- encoder(d, latent_d)
self$decoder <- nn_sequential(
nn_linear(latent_d, d / 2),
nn_relu(),
nn_linear(d / 2, d),
nn_sigmoid()
)
},
loss = function(x, ...) {
.[m, logvar] <- self$encoder(x)
z <- m + torch_exp(0.5 * logvar) * torch_randn_like(logvar)
x_recon <- self$decoder(z)
recon_loss <- nnf_mse_loss(x_recon, x, "sum")
kld_loss <- -0.5 * torch_sum(1 + logvar - m$pow(2) - logvar$exp())
recon_loss + kld_loss
},
predict = function(x, type=c("encode", "decode")) {
type <- match.arg(type)
if (type=="encode") {
.[m, logvar] <- self$encoder(x)
m
} else if (type=="decode") {
self$decoder(x)
}
}
)
# Fit ---------------------------------------------------------------------
model <- autoencoder |>
setup(
optimizer = optim_adam
) |>
set_hparams(
d = 28*28,
latent_d = 2
) |>
set_opt_hparams(
lr = 1e-3
) |>
fit(
data,
epochs = 20,
dataloader_options = list(batch_size = 128, shuffle = TRUE)
)
# Latent representation ----------------------------------------------------
preds <- predict(model, data)
df <- as.data.frame(as.matrix(preds))
df$labels <- sapply(seq_len(length(mnist)), \(i) mnist[i]$y)
# Visualization ------------------------------------------------------------
library(ggplot2)
ggplot(df, aes(x = V1, y = V2, color = as.factor(labels))) +
geom_point(size=0.1)
# Visualize the latent space -----------------------------------------------
grid_size <- 15
v1 <- seq(min(df$V1), max(df$V1), length.out = grid_size)
v2 <- seq(min(df$V2), max(df$V2), length.out = grid_size)
preds <- v1 |>
lapply(\(v) predict(model, cbind(v, v2), type="decode")) |>
lapply(\(p) p$view(c(-1, 28, 28))$unbind(1)) |>
lapply(\(imgs) torch_cat(imgs)) |>
torch_cat(dim = 2) |>
as.array()
plot(as.raster(preds))