This example is inspired by the chargpt project by Andrey Karpathy. We are going to train character-level language model on Shakespeare texts.
We first load the libraries that we plan to use:
Next we define the torch dataset that will pre-process data for the model. It splits the text into a character vector, each element containing exactly one character.
Then lists all unique characters into the vocab
attribute. The order of the characters in the vocabulary is used to
encode each character to an integer value, that will be used in the
embedding layer.
The .getitem()
method, can take chunks of
block_size
characters and encode them into their integer
representation.
url <- "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
char_dataset <- torch::dataset(
initialize = function(data, block_size = 128) {
self$block_size <- block_size
self$data <- stringr::str_split_1(data, "")
self$data_size <- length(self$data)
self$vocab <- unique(self$data)
self$vocab_size <- length(self$vocab)
},
.getitem = function(i) {
chunk <- self$data[i + seq_len(self$block_size + 1)]
idx <- match(chunk, self$vocab)
list(
x = head(idx, self$block_size),
y = tail(idx, self$block_size)
)
},
.length = function() {
self$data_size - self$block_size - 1L # this is to account the last value
}
)
dataset <- char_dataset(readr::read_file(url))
dataset[1] # this allows us to see an element of the dataset
We then define the neural net we are going to train. Defining a GPT-2 model is quite verbose, so we are going to use the minhub implementation directly. You can find the full model definition here, and this code is entirely self-contained, so you don’t need to install minhub, if you don’t want to.
We also implemented the generate
method for the model,
that allows one to generate completions using the model. It applies the
model in a loop, at each iteration prediction what’s the next
character.
model <- torch::nn_module(
initialize = function(vocab_size) {
# remotes::install_github("mlverse/minhub")
self$gpt <- minhub::gpt2(
vocab_size = vocab_size,
n_layer = 6,
n_head = 6,
n_embd = 192
)
},
forward = function(x) {
# we have to transpose to make the vocabulary the last dimension
self$gpt(x)$transpose(2,3)
},
generate = function(x, temperature = 1, iter = 50, top_k = 10) {
# samples from the model givn a context vector.
for (i in seq_len(iter)) {
logits <- self$forward(x)[,,-1]
logits <- logits/temperature
c(prob, ind) %<-% logits$topk(top_k)
logits <- torch_full_like(logits, -Inf)$scatter_(-1, ind, prob)
logits <- nnf_softmax(logits, dim = -1)
id_next <- torch_multinomial(logits, num_samples = 1)
x <- torch_cat(list(x, id_next), dim = 2)
}
x
}
)
Next, we implemented a callback that is used for nicely displaying generated samples during the model training:
# samples from the model using the context.
generate <- function(model, vocab, context, ...) {
local_no_grad() # disables gradient for sampling
x <- match(stringr::str_split_1(context, ""), vocab)
x <- torch_tensor(x)[NULL,]$to(device = model$device)
content <- as.integer(model$generate(x, ...)$cpu())
paste0(vocab[content], collapse = "")
}
display_cb <- luz_callback(
initialize = function(iter = 500) {
self$iter <- iter # print every 500 iterations
},
on_train_batch_end = function() {
if (!(ctx$iter %% self$iter == 0))
return()
ctx$model$eval()
with_no_grad({
# sample from the model...
context <- "O God, O God!"
text <- generate(ctx$model, dataset$vocab, context, iter = 100)
cli::cli_h3(paste0("Iter ", ctx$iter))
cli::cli_text(text)
})
}
)
Finally, you can train the model using fit
:
fitted <- model |>
setup(
loss = nn_cross_entropy_loss(),
optimizer = optim_adam
) |>
set_opt_hparams(lr = 5e-4) |>
set_hparams(vocab_size = dataset$vocab_size) |>
fit(
dataset,
dataloader_options = list(batch_size = 128, shuffle = TRUE),
epochs = 1,
callbacks = list(
display_cb(iter = 500),
luz_callback_gradient_clip(max_norm = 1)
)
)
One epoch, is reasonable for this dataset and takes ~1h on the M1 MBP. You can generate new samples with:
context <- "O God, O God!"
text <- generate(fitted$model, dataset$vocab, context, iter = 100)
cat(text)