In the training phase, computes individual losses with regard to two targets, weights them item-wise, and averages the linear combinations to yield the mean batch loss. For validation and testing, defers to the passed-in loss.
Arguments
- loss
the underlying loss
nn_moduleto call. It must support thereductionfield. During training the attribute will be changed to'none'so we get the loss for individual observations. See for for example documentation for thereductionargument intorch::nn_cross_entropy_loss().
Details
It should be used together with luz_callback_mixup().