Computes accuracy for multi-class classification problems.
Details
This metric expects to take logits or probabilities at every update. It will then take the columnwise argmax and compare to the target.
See also
Other luz_metrics:
luz_metric_binary_accuracy_with_logits()
,
luz_metric_binary_accuracy()
,
luz_metric_binary_auroc()
,
luz_metric_mae()
,
luz_metric_mse()
,
luz_metric_multiclass_auroc()
,
luz_metric_rmse()
,
luz_metric()
Examples
if (torch::torch_is_installed()) {
library(torch)
metric <- luz_metric_accuracy()
metric <- metric$new()
metric$update(torch_randn(100, 10), torch::torch_randint(1, 10, size = 100))
metric$compute()
}
#> [1] 0.08