Skip to contents
# Packages ----------------------------------------------------------------
library(torch)
library(torchvision)
library(luz)

set.seed(1)
torch_manual_seed(1)

# Datasets and loaders ----------------------------------------------------

dir <- "./mnist" # caching directory

train_ds <- mnist_dataset(
  dir,
  download = TRUE,
  transform = transform_to_tensor
)

test_ds <- mnist_dataset(
  dir,
  train = FALSE,
  transform = transform_to_tensor
)

train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
test_dl <- dataloader(test_ds, batch_size = 128)


# Building the network ---------------------------------------------------

net <- nn_module(
  "Net",
  initialize = function() {
    self$conv1 <- nn_conv2d(1, 32, 3, 1)
    self$conv2 <- nn_conv2d(32, 64, 3, 1)
    self$dropout1 <- nn_dropout(0.25)
    self$dropout2 <- nn_dropout(0.5)
    self$fc1 <- nn_linear(9216, 128)
    self$fc2 <- nn_linear(128, 10)
  },
  forward = function(x) {
    x %>% 
      self$conv1() %>% 
      nnf_relu() %>% 
      self$conv2() %>% 
      nnf_relu() %>% 
      nnf_max_pool2d(2) %>% 
      self$dropout1() %>% 
      torch_flatten(start_dim = 2) %>% 
      self$fc1() %>% 
      nnf_relu() %>% 
      self$dropout2() %>% 
      self$fc2()
  }
)

# Train -------------------------------------------------------------------

fitted <- net %>%
  setup(
    # we must use a un-reduced loss to allow mixing up
    loss = nn_cross_entropy_loss(reduction = "none"), 
    optimizer = optim_adam,
    metrics = luz_metric_set(
      valid_metrics = luz_metric_accuracy()
    )
  ) %>%
  fit(
    train_dl, epochs = 10, valid_data = test_dl, 
    callbacks = list(
      luz_callback_mixup(
        alpha = 0.4, 
        auto_loss = TRUE # loss is automatically modified by the mixup callback
      )
    )
  )

# Making predictions ------------------------------------------------------

preds <- predict(fitted, test_dl)
preds$shape


# Serialization -----------------------------------------------------------

luz_save(fitted, "mnist-mixup.pt")