Skip to contents

This example is an adaptation of the ‘Training a causal language model from scratch’ class from the Hugging Face NLP course.

library(torch)
library(tok)
library(luz)
library(minhub) # remotes::install_github("mlverse/minhub")
#library(tidyverse)
options(arrow.skip_nul = TRUE)
library(arrow)

Data

First step is to implement a torch dataset that gathers data and pre-process it into a format that is suitable for training the model.

That means that we need to:

  1. Download data
  2. Train a tokenizer for this dataset
  3. Be able to produce sequences of tokens in the format expected by the model

We are going to use 2 datasets available in Hugging Face Hub. The first contain all R packages source code available on CRAN. The second contains all R code that is available in GitHub data dumps. Both datasets are in the Parquet format. Following we implement a function that downloads and caches the data and then returns a single arrow table containing all data.

read_dataset <- function(source) {
  d <- source |>
    hfhub::hub_snapshot(repo_type = "dataset", allow_patterns = "parquet$") |>
    fs::path("data/r") |>
    arrow::open_dataset() |>
    dplyr::filter(stringr::str_detect(path, ".*\\.[rR]$")) |>
    dplyr::select(content) |>
    dplyr::mutate(content = arrow::cast(content, arrow::string())) |>
    dplyr::filter(!is.na(content)) |>
    dplyr::collect() %>%
    # the dataset contains invalid utf8 characters...
    # we need to remove them, otherwise we get an error from tokenizers
    dplyr::filter(utf8::utf8_valid(content))
}

read_datasets <- function() {
  dplyr::bind_rows(
    read_dataset("dfalbel/cran-packages"),
    read_dataset("dfalbel/github-r-repos")
  )
}

Next we implement a function that trains a tokenizer for our dataset.

create_tokenizer <- function(text, vocab_size, special_tokens) {
  tok <- tok::tokenizer$new(tok::model_bpe$new())

  tok$pre_tokenizer <- tok::pre_tokenizer_byte_level$new(add_prefix_space = FALSE)
  tok$decoder <- tok::decoder_byte_level$new()
  tok$post_processor <- tok::processor_byte_level$new(trim_offsets = FALSE)

  tok$train_from_memory(
    text,
    tok::trainer_bpe$new(vocab_size = vocab_size, special_tokens = special_tokens)
  )
  tok
}

# test code to debug the tokenizer
# data <- read_datasets()
# tok <- create_tokenizer(data$content)

We can finally implement the torch dataset that we are going to use for training the model. We are going to use the torch::iterable_dataset instead of torch::dataset. The main motivation is that we can’t really know the total number of samples in the dataset, so we can implement a .getitem() method to get any arbiratrary sample. Thus we implement the .iter method that returns a new sample every time it’s called.

r_sources_dataset <- torch::iterable_dataset(
  "r_sources_dataset",
  initialize = function(root = ".", vocab_size = 20000, context_length = 128) {
    self$data <- read_datasets()
    self$context_length <- context_length
    self$index <- sample.int(nrow(self$data))

    # we only create a tokenizer if it doesn't exist, otherwise we just load it
    tok_path <- file.path(root, glue::glue("tokenizer-{vocab_size}.json"))
    if (!file.exists(tok_path)) {
      self$tok <- create_tokenizer(
        as.character(self$data$content),
        vocab_size,
        c("<fbegin>", "<fend>")
      )
      fs::dir_create(root)
      self$tok$save(tok_path)
    } else {
      self$tok <- tok::tokenizer$from_file(tok_path)
    }
  },
  .iter = function() {
    i <- 1L
    sequence <- c()
    function() {
      while (length(sequence) < (self$context_length + 1) && i <= nrow(self$data)) {
        sequence <<- c(
          sequence,
          self$tok$encode(paste("<fbegin>", as.character(self$data$content[self$index[i]]), "<fend>"))$ids
        )
        i <- i + 1L
      }

      if (length(sequence) < (self$context_length + 1)) {
        return(coro::exhausted())
      }

      on.exit({
        sequence <<- sequence[-seq_len(self$context_length)]
      })
      list(
        input_ids = sequence[seq_len(self$context_length)] + 1L,
        labels = sequence[2:(self$context_length + 1)] + 1L
      )
    }
  }
)

# debug code for the dataset
# ds <- r_sources_dataset("~/Downloads/")
# it <- ds$.iter()
# it()
# ds$tok$get_vocab_size()

