Skip to contents

When fitting models take too long you might want to save intermediate state to disk, if something goes wrong during training (eg. process is killed, network fails, etc) you can recover from where it stopped.

You might also want to recover intermediate results to evaluate the model in different moments of the training, like comparing results after 10 epochs and after 30 epochs.

This article describes luz features that are built to handle those cases. These features are optional and are enabled once you add specific callbacks to your fit call.

Resuming training runs that crashed

If you have a long training run that can crash for whatever reason (computer turned off, process kileed in cluster, etc), we recommend you to add luz_callback_autoresume() to your list of callbacks.

luz_callback_autoresume() will automatically checkpoint the whole state of your model at the end of each epoch. If something fails during training you can simply rerun the same script, whithout any code changes and the checkpoint will be reloaded and the training will start from where it stopped.

For example, lets’s take a randomly generated training dataset and a linear model to show how autoresume works.

Here’s the training data:

x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)

And the model definition:

model <- nn_linear %>%
  setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
  set_hparams(in_features = 10, out_features = 1) %>%
  set_opt_hparams(lr = 0.01)

Let’s now create a callback that simulates a random failure that could happen. This callback will just raise an R error on the 5th epoch.

interrupt <- luz_callback(
  "interrupt",
  failed = FALSE,
  on_epoch_end = function() {
    if (ctx$epoch == 5 && !self$failed) {
      self$failed <- TRUE
      stop("Error on epoch 5")
    }
  }
)

Let’s now start training adding the luz_callback_auto_resume():

autoresume <- luz_callback_auto_resume(path = "state.pt")
inter <- interrupt()

# An error will happen in the 5th epoch and the model will be stopped.
results <- model %>% fit(
  list(x, y),
  callbacks = list(inter, autoresume),
  verbose = FALSE
)
#> Error in `FUN()`:
#> ! Error while calling callback with class <interrupt/LuzCallback/R6> at
#>   on_epoch_end.
#> Caused by error in `self[[callback_nm]]()`:
#> ! Error on epoch 5

To resume model training exactly from where it stopped you just need to restart fitting, using the exact same model, callbacks, etc:

results <- model %>% fit(
  list(x, y),
  callbacks = list(inter, autoresume),
  verbose = FALSE
)

With this, the model fitting process will be continued exactly from where it stopped. Records, optimizer and model state are recovered from the previous run so you can have the full results:

plot(results)

Checkpointing

Sometimes you want to have more control over how checkpoints are handled. In this case you can use luz_callback_model_checkpoint() to save checkpoints to a specified file or directory.

Let’s use the same example as in the resuming section: We first generate some data.

x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)

Then define our model:

model <- nn_linear %>%
  setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
  set_hparams(in_features = 10, out_features = 1) %>%
  set_opt_hparams(lr = 0.01)

Let’s now fit the model using luz_callback_model_checkpoint().

checkpoint <- luz_callback_model_checkpoint(
  path = "checkpoints/", 
  monitor = "train_loss"
)

results <- model %>% fit(
  list(x, y),
  callbacks = list(checkpoint),
  verbose = FALSE
)

You can see now that the checkpoints directory contains files with state dumps for each epoch. By default, luz_callback_model_checkpoint will save the state for each epochs and format the name including the resulting loss. This can be configured withing the path parameter, see ?luz_callback_model_checkpoint for details.

fs::dir_ls("checkpoints")
#> checkpoints/epoch-01-train_loss-1.237.pt
#> checkpoints/epoch-02-train_loss-1.065.pt
#> checkpoints/epoch-03-train_loss-1.026.pt
#> checkpoints/epoch-04-train_loss-1.004.pt
#> checkpoints/epoch-05-train_loss-1.004.pt
#> checkpoints/epoch-06-train_loss-1.005.pt
#> checkpoints/epoch-07-train_loss-0.999.pt
#> checkpoints/epoch-08-train_loss-0.998.pt
#> checkpoints/epoch-09-train_loss-1.001.pt
#> checkpoints/epoch-10-train_loss-1.002.pt

Finally, you can load a specific checkpoint to the fitted result using luz_load_checkpoint. Note that loading the checkpoint into a a luz_fitted_module is going to modify the model weights in-place.

luz_load_checkpoint(results, fs::dir_ls("checkpoints")[1])

You can then start making predictions, or evaluate your model using the reloeded weights.

You might also want to start a new training run from a checkpoint. For this, you can use the luz_callback_resume_from_checkpoint(). By default, it will only recover the model weights from the checkpoint file, but you can configure it to restore records, callback and optimizer state too. If a checkpoint directory is passed then training will resume from the last checkpoint file as returned by fs::dir_ls.

Here’s how you would use this callback:

resume <- luz_callback_resume_from_checkpoint(path = "checkpoints/")
results <- model %>% fit(
  list(x, y),
  callbacks = list(resume),
  verbose = FALSE
)
plot(results)

Custom callbacks state

Sometimes callbacks also need to keep their internal state in order to allow continuing training exactly from where it stopped. In this case, callbacks can implement the state_dict() and the load_state_dict() methods that are automatically called when saving and reloading checkpoints.

For example, suppose that you have a callback that tracks gradients for weights at every epoch. You want to use the tracked weights to further analyse the training procedure. It could be implemented like:

cb_weight_grad <- luz_callback(
  "weight_grad",
  gradients = list(),
  initialize = function(track_weights) {
    self$track_weights
  },
  on_train_batch_before_step = function() {
    gradients[[ctx$epoch]] <- list()
    for (w in self$track_weights) {
      gradients[[ctx$epoch]][[w]] <- self$model$parameters[[w]]
    }
  }
)

In the above example, the gradients field is a state in the callback. If training fails for some reason, gradients will be lost. If it’s important for you to also checkpoint the callback state, you can implement the state_dict() method must returning a named list of objects that compose the state of the callback and load_state_dict() taking the same named list returned by state_dict() and restoring the callback state.

The callback above could be reimplemented with:

cb_weight_grad <- luz_callback(
  "weight_grad",
  gradients = list(),
  initialize = function(track_weights) {
    self$track_weights
  },
  on_train_batch_before_step = function() {
    gradients[[ctx$epoch]] <- list()
    for (w in self$track_weights) {
      gradients[[ctx$epoch]][[w]] <- self$model$parameters[[w]]
    }
  },
  state_dict = function() {
    list(gradients = self$gradients)
  },
  load_state_dict = function(d) {
    self$gradients <- d$gradients
  }
)