The setup function is used to set important attributes and method for nn_modules
to be used with luz.
Arguments
- module
(
nn_module
) Thenn_module
that you want set up.- loss
(
function
, optional) An optional function with the signaturefunction(input, target)
. It's only requires if yournn_module
doesn't implement a method calledloss
.- optimizer
(
torch_optimizer
, optional) A function with the signaturefunction(parameters, ...)
that is used to initialize an optimizer given the model parameters.- metrics
(
list
, optional) A list of metrics to be tracked during the training procedure. Sometimes, you want some metrics to be evaluated only during training or validation, in this case you can pass aluz_metric_set()
object to specify metrics used in each stage.- backward
(
function
) A functions that takes the loss scalar values as it's parameter. It must call$backward()
ortorch::autograd_backward()
. In general you don't need to set this parameter unless you need to customize how luz calls thebackward()
, for example, if you need to add additional arguments to the backward call. Note that this becomes a method of thenn_module
thus can be used by your customstep()
if you override it.
Value
A luz module that can be trained with fit()
.
Note
It also adds a device
active field that can be used to query the current
module device
within methods, with eg self$device
. This is useful when
ctx()
is not available, eg, when calling methods from outside the luz
wrappers. Users can override the default by implementing a device
active
method in the input module
.
See also
Other training:
evaluate()
,
fit.luz_module_generator()
,
predict.luz_module_fitted()