This dataset is likely too large for us to train the model on all documents in this example. It’s also hard to predict how long it will take for it to train until the end. In order to make it easier, we define a wraper dataset that is used to run the above dataset for a fixed number of steps. This is not required, but makes using luz more pleasant, as we can easily define for how many tokens we want to train our model.

fixed_steps_iterable_dataset <- iterable_dataset(
  "fixed_steps_dataset",
  initialize = function(dataset, steps) {
    self$dataset <- dataset
    self$steps <- steps
  },
  .iter = function() {
    i <- 1L
    iter <- NULL
    function() {
      if (i > self$steps) {
        return(coro::exhausted())
      }

      i <<- i + 1L

      if (is.null(iter) || coro::is_exhausted(data <- iter())) {
        iter <<- self$dataset$.iter()
        data <- iter()
      }

      data
    }
  },
  .length = function() {
    self$steps
  }
)

We finally define the model we are going to train. We’ll use a small version of GPT2. We also define a generate method allowing us to sample from the model given an initial context.

net <- nn_module(
  initialize = function() {
    self$gpt <- minhub::gpt2(
      vocab_size = 20000,
      pdrop = 0.1
    )
  },
  forward = function(x) {
    self$gpt(x)$transpose(2,3)
  },
  generate = function(x, temperature = 1, iter = 50, top_k = 10) {
    # samples from the model givn a context vector.
    for (i in seq_len(iter)) {
      logits <- self$forward(x)[,,-1]
      logits <- logits/temperature
      c(prob, ind) %<-% logits$topk(top_k)
      logits <- torch_full_like(logits, -Inf)$scatter_(-1, ind, prob)
      logits <- nnf_softmax(logits, dim = -1)
      id_next <- torch_multinomial(logits, num_samples = 1)
      x <- torch_cat(list(x, id_next), dim = 2)
    }
    x
  }
)

# debug code for the model
# ds <- torch::dataloader(r_sources_dataset("~/Downloads/"), batch_size = 32)
# batch <- coro::collect(ds, 1)[[1]]
# str(batch)
# m <- net()
# str(m(batch$input_ids))

To make it easier to inspect training, we will also define a callback that prints a sample from the model every epoch.

# samples from the model using the context.
generate <- function(model, tok, context, ...) {
  local_no_grad() # disables gradient for sampling
  x <- tok$encode(context)$ids + 1L
  x <- torch_tensor(x)[NULL,]$to(device = model$device)
  content <- as.integer(model$generate(x, ...)$cpu())
  tok$decode(content - 1L)
}

display_cb <- luz_callback(
  initialize = function() {},
  on_epoch_end = function() {
    local_no_grad()
    # sample from the model...
    context <- "# creates a linear model"
    text <- generate(ctx$model, dataset$dataset$tok, context, iter = 100)
    cli::cli_rule()
    cat(text, "\n")
    cli::cli_rule()
  }
)

We can finally train the model. We define that we want to train the model for half a billion tokens in a total of 100 epochs.

n_tokens <- 500e6
batch_size <- 16
epochs <- 100
context_length <- 256L

steps <- n_tokens / context_length / epochs
dataset <- fixed_steps_iterable_dataset(
  r_sources_dataset(context_length = context_length),
  steps = steps
)

fitted <- net %>%
  setup(
    optimizer = optim_adam,
    loss = nn_cross_entropy_loss()
  ) %>%
  set_opt_hparams(lr = 3e-4) |>
  fit(
    dataset,
    epochs = epochs,
    dataloader_options = list(batch_size = batch_size),
    callbacks = list(
      luz_callback_lr_scheduler(
        torch::lr_one_cycle,
        max_lr = 0.1,
        epochs = epochs,
        steps_per_epoch = steps/batch_size,
        call_on = "on_batch_end"
      ),
      luz_callback_gradient_clip(max_norm = 1),
      display_cb()
    ),
    verbose = TRUE
  )

luz::luz_save(fitted, "model.pt")

We can then use the model to generate text given a prompt with:

fitted <- luz::luz_load("model.pt")
tok <- tok::tokenizer$from_file("tokenizer-20000.json")
context <- "#' Creates a linear model
linear_model <- function(x, y) {
"
text <- generate(fitted$model, tok, context, iter = 100)
cat(text)