Skip to contents

This callback allows you to resume training a model.

Usage

luz_callback_auto_resume(path = "./state.pt")

Arguments

path

Path to save state files for the model.

Details

When using it, model weights, optimizer state are serialized at the end of each epoch. If something fails during training simply re-running the same script will restart the model training from the epoch right after the last epoch that was serialized.

Note

In general you will want to add this callback as the last in the callbacks list, this way, the serialized state is likely to contain all possible changes that other callbacks could have made at 'on_epoch_end'. The default weight attribute of this callback is Inf.

Read the checkpointing article in the pkgdown website for more information.

Customizing serialization

By default model, optimizer state and records are serialized. Callbacks can be used to customize serialization by implementing the state_dict() and load_state_dict() methods. If those methods are implemented, then state_dict() is called at the end of each epoch and load_state_dict() is called when the model is resumed.

Examples

if (torch::torch_is_installed()) {
library(torch)
library(luz)

x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)

model <- nn_linear %>%
  setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
  set_hparams(in_features = 10, out_features = 1) %>%
  set_opt_hparams(lr = 0.01)


# simulate a failure in the middle of epoch 5 happening only once.
callback_stop <- luz_callback(
  "interrupt",
  failed = FALSE,
  on_epoch_end = function() {
    if (ctx$epoch == 5 && !self$failed) {
      self$failed <- TRUE
      stop("Error on epoch 5")
    }
  }
)

path <- tempfile()
autoresume <- luz_callback_auto_resume(path = path)
interrupt <- callback_stop()

# try once and the model fails
try({
  results <- model %>% fit(
    list(x, y),
    callbacks = list(autoresume, interrupt),
    verbose = FALSE
  )
})

# model resumes and completes
results <- model %>% fit(
  list(x, y),
  callbacks = list(autoresume, interrupt),
  verbose = FALSE
)

get_metrics(results)

}
#> Error in FUN(X[[i]], ...) : 
#>   Error while calling callback with class <interrupt/LuzCallback/R6> at
#> on_epoch_end.
#> Caused by error in `self[[callback_nm]]()`:
#> ! Error on epoch 5
#>      set metric epoch    value
#> 1  train   loss     1 1.295996
#> 2  train   loss     2 1.154590
#> 3  train   loss     3 1.110767
#> 4  train   loss     4 1.104347
#> 5  train   loss     5 1.097942
#> 6  train   loss     6 1.083621
#> 7  train   loss     7 1.104237
#> 8  train   loss     8 1.098173
#> 9  train   loss     9 1.100960
#> 10 train   loss    10 1.102442