Skip to contents

This example is inspired by the chargpt project by Andrey Karpathy. We are going to train character-level language model on Shakespeare texts.

We first load the libraries that we plan to use:

Next we define the torch dataset that will pre-process data for the model. It splits the text into a character vector, each element containing exactly one character.

Then lists all unique characters into the vocab attribute. The order of the characters in the vocabulary is used to encode each character to an integer value, that will be used in the embedding layer.

The .getitem() method, can take chunks of block_size characters and encode them into their integer representation.

url <- "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"

char_dataset <- torch::dataset(
    initialize = function(data, block_size = 128) {
        self$block_size <- block_size
        self$data <- stringr::str_split_1(data, "")

        self$data_size <- length(self$data)
        self$vocab <- unique(self$data)
        self$vocab_size <- length(self$vocab)
    },
    .getitem = function(i) {
        chunk <- self$data[i + seq_len(self$block_size + 1)]
        idx <- match(chunk, self$vocab)
        list(
            x = head(idx, self$block_size),
            y = tail(idx, self$block_size)
        )
    },
    .length = function() {
        self$data_size - self$block_size - 1L # this is to account the last value
    }
)

dataset <- char_dataset(readr::read_file(url))
dataset[1] # this allows us to see an element of the dataset

We then define the neural net we are going to train. Defining a GPT-2 model is quite verbose, so we are going to use the minhub implementation directly. You can find the full model definition here, and this code is entirely self-contained, so you don’t need to install minhub, if you don’t want to.

We also implemented the generate method for the model, that allows one to generate completions using the model. It applies the model in a loop, at each iteration prediction what’s the next character.

model <- torch::nn_module(
    initialize = function(vocab_size) {
        # remotes::install_github("mlverse/minhub")
        self$gpt <- minhub::gpt2(
            vocab_size = vocab_size,
            n_layer = 6,
            n_head = 6,
            n_embd = 192
        )
    },
    forward = function(x) {
        # we have to transpose to make the vocabulary the last dimension
        self$gpt(x)$transpose(2,3)
    },
    generate = function(x, temperature = 1, iter = 50, top_k = 10) {
        # samples from the model givn a context vector.
        for (i in seq_len(iter)) {
            logits <- self$forward(x)[,,-1]
            logits <- logits/temperature
            c(prob, ind) %<-% logits$topk(top_k)
            logits <- torch_full_like(logits, -Inf)$scatter_(-1, ind, prob)
            logits <- nnf_softmax(logits, dim = -1)
            id_next <- torch_multinomial(logits, num_samples = 1)
            x <- torch_cat(list(x, id_next), dim = 2)
        }
        x
    }
)

Next, we implemented a callback that is used for nicely displaying generated samples during the model training:

# samples from the model using the context.
generate <- function(model, vocab, context, ...) {
  local_no_grad() # disables gradient for sampling
  x <- match(stringr::str_split_1(context, ""), vocab)
  x <- torch_tensor(x)[NULL,]$to(device = model$device)
  content <- as.integer(model$generate(x, ...)$cpu())
  paste0(vocab[content], collapse = "")
}

display_cb <- luz_callback(
  initialize = function(iter = 500) {
    self$iter <- iter # print every 500 iterations
  },
  on_train_batch_end = function() {
    if (!(ctx$iter %% self$iter == 0))
      return()

    ctx$model$eval()
    with_no_grad({
      # sample from the model...
      context <- "O God, O God!"
      text <- generate(ctx$model, dataset$vocab, context, iter = 100)
      cli::cli_h3(paste0("Iter ", ctx$iter))
      cli::cli_text(text)
    })

  }
)

Finally, you can train the model using fit:

fitted <- model |>
    setup(
        loss = nn_cross_entropy_loss(),
        optimizer = optim_adam
    ) |>
    set_opt_hparams(lr = 5e-4) |>
    set_hparams(vocab_size = dataset$vocab_size) |>
    fit(
      dataset,
      dataloader_options = list(batch_size = 128, shuffle = TRUE),
      epochs = 1,
      callbacks = list(
        display_cb(iter = 500),
        luz_callback_gradient_clip(max_norm = 1)
      )
    )

One epoch, is reasonable for this dataset and takes ~1h on the M1 MBP. You can generate new samples with:

context <- "O God, O God!"
text <- generate(fitted$model, dataset$vocab, context, iter = 100)
cat(text)