Computes the hierarchy-constrained loss for multi-label classification. Enforces that if a class is predicted positive, all its ancestors must also be positive, using the ancestor matrix R.
Usage
nnf_mc_loss(
output,
target,
R,
to_eval = NULL,
criterion = nnf_binary_cross_entropy_with_logits
)Arguments
- output
A
torch_tensorof raw network outputs (pre-sigmoid), shape(batch_size, n_classes).- target
Binary target labels, shape
(batch_size, n_classes).- R
Ancestor matrix tensor of shape
(1, n_classes, n_classes)whereR[1, i, j] = 1iff classiis a descendant of classj.- to_eval
Optional logical tensor of shape
(n_classes,)indicating which classes to include in the loss computation. IfNULL, all classes are evaluated.- criterion
Loss function to apply after constraint propagation. Default:
nnf_binary_cross_entropy_with_logits(expects raw logits).
Value
A scalar torch_tensor containing the computed loss, or a tensor
of shape (batch_size, n_classes) if reduction = "none".
Details
The loss combines constrained outputs differently for positive and negative labels:
For positive labels: uses constrained output of label-weighted predictions
For negative labels: uses constrained raw predictions (penalizes ancestor violations)