Skip to contents
# Packages ----------------------------------------------------------------
library(torch)
library(torchvision)
library(luz)

# Datasets and loaders ----------------------------------------------------

dir <- "~/Downloads/mnist" # caching directory

train_ds <- mnist_dataset(
  dir,
  download = TRUE,
  transform = transform_to_tensor
)

test_ds <- mnist_dataset(
  dir,
  train = FALSE,
  transform = transform_to_tensor
)

train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
test_dl <- dataloader(test_ds, batch_size = 128)


# Building the network ---------------------------------------------------

net <- nn_module(
  "Net",
  initialize = function(accumulate_batches = 2) {
    self$conv1 <- nn_conv2d(1, 32, 3, 1)
    self$conv2 <- nn_conv2d(32, 64, 3, 1)
    self$dropout1 <- nn_dropout2d(0.25)
    self$dropout2 <- nn_dropout2d(0.5)
    self$fc1 <- nn_linear(9216, 128)
    self$fc2 <- nn_linear(128, 10)

    self$accumulate_batches <- accumulate_batches
  },
  forward = function(x) {
    x <- self$conv1(x)
    x <- nnf_relu(x)
    x <- self$conv2(x)
    x <- nnf_relu(x)
    x <- nnf_max_pool2d(x, 2)
    x <- self$dropout1(x)
    x <- torch_flatten(x, start_dim = 2)
    x <- self$fc1(x)
    x <- nnf_relu(x)
    x <- self$dropout2(x)
    x <- self$fc2(x)
    x
  },
  step = function() {
    ctx$pred <- ctx$model(ctx$input)
    loss <- ctx$model$loss(ctx$pred, ctx$target)

    if (ctx$training) {
      loss <- loss/self$accumulate_batches
      loss$backward()
    }

    if (ctx$training && (ctx$iter %% self$accumulate_batches == 0)) {
      opt <- ctx$optimizers[[1]]
      opt$step()
      opt$zero_grad()
    }
  }
)

# Train -------------------------------------------------------------------

fitted <- net %>%
  set_hparams(accumulate_batches = 10) %>%
  setup(
    loss = nn_cross_entropy_loss(),
    optimizer = torch::optim_adam,
    metrics = list(
      luz_metric_accuracy()
    )
  ) %>%
  fit(train_dl, epochs = 10, valid_data = test_dl)


# Serialization -----------------------------------------------------------

luz_save(fitted, "mnist-virtual-batch_size.pt")