# Packages ----------------------------------------------------------------
library(torch)
library(torchvision)
library(luz)
# Datasets and loaders ----------------------------------------------------
dir <- "./mnist" # caching directory
# Modify the MNIST dataset so the target is identical to the input.
mnist_dataset2 <- torch::dataset(
inherit = mnist_dataset,
.getitem = function(i) {
output <- super$.getitem(i)
output$y <- output$x
output
}
)
train_ds <- mnist_dataset2(
dir,
download = TRUE,
transform = transform_to_tensor
)
test_ds <- mnist_dataset2(
dir,
train = FALSE,
transform = transform_to_tensor
)
train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
test_dl <- dataloader(test_ds, batch_size = 128)
# Building the network ---------------------------------------------------
net <- nn_module(
"Net",
initialize = function() {
self$encoder <- nn_sequential(
nn_conv2d(1, 6, kernel_size=5),
nn_relu(),
nn_conv2d(6, 16, kernel_size=5),
nn_relu()
)
self$decoder <- nn_sequential(
nn_conv_transpose2d(16, 6, kernel_size = 5),
nn_relu(),
nn_conv_transpose2d(6, 1, kernel_size = 5),
nn_sigmoid()
)
},
forward = function(x) {
x %>%
self$encoder() %>%
self$decoder()
},
predict = function(x) {
self$encoder(x) %>%
torch_flatten(start_dim = 2)
}
)
# Train -------------------------------------------------------------------
fitted <- net %>%
setup(
loss = nn_mse_loss(),
optimizer = optim_adam
) %>%
fit(train_dl, epochs = 1, valid_data = test_dl)
# Create predictions ------------------------------------------------------
preds <- predict(fitted, test_dl)
# Serialize ---------------------------------------------------------------
luz_save(fitted, "mnist-autoencoder.pt")