# Packages ----------------------------------------------------------------
library(torch)
library(torchvision)
library(torchdatasets)
library(luz)
# Datasets and loaders ----------------------------------------------------
dir <- "./pets" #caching directory
# A light wrapper around the `oxford_pet_dataset` that resizes and transforms
# input images and masks to the specified `size` and introduces the `augmentation`
# argument, allowing us to specify transformations that must be synced between
# images and masks, eg. flipping, cropping, etc.
pet_dataset <- torch::dataset(
inherit = oxford_pet_dataset,
initialize = function(..., augmentation = NULL, size = c(224, 224)) {
input_transform <- function(x) {
x %>%
transform_to_tensor() %>%
transform_resize(size)
}
target_transform <- function(x) {
x <- torch_tensor(x, dtype = torch_long())
x <- x[newaxis,..]
x <- transform_resize(x, size, interpolation = 0)
x[1,..]
}
self$split <- split
super$initialize(
...,
transform = input_transform,
target_transform = target_transform
)
if (is.null(augmentation))
self$augmentation <- function(...) {list(...)}
else
self$augmentation <- augmentation
},
.getitem = function(i) {
items <- super$.getitem(i)
do.call(self$augmentation, items)
}
)
train_ds <- pet_dataset(
dir,
download = TRUE,
split = "train"
)
valid_ds <- pet_dataset(
dir,
download = TRUE,
split = "valid"
)
train_dl <- dataloader(train_ds, batch_size = 32, shuffle = TRUE)
valid_dl <- dataloader(valid_ds, batch_size = 32)
# Define the network ------------------------------------------------------
# We use a pre-trained mobile net encoder. We take intermediate layers to use
# in the skip connections.
encoder <- torch::nn_module(
initialize = function() {
model <- model_mobilenet_v2(pretrained = TRUE)
self$stages <- nn_module_list(list(
nn_identity(),
model$features[1:2],
model$features[3:4],
model$features[5:7],
model$features[8:14],
model$features[15:18]
))
for (par in self$parameters) {
par$requires_grad_(FALSE)
}
},
forward = function(x) {
features <- list()
for (i in 1:length(self$stages)) {
x <- self$stages[[i]](x)
features[[length(features) + 1]] <- x
}
features
}
)
# The decoder blocks are composed of a upsample layer + a convolution
# with same padding.
decoder_block <- nn_module(
initialize = function(in_channels, skip_channels, out_channels) {
self$upsample <- nn_conv_transpose2d(
in_channels = in_channels,
out_channels = out_channels,
kernel_size = 2,
stride = 2
)
self$activation <- nn_relu()
self$conv <- nn_conv2d(
in_channels = out_channels + skip_channels,
out_channels = out_channels,
kernel_size = 3,
padding = "same"
)
},
forward = function(x, skip) {
x <- x %>%
self$upsample() %>%
self$activation()
input <- torch_cat(list(x, skip), dim = 2)
input %>%
self$conv() %>%
self$activation()
}
)
# We build the decoder by making a sequence of `decoder_blocks` matching
# the sizes to be compatible with the encoder sizes.
decoder <- nn_module(
initialize = function(
decoder_channels = c(256, 128, 64, 32, 16),
encoder_channels = c(16, 24, 32, 96, 320)
) {
encoder_channels <- rev(encoder_channels)
skip_channels <- c(encoder_channels[-1], 3)
in_channels <- c(encoder_channels[1], decoder_channels)
depth <- length(encoder_channels)
self$blocks <- nn_module_list()
for (i in seq_len(depth)) {
self$blocks$append(decoder_block(
in_channels = in_channels[i],
skip_channels = skip_channels[i],
out_channels = decoder_channels[i]
))
}
},
forward = function(features) {
features <- rev(features)
x <- features[[1]]
for (i in seq_along(self$blocks)) {
x <- self$blocks[[i]](x, features[[i+1]])
}
x
}
)
# FInally the model is the composition of encoder and decoder + an output
# layer that will produce the distribution for each one of the possible
# classes.
model <- nn_module(
initialize = function() {
self$encoder <- encoder()
self$decoder <- decoder()
self$output <- nn_sequential(
nn_conv2d(16, 3, 3, padding = "same")
)
},
forward = function(x) {
x %>%
self$encoder() %>%
self$decoder() %>%
self$output()
}
)
# Train ---------------------------------------------
# We train using the cross entropy loss. We could have used the dice loss
# too, but it's harder to optimize.
model <- model %>%
setup(optimizer = optim_adam, loss = nn_cross_entropy_loss())
f <- lr_finder(model, train_dl)
plot(f)
fitted <- model %>%
set_opt_hparams(lr = 1e-3) %>%
fit(train_dl, epochs = 10, valid_data = valid_dl)
plot(fitted)
# Plot validation image ---------------------
library(raster)
preds <- predict(fitted, dataloader(dataset_subset(valid_ds, 2)))
mask <- as.array(torch_argmax(preds[1,..], 1)$to(device = "cpu"))
mask <- raster::ratify(raster::raster(mask))
img <- raster::brick(as.array(valid_ds[2][[1]]$permute(c(2,3,1))))
raster::plotRGB(img, scale = 1)
plot(mask, alpha = 0.4, legend = FALSE, axes = FALSE, add = TRUE)