Set's up a
nn_module to use with luz
The setup function is used to set important attributes and method for
to be used with luz.
nn_modulethat you want set up.
function, optional) An optional function with the signature
function(input, target). It's only requires if your
nn_moduledoesn't implement a method called
torch_optimizer, optional) A function with the signature
function(parameters, ...)that is used to initialize an optimizer given the model parameters.
list, optional) A list of metrics to be tracked during the training procedure. Sometimes, you want some metrics to be evaluated only during training or validation, in this case you can pass a
luz_metric_set()object to specify metrics used in each stage.
function) A functions that takes the loss scalar values as it's parameter. It must call
torch::autograd_backward(). In general you don't need to set this parameter unless you need to customize how luz calls the
backward(), for example, if you need to add additional arguments to the backward call. Note that this becomes a method of the
nn_modulethus can be used by your custom
step()if you override it.
A luz module that can be trained with
It also adds a
device active field that can be used to query the current
device within methods, with eg
self$device. This is useful when
ctx() is not available, eg, when calling methods from outside the
wrappers. Users can override the default by implementing a
method in the input