Luz is a higher level API for torch that is designed to be highly flexible by providing a layered API that allows it to be useful no matter the level of control your need for your training loop.
In the getting started vignette we have seen the basics of luz and how to quickly modify parts of the training loop using callbacks and custom metrics. In this document we will describe how luz allows the user to get fine-grained control of the training loop.
Apart from the use of callbacks, there are three more ways that you can use luz (depending on how much control you need):
Multiple optimizers or losses: You might be optimizing two loss functions each with its own optimizer, but you still don’t want to modify the
backward()
-zero_grad()
andstep()
calls. This is common in models like GANs (Generative Adversarial Networks) when you have competing neural networks trained with different losses and optimizers.Fully flexible steps: You might want to be in control of how to call
backward()
,zero_grad()
andstep()
. You might also want to have more control of gradient computation. For example, you might want to use ‘virtual batch sizes’, where you accumulate the gradients for a few steps before updating the weights.Completely flexible loops: Your training loop can be anything you want but you still want to use luz to handle device placement of the dataloaders, optimizers and models. See
vignette("accelerator")
.
Let’s consider a simplified version of the net
that we
implemented in the getting started vignette:
net <- nn_module(
"Net",
initialize = function() {
self$fc1 <- nn_linear(100, 50)
self$fc1 <- nn_linear(50, 10)
},
forward = function(x) {
x %>%
self$fc1() %>%
nnf_relu() %>%
self$fc2()
}
)
Using the highest level of luz API we would fit it using:
fitted <- net %>%
setup(
loss = nn_cross_entropy_loss(),
optimizer = optim_adam,
metrics = list(
luz_metric_accuracy
)
) %>%
fit(train_dl, epochs = 10, valid_data = test_dl)
Multiple optimizers
Suppose we want to do an experiment where we train the first fully
connected layer using a learning rate of 0.1 and the second one using a
learning rate of 0.01. We will minimize the same
nn_cross_entropy_loss()
for both, but for the first layer
we want to add L1 regularization on the weights.
In order to use luz for this, we will implement two methods in the
net
module:
set_optimizers
: returns a named list of optimizers depending on thectx
.loss
: computes the loss depending on the selected optimizer.
Let’s go to the code:
net <- nn_module(
"Net",
initialize = function() {
self$fc1 <- nn_linear(100, 50)
self$fc1 <- nn_linear(50, 10)
},
forward = function(x) {
x %>%
self$fc1() %>%
nnf_relu() %>%
self$fc2()
},
set_optimizers = function(lr_fc1 = 0.1, lr_fc2 = 0.01) {
list(
opt_fc1 = optim_adam(self$fc1$parameters, lr = lr_fc1),
opt_fc2 = optim_adam(self$fc2$parameters, lr = lr_fc2)
)
},
loss = function(input, target) {
pred <- ctx$model(input)
if (ctx$opt_name == "opt_fc1")
nnf_cross_entropy(pred, target) + torch_norm(self$fc1$weight, p = 1)
else if (ctx$opt_name == "opt_fc2")
nnf_cross_entropy(pred, target)
}
)
Notice that the model optimizers will be initialized according to the
set_optimizers()
method’s return value (a list). In this
case, we are initializing the optimizers using different model
parameters and learning rates.
The loss()
method is responsible for computing the loss
that will then be back-propagated to compute gradients and update the
weights. This loss()
method can access the ctx
object that will contain an opt_name
field, describing
which optimizer is currently being used. Note that this function will be
called once for each optimizer for each training and validation step.
See help("ctx")
for complete information about the context
object.
We can finally setup
and fit
this module,
however we no longer need to specify optimizers and loss functions.
fitted <- net %>%
setup(metrics = list(luz_metric_accuracy)) %>%
fit(train_dl, epochs = 10, valid_data = test_dl)
Now let’s re-implement this same model using the slightly more flexible approach of overriding the training and validation step.
Fully flexible step
Instead of implementing the loss()
method, we can
implement the step()
method. This allows us to flexibly
modify what happens when training and validating for each batch in the
dataset. You are now responsible for updating the weights by stepping
the optimizers and back-propagating the loss.
net <- nn_module(
"Net",
initialize = function() {
self$fc1 <- nn_linear(100, 50)
self$fc1 <- nn_linear(50, 10)
},
forward = function(x) {
x %>%
self$fc1() %>%
nnf_relu() %>%
self$fc2()
},
set_optimizers = function(lr_fc1 = 0.1, lr_fc2 = 0.01) {
list(
opt_fc1 = optim_adam(self$fc1$parameters, lr = lr_fc1),
opt_fc2 = optim_adam(self$fc2$parameters, lr = lr_fc2)
)
},
step = function() {
ctx$loss <- list()
for (opt_name in names(ctx$optimizers)) {
ctx$pred <- ctx$model(ctx$input)
opt <- ctx$optimizers[[opt_name]]
loss <- nnf_cross_entropy(pred, target)
if (opt_name == "opt_fc1") {
# we have L1 regularization in layer 1
loss <- nnf_cross_entropy(pred, target) +
torch_norm(self$fc1$weight, p = 1)
}
if (ctx$training) {
opt$zero_grad()
loss$backward()
opt$step()
}
ctx$loss[[opt_name]] <- loss$detach()
}
}
)
The important things to notice here are:
The
step()
method is used for both training and validation. You need to be careful to only modify the weights when training. Again, you can get complete information regarding the context object usinghelp("ctx")
.ctx$optimizers
is a named list holding each optimizer that was created when theset_optimizers()
method was called.You need to manually track the losses by saving saving them in a named list in
ctx$loss
. By convention, we use the same name as the optimizer it refers to. It is good practice todetach()
them before saving to reduce memory usage.Callbacks that would be called inside the default
step()
method likeon_train_batch_after_pred
,on_train_batch_after_loss
, etc, won’t be automatically called. You can still cal them manually by addingctx$call_callbacks("<callback name>")
inside your training step. See the code forfit_one_batch()
andvalid_one_batch
to find all the callbacks that won’t be called.If you want luz metrics to work with your custom
step()
method, you must assignctx$pred
with the model predictions as metrics will always be called withmetric$update(ctx$pred, ctx$target)
.
Next steps
In this article you learned how to customize the step()
of your training loop using luz layered functionality.
Luz also allows more flexible modifications of the training loop
described in the Accelerator vignette
(vignette("accelerator")
).
You should now be able to follow the examples marked with the ‘intermediate’ and ‘advanced’ category in the examples gallery.