Skip to contents

This example is a port of ‘Text classification from scratch’ from Keras documentation by Mark Omerick and François Chollet.

First we implement a torch dataset that downloads and pre-process the data. The initialize method is called when we instantiate a dataset. Our implementation:

  • Downloads the IMDB dataset if it doesn’t exist in the root directory.
  • Extracts the files into root.
  • Creates a tokenizer using the files in the training set.

We also implement the .getitem method that is used to extract a single element from the dataset and pre-process the file contents.

library(torch)
library(tok)
library(luz)

vocab_size <- 20000 # maximum number of items in the vocabulary
output_length <- 500 # padding and truncation length.
embedding_dim <- 128 # size of the embedding vectors

imdb_dataset <- dataset(
  initialize = function(output_length, vocab_size, root, split = "train", download = TRUE) {
    
    url <- "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
    fpath <- file.path(root, "aclImdb")
    
    # download if file doesn't exist yet
    if (!dir.exists(fpath) && download) {
      # download into tempdir, then extract and move to the root dir
      withr::with_tempfile("file", {
        download.file(url, file)
        untar(file, exdir = root)
      })  
    }
    
    # now list files for the split
    self$data <- rbind(
      data.frame(
        fname = list.files(file.path(fpath, split, "pos"), full.names = TRUE),
        y = 1
      ),
      data.frame(
        fname = list.files(file.path(fpath, split, "neg"), full.names = TRUE),
        y = 0
      )
    )
    
    # train a tokenizer on the train data (if one doesn't exist yet)
    tokenizer_path <- file.path(root, glue::glue("tokenizer-{vocab_size}.json"))
    if (!file.exists(tokenizer_path)) {
      self$tok <- tok::tokenizer$new(tok::model_bpe$new())
      self$tok$pre_tokenizer <- tok::pre_tokenizer_whitespace$new()
      
      files <- list.files(file.path(fpath, "train"), recursive = TRUE, full.names = TRUE)
      self$tok$train(files, tok::trainer_bpe$new(vocab_size = vocab_size))
      
      self$tok$save(tokenizer_path)  
    } else {
      self$tok <- tok::tokenizer$from_file(tokenizer_path)
    }
    
    self$tok$enable_padding(length = output_length)
    self$tok$enable_truncation(max_length = output_length)
  },
  .getitem = function(i) {
    item <- self$data[i,]
    
    # takes item i, reads the file content into a char string
    # then makes everything lower case and removes html + punctuaction
    # next uses the tokenizer to encode the text.
    text <- item$fname %>% 
      readr::read_file() %>% 
      stringr::str_to_lower() %>% 
      stringr::str_replace_all("<br />", " ") %>% 
      stringr::str_remove_all("[:punct:]") %>% 
      self$tok$encode()
  
    list(
      x = text$ids + 1L,
      y = item$y
    )
  },
  .length = function() {
    nrow(self$data)
  }
)

train_ds <- imdb_dataset(output_length, vocab_size,  "./imdb", split = "train")
test_ds <- imdb_dataset(output_length, vocab_size,  "./imdb", split = "test")

We now define the model we want to train. The model is a 1D convnet starting with an embedding layer and we plug a classifier at the output.

model <- nn_module(
  initialize = function(vocab_size, embedding_dim) {
    self$embedding <- nn_sequential(
      nn_embedding(num_embeddings = vocab_size, embedding_dim = embedding_dim),
      nn_dropout(0.5)
    )

    self$convs <- nn_sequential(
      nn_conv1d(embedding_dim, 128, kernel_size = 7, stride = 3, padding = "valid"),
      nn_relu(),
      nn_conv1d(128, 128, kernel_size = 7, stride = 3, padding = "valid"),
      nn_relu(),
      nn_adaptive_max_pool2d(c(128, 1)) # reduces the length dimension
    )
    
    self$classifier <- nn_sequential(
      nn_flatten(),
      nn_linear(128, 128),
      nn_relu(),
      nn_dropout(0.5),
      nn_linear(128, 1)
    )
  },
  forward = function(x) {
    emb <- self$embedding(x)
    out <- emb$transpose(2, 3) %>% 
      self$convs() %>% 
      self$classifier()
    # we drop the last so we get (B) instead of (B, 1)
    out$squeeze(2)
  }
)

# test the model for a single example batch
# m <- model(vocab_size, embedding_dim)
# x <- torch_randint(1, 20000, size = c(32, 500), dtype = "int")
# m(x)

We can finally train the model:

fitted_model <- model %>% 
  setup(
    loss = nnf_binary_cross_entropy_with_logits,
    optimizer = optim_adam,
    metrics = luz_metric_binary_accuracy_with_logits()
  ) %>% 
  set_hparams(vocab_size = vocab_size, embedding_dim = embedding_dim) %>% 
  fit(train_ds, epochs = 3)

We can finally obtain the metrics on the test dataset:

fitted_model %>% evaluate(test_ds)

Remember that in order to predict for texts, we need make the same pre-processing as used in the dataset definition.