Skip to contents
# Packages ----------------------------------------------------------------
library(torch)
library(torchvision)
library(torchdatasets)
library(luz)

# Datasets and loaders ----------------------------------------------------

dir <- "./pets" #caching directory

# A light wrapper around the `oxford_pet_dataset` that resizes and transforms 
# input images and masks to the specified `size` and introduces the `augmentation`
# argument, allowing us to specify transformations that must be synced between 
# images and masks, eg. flipping, cropping, etc.
pet_dataset <- torch::dataset(
  inherit = oxford_pet_dataset,
  initialize = function(..., augmentation = NULL, size = c(224, 224)) {
    
    input_transform <- function(x) {
      x %>% 
        transform_to_tensor() %>% 
        transform_resize(size)
    }
    
    target_transform <- function(x) {
      x <- torch_tensor(x, dtype = torch_long())
      x <- x[newaxis,..]
      x <- transform_resize(x, size, interpolation = 0)
      x[1,..]
    }
    
    self$split <- split
    super$initialize(
      ..., 
      transform = input_transform,
      target_transform = target_transform
    )
    
    if (is.null(augmentation))
      self$augmentation <- function(...) {list(...)}
    else
      self$augmentation <- augmentation
    
  },
  .getitem = function(i) {
    items <- super$.getitem(i)  
    do.call(self$augmentation, items)
  }
)

train_ds <- pet_dataset(
  dir,
  download = TRUE,
  split = "train"
)

valid_ds <- pet_dataset(
  dir,
  download = TRUE,
  split = "valid"
)


train_dl <- dataloader(train_ds, batch_size = 32, shuffle = TRUE)
valid_dl <- dataloader(valid_ds, batch_size = 32)

# Define the network ------------------------------------------------------

# We use a pre-trained mobile net encoder. We take intermediate layers to use
# in the skip connections.
encoder <- torch::nn_module(
  initialize = function() {
    model <- model_mobilenet_v2(pretrained = TRUE)
    self$stages <- nn_module_list(list(
      nn_identity(),
      model$features[1:2],
      model$features[3:4],
      model$features[5:7],
      model$features[8:14],
      model$features[15:18]
    ))
    
    for (par in self$parameters) {
      par$requires_grad_(FALSE)
    }
    
  },
  forward = function(x) {
    features <- list()
    for (i in 1:length(self$stages)) {
      x <- self$stages[[i]](x)
      features[[length(features) + 1]] <- x
    }
    features
  }
)

# The decoder blocks are composed of a upsample layer + a convolution
# with same padding.
decoder_block <- nn_module(
  initialize = function(in_channels, skip_channels, out_channels) {
    self$upsample <- nn_conv_transpose2d(
      in_channels = in_channels, 
      out_channels = out_channels,
      kernel_size = 2,
      stride = 2
    )
    self$activation <- nn_relu()
    self$conv <- nn_conv2d(
      in_channels = out_channels + skip_channels, 
      out_channels = out_channels,
      kernel_size = 3,
      padding = "same"
    )
  },
  forward = function(x, skip) {
    x <- x %>% 
      self$upsample() %>% 
      self$activation()
    
    input <- torch_cat(list(x, skip), dim = 2)
    
    input %>% 
      self$conv() %>% 
      self$activation()
  }
)

# We build the decoder by making a sequence of `decoder_blocks` matching
# the sizes to be compatible with the encoder sizes.
decoder <- nn_module(
  initialize = function(
    decoder_channels = c(256, 128, 64, 32, 16),
    encoder_channels = c(16, 24, 32, 96, 320)
  ) {
    
    encoder_channels <- rev(encoder_channels)
    skip_channels <- c(encoder_channels[-1], 3)
    in_channels <- c(encoder_channels[1], decoder_channels)
    
    depth <- length(encoder_channels)
    
    self$blocks <- nn_module_list()
    for (i in seq_len(depth)) {
      self$blocks$append(decoder_block(
        in_channels = in_channels[i],
        skip_channels = skip_channels[i],
        out_channels = decoder_channels[i]
      ))
    }
    
  },
  forward = function(features) {
    features <- rev(features)
    x <- features[[1]]
    for (i in seq_along(self$blocks)) {
      x <- self$blocks[[i]](x, features[[i+1]])
    }
    x
  }
)

# FInally the model is the composition of encoder and decoder + an output 
# layer that will produce the distribution for each one of the possible 
# classes.
model <- nn_module(
  initialize = function() {
    self$encoder <- encoder()
    self$decoder <- decoder()
    self$output <- nn_sequential(
      nn_conv2d(16, 3, 3, padding = "same")
    )
  },
  forward = function(x) {
    x %>% 
      self$encoder() %>% 
      self$decoder() %>% 
      self$output()
  }
)

# Train ---------------------------------------------

# We train using the cross entropy loss. We could have used the dice loss
# too, but it's harder to optimize.
model <- model %>%
  setup(optimizer = optim_adam, loss = nn_cross_entropy_loss())

f <- lr_finder(model, train_dl)
plot(f)

fitted <- model %>%
  set_opt_hparams(lr = 1e-3) %>% 
  fit(train_dl, epochs = 10, valid_data = valid_dl)

plot(fitted)

# Plot validation image ---------------------
library(raster)

preds <- predict(fitted, dataloader(dataset_subset(valid_ds, 2)))

mask <- as.array(torch_argmax(preds[1,..], 1)$to(device = "cpu"))
mask <- raster::ratify(raster::raster(mask))

img <- raster::brick(as.array(valid_ds[2][[1]]$permute(c(2,3,1))))
raster::plotRGB(img, scale = 1)
plot(mask, alpha = 0.4, legend = FALSE, axes = FALSE, add = TRUE)