# Packages ----------------------------------------------------------------
library(torch)
library(torchvision)
library(luz)
# Datasets and loaders ----------------------------------------------------
dir <- "./mnist" # caching directory
triplet_mnist_dataset <- dataset(
inherit = mnist_dataset,
.getitem = function(index) {
anchor <- self$data[index, ,]
label <- self$targets[index]
negative <- self$data[sample(which(self$targets != label), 1),,]
positive <- self$data[sample(which(self$targets == label), 1),,]
list(
list( # input is a list with 3 images.
anchor = self$transform(anchor),
negative = self$transform(negative),
positive = self$transform(positive)
),
list() # no 'target'
)
}
)
train_ds <- triplet_mnist_dataset(
dir,
download = TRUE,
transform = transform_to_tensor
)
test_ds <- triplet_mnist_dataset(
dir,
train = FALSE,
transform = transform_to_tensor
)
train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
test_dl <- dataloader(test_ds, batch_size = 128)
# Define the network ------------------------------------------------------
net <- nn_module(
"Net",
initialize = function(embedding_dim) {
self$conv1 <- nn_conv2d(1, 32, 3, 1)
self$conv2 <- nn_conv2d(32, 64, 3, 1)
self$dropout1 <- nn_dropout(0.25)
self$dropout2 <- nn_dropout(0.5)
self$fc1 <- nn_linear(9216, 512)
self$fc2 <- nn_linear(512, embedding_dim)
},
forward = function(x) {
x <- self$conv1(x)
x <- nnf_relu(x)
x <- self$conv2(x)
x <- nnf_relu(x)
x <- nnf_max_pool2d(x, 2)
x <- self$dropout1(x)
x <- torch_flatten(x, start_dim = 2)
x <- self$fc1(x)
x <- nnf_relu(x)
x <- self$dropout2(x)
x <- self$fc2(x)
x
}
)
triplet_model <- torch::nn_module(
initialize = function(embedding_dim = 2, margin = 1) {
self$embedding <- net(embedding_dim = embedding_dim)
self$criterion <- nn_triplet_margin_loss(margin = margin)
},
loss = function(input, ...) {
embeds <- lapply(input, self$embedding)
self$criterion(
embeds$anchor,
embeds$positive,
embeds$negative
)
}
)
fitted <- triplet_model %>%
setup(optimizer = optim_adam) %>%
set_hparams(embedding_dim = 2) %>%
fit(train_dl, epochs = 10, valid_data = test_dl)
# Serializing
luz_save(fitted, "triplet.pt")