Computes accuracy for multi-class classification problems.
Details
This metric expects to take logits or probabilities at every update. It will then take the columnwise argmax and compare to the target.
Examples
if (torch::torch_is_installed()) {
library(torch)
metric <- luz_metric_accuracy()
metric <- metric$new()
metric$update(torch_randn(100, 10), torch::torch_randint(1, 10, size = 100))
metric$compute()
}
#> [1] 0.11