Training a causal language model from scratch
Source:vignettes/examples/text-generation.Rmd
text-generation.Rmd
This example is an adaptation of the ‘Training a causal language model from scratch’ class from the Hugging Face NLP course.
library(torch)
library(tok)
library(luz)
library(minhub) # remotes::install_github("mlverse/minhub")
#library(tidyverse)
options(arrow.skip_nul = TRUE)
library(arrow)
Data
First step is to implement a torch dataset that gathers data and pre-process it into a format that is suitable for training the model.
That means that we need to:
- Download data
- Train a tokenizer for this dataset
- Be able to produce sequences of tokens in the format expected by the model
We are going to use 2 datasets available in Hugging Face Hub. The first contain all R packages source code available on CRAN. The second contains all R code that is available in GitHub data dumps. Both datasets are in the Parquet format. Following we implement a function that downloads and caches the data and then returns a single arrow table containing all data.
read_dataset <- function(source) {
d <- source |>
hfhub::hub_snapshot(repo_type = "dataset", allow_patterns = "parquet$") |>
fs::path("data/r") |>
arrow::open_dataset() |>
dplyr::filter(stringr::str_detect(path, ".*\\.[rR]$")) |>
dplyr::select(content) |>
dplyr::mutate(content = arrow::cast(content, arrow::string())) |>
dplyr::filter(!is.na(content)) |>
dplyr::collect() %>%
# the dataset contains invalid utf8 characters...
# we need to remove them, otherwise we get an error from tokenizers
dplyr::filter(utf8::utf8_valid(content))
}
read_datasets <- function() {
dplyr::bind_rows(
read_dataset("dfalbel/cran-packages"),
read_dataset("dfalbel/github-r-repos")
)
}
Next we implement a function that trains a tokenizer for our dataset.
create_tokenizer <- function(text, vocab_size, special_tokens) {
tok <- tok::tokenizer$new(tok::model_bpe$new())
tok$pre_tokenizer <- tok::pre_tokenizer_byte_level$new(add_prefix_space = FALSE)
tok$decoder <- tok::decoder_byte_level$new()
tok$post_processor <- tok::processor_byte_level$new(trim_offsets = FALSE)
tok$train_from_memory(
text,
tok::trainer_bpe$new(vocab_size = vocab_size, special_tokens = special_tokens)
)
tok
}
# test code to debug the tokenizer
# data <- read_datasets()
# tok <- create_tokenizer(data$content)
We can finally implement the torch dataset that we are going to use
for training the model. We are going to use the
torch::iterable_dataset
instead of
torch::dataset
. The main motivation is that we can’t really
know the total number of samples in the dataset, so we can implement a
.getitem()
method to get any arbiratrary sample. Thus we
implement the .iter
method that returns a new sample every
time it’s called.
r_sources_dataset <- torch::iterable_dataset(
"r_sources_dataset",
initialize = function(root = ".", vocab_size = 20000, context_length = 128) {
self$data <- read_datasets()
self$context_length <- context_length
self$index <- sample.int(nrow(self$data))
# we only create a tokenizer if it doesn't exist, otherwise we just load it
tok_path <- file.path(root, glue::glue("tokenizer-{vocab_size}.json"))
if (!file.exists(tok_path)) {
self$tok <- create_tokenizer(
as.character(self$data$content),
vocab_size,
c("<fbegin>", "<fend>")
)
fs::dir_create(root)
self$tok$save(tok_path)
} else {
self$tok <- tok::tokenizer$from_file(tok_path)
}
},
.iter = function() {
i <- 1L
sequence <- c()
function() {
while (length(sequence) < (self$context_length + 1) && i <= nrow(self$data)) {
sequence <<- c(
sequence,
self$tok$encode(paste("<fbegin>", as.character(self$data$content[self$index[i]]), "<fend>"))$ids
)
i <- i + 1L
}
if (length(sequence) < (self$context_length + 1)) {
return(coro::exhausted())
}
on.exit({
sequence <<- sequence[-seq_len(self$context_length)]
})
list(
input_ids = sequence[seq_len(self$context_length)] + 1L,
labels = sequence[2:(self$context_length + 1)] + 1L
)
}
}
)
# debug code for the dataset
# ds <- r_sources_dataset("~/Downloads/")
# it <- ds$.iter()
# it()
# ds$tok$get_vocab_size()
This dataset is likely too large for us to train the model on all documents in this example. It’s also hard to predict how long it will take for it to train until the end. In order to make it easier, we define a wraper dataset that is used to run the above dataset for a fixed number of steps. This is not required, but makes using luz more pleasant, as we can easily define for how many tokens we want to train our model.
fixed_steps_iterable_dataset <- iterable_dataset(
"fixed_steps_dataset",
initialize = function(dataset, steps) {
self$dataset <- dataset
self$steps <- steps
},
.iter = function() {
i <- 1L
iter <- NULL
function() {
if (i > self$steps) {
return(coro::exhausted())
}
i <<- i + 1L
if (is.null(iter) || coro::is_exhausted(data <- iter())) {
iter <<- self$dataset$.iter()
data <- iter()
}
data
}
},
.length = function() {
self$steps
}
)
We finally define the model we are going to train. We’ll use a small
version of GPT2. We also define a generate
method allowing
us to sample from the model given an initial context.
net <- nn_module(
initialize = function() {
self$gpt <- minhub::gpt2(
vocab_size = 20000,
pdrop = 0.1
)
},
forward = function(x) {
self$gpt(x)$transpose(2,3)
},
generate = function(x, temperature = 1, iter = 50, top_k = 10) {
# samples from the model givn a context vector.
for (i in seq_len(iter)) {
logits <- self$forward(x)[,,-1]
logits <- logits/temperature
c(prob, ind) %<-% logits$topk(top_k)
logits <- torch_full_like(logits, -Inf)$scatter_(-1, ind, prob)
logits <- nnf_softmax(logits, dim = -1)
id_next <- torch_multinomial(logits, num_samples = 1)
x <- torch_cat(list(x, id_next), dim = 2)
}
x
}
)
# debug code for the model
# ds <- torch::dataloader(r_sources_dataset("~/Downloads/"), batch_size = 32)
# batch <- coro::collect(ds, 1)[[1]]
# str(batch)
# m <- net()
# str(m(batch$input_ids))
To make it easier to inspect training, we will also define a callback that prints a sample from the model every epoch.
# samples from the model using the context.
generate <- function(model, tok, context, ...) {
local_no_grad() # disables gradient for sampling
x <- tok$encode(context)$ids + 1L
x <- torch_tensor(x)[NULL,]$to(device = model$device)
content <- as.integer(model$generate(x, ...)$cpu())
tok$decode(content - 1L)
}
display_cb <- luz_callback(
initialize = function() {},
on_epoch_end = function() {
local_no_grad()
# sample from the model...
context <- "# creates a linear model"
text <- generate(ctx$model, dataset$dataset$tok, context, iter = 100)
cli::cli_rule()
cat(text, "\n")
cli::cli_rule()
}
)
We can finally train the model. We define that we want to train the model for half a billion tokens in a total of 100 epochs.
n_tokens <- 500e6
batch_size <- 16
epochs <- 100
context_length <- 256L
steps <- n_tokens / context_length / epochs
dataset <- fixed_steps_iterable_dataset(
r_sources_dataset(context_length = context_length),
steps = steps
)
fitted <- net %>%
setup(
optimizer = optim_adam,
loss = nn_cross_entropy_loss()
) %>%
set_opt_hparams(lr = 3e-4) |>
fit(
dataset,
epochs = epochs,
dataloader_options = list(batch_size = batch_size),
callbacks = list(
luz_callback_lr_scheduler(
torch::lr_one_cycle,
max_lr = 0.1,
epochs = epochs,
steps_per_epoch = steps/batch_size,
call_on = "on_batch_end"
),
luz_callback_gradient_clip(max_norm = 1),
display_cb()
),
verbose = TRUE
)
luz::luz_save(fitted, "model.pt")
We can then use the model to generate text given a prompt with: