In the training phase, computes individual losses with regard to two targets, weights them item-wise, and averages the linear combinations to yield the mean batch loss. For validation and testing, defers to the passed-in loss.
Arguments
- loss
the underlying loss
nn_module
to call. It must support thereduction
field. During training the attribute will be changed to'none'
so we get the loss for individual observations. See for for example documentation for thereduction
argument intorch::nn_cross_entropy_loss()
.
Details
It should be used together with luz_callback_mixup()
.