Text classification from scratch
Source:vignettes/examples/text-classification.Rmd
text-classification.Rmd
This example is a port of ‘Text classification from scratch’ from Keras documentation by Mark Omerick and François Chollet.
First we implement a torch dataset that downloads and pre-process the data. The initialize method is called when we instantiate a dataset. Our implementation:
- Downloads the IMDB dataset if it doesn’t exist in the
root
directory. - Extracts the files into
root
. - Creates a tokenizer using the files in the training set.
We also implement the .getitem
method that is used to
extract a single element from the dataset and pre-process the file
contents.
library(torch)
library(tok)
library(luz)
vocab_size <- 20000 # maximum number of items in the vocabulary
output_length <- 500 # padding and truncation length.
embedding_dim <- 128 # size of the embedding vectors
imdb_dataset <- dataset(
initialize = function(output_length, vocab_size, root, split = "train", download = TRUE) {
url <- "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
fpath <- file.path(root, "aclImdb")
# download if file doesn't exist yet
if (!dir.exists(fpath) && download) {
# download into tempdir, then extract and move to the root dir
withr::with_tempfile("file", {
download.file(url, file)
untar(file, exdir = root)
})
}
# now list files for the split
self$data <- rbind(
data.frame(
fname = list.files(file.path(fpath, split, "pos"), full.names = TRUE),
y = 1
),
data.frame(
fname = list.files(file.path(fpath, split, "neg"), full.names = TRUE),
y = 0
)
)
# train a tokenizer on the train data (if one doesn't exist yet)
tokenizer_path <- file.path(root, glue::glue("tokenizer-{vocab_size}.json"))
if (!file.exists(tokenizer_path)) {
self$tok <- tok::tokenizer$new(tok::model_bpe$new())
self$tok$pre_tokenizer <- tok::pre_tokenizer_whitespace$new()
files <- list.files(file.path(fpath, "train"), recursive = TRUE, full.names = TRUE)
self$tok$train(files, tok::trainer_bpe$new(vocab_size = vocab_size))
self$tok$save(tokenizer_path)
} else {
self$tok <- tok::tokenizer$from_file(tokenizer_path)
}
self$tok$enable_padding(length = output_length)
self$tok$enable_truncation(max_length = output_length)
},
.getitem = function(i) {
item <- self$data[i,]
# takes item i, reads the file content into a char string
# then makes everything lower case and removes html + punctuaction
# next uses the tokenizer to encode the text.
text <- item$fname %>%
readr::read_file() %>%
stringr::str_to_lower() %>%
stringr::str_replace_all("<br />", " ") %>%
stringr::str_remove_all("[:punct:]") %>%
self$tok$encode()
list(
x = text$ids + 1L,
y = item$y
)
},
.length = function() {
nrow(self$data)
}
)
train_ds <- imdb_dataset(output_length, vocab_size, "./imdb", split = "train")
test_ds <- imdb_dataset(output_length, vocab_size, "./imdb", split = "test")
We now define the model we want to train. The model is a 1D convnet starting with an embedding layer and we plug a classifier at the output.
model <- nn_module(
initialize = function(vocab_size, embedding_dim) {
self$embedding <- nn_sequential(
nn_embedding(num_embeddings = vocab_size, embedding_dim = embedding_dim),
nn_dropout(0.5)
)
self$convs <- nn_sequential(
nn_conv1d(embedding_dim, 128, kernel_size = 7, stride = 3, padding = "valid"),
nn_relu(),
nn_conv1d(128, 128, kernel_size = 7, stride = 3, padding = "valid"),
nn_relu(),
nn_adaptive_max_pool2d(c(128, 1)) # reduces the length dimension
)
self$classifier <- nn_sequential(
nn_flatten(),
nn_linear(128, 128),
nn_relu(),
nn_dropout(0.5),
nn_linear(128, 1)
)
},
forward = function(x) {
emb <- self$embedding(x)
out <- emb$transpose(2, 3) %>%
self$convs() %>%
self$classifier()
# we drop the last so we get (B) instead of (B, 1)
out$squeeze(2)
}
)
# test the model for a single example batch
# m <- model(vocab_size, embedding_dim)
# x <- torch_randint(1, 20000, size = c(32, 500), dtype = "int")
# m(x)
We can finally train the model:
fitted_model <- model %>%
setup(
loss = nnf_binary_cross_entropy_with_logits,
optimizer = optim_adam,
metrics = luz_metric_binary_accuracy_with_logits()
) %>%
set_hparams(vocab_size = vocab_size, embedding_dim = embedding_dim) %>%
fit(train_ds, epochs = 3)
We can finally obtain the metrics on the test dataset:
Remember that in order to predict for texts, we need make the same pre-processing as used in the dataset definition.