The setup function is used to set important attributes and method for nn_modules
to be used with luz.
Arguments
- module
(
nn_module) Thenn_modulethat you want set up.- loss
(
function, optional) An optional function with the signaturefunction(input, target). It's only requires if yournn_moduledoesn't implement a method calledloss.- optimizer
(
torch_optimizer, optional) A function with the signaturefunction(parameters, ...)that is used to initialize an optimizer given the model parameters.- metrics
(
list, optional) A list of metrics to be tracked during the training procedure. Sometimes, you want some metrics to be evaluated only during training or validation, in this case you can pass aluz_metric_set()object to specify metrics used in each stage.- backward
(
function) A functions that takes the loss scalar values as it's parameter. It must call$backward()ortorch::autograd_backward(). In general you don't need to set this parameter unless you need to customize how luz calls thebackward(), for example, if you need to add additional arguments to the backward call. Note that this becomes a method of thenn_modulethus can be used by your customstep()if you override it.
Value
A luz module that can be trained with fit().
Note
It also adds a device active field that can be used to query the current
module device within methods, with eg self$device. This is useful when
ctx() is not available, eg, when calling methods from outside the luz
wrappers. Users can override the default by implementing a device active
method in the input module.
See also
Other training:
evaluate(),
fit.luz_module_generator(),
predict.luz_module_fitted()