Skip to contents
# Packages ----------------------------------------------------------------
library(torch)
library(torchvision)
library(torchdatasets)
library(luz)

set.seed(1)
torch_manual_seed(1)

# Datasets and loaders ----------------------------------------------------

dir <- "./dogs-vs-cats" # caching directory

ds <- torchdatasets::dogs_vs_cats_dataset(
  dir,
  download = TRUE,
  transform = . %>%
    torchvision::transform_to_tensor() %>%
    torchvision::transform_resize(size = c(224, 224)) %>% 
    torchvision::transform_normalize(rep(0.5, 3), rep(0.5, 3)),
  target_transform = function(x) as.double(x) - 1
)

train_id <- sample.int(length(ds), size = 0.7*length(ds))
train_ds <- dataset_subset(ds, indices = train_id)
valid_ds <- dataset_subset(ds, indices = which(!seq_along(ds) %in% train_id))

train_dl <- dataloader(train_ds, batch_size = 64, shuffle = TRUE, num_workers = 4)
valid_dl <- dataloader(valid_ds, batch_size = 64, num_workers = 4)

# Building the network ---------------------------------------------------

net <- torch::nn_module(
  initialize = function(num_classes) {
    self$model <- model_alexnet(pretrained = TRUE)

    for (par in self$parameters) {
      par$requires_grad_(FALSE)
    }

    self$model$classifier <- nn_sequential(
      nn_dropout(0.5),
      nn_linear(9216, 512),
      nn_relu(),
      nn_linear(512, 256),
      nn_relu(),
      nn_linear(256, num_classes)
    )
  },
  forward = function(x) {
    self$model(x)[,1]
  }
)

# Train -------------------------------------------------------------------

fitted <- net %>%
  setup(
    loss = nn_bce_with_logits_loss(),
    optimizer = optim_adam,
    metrics = list(
      luz_metric_binary_accuracy_with_logits()
    )
  ) %>%
  set_hparams(num_classes = 1) %>%
  set_opt_hparams(lr = 0.01) %>%
  fit(train_dl, epochs = 5, valid_data = valid_dl, verbose = TRUE)

# Make predictions --------------------------------------------------------

preds <- torch_sigmoid(predict(fitted, valid_dl))

# Serialization -----------------------------------------------------------
luz_save(fitted, "model-dogs-and-cats.pt")