Computes accuracy for binary classification problems where the model
return logits. Commonly used together with torch::nn_bce_with_logits_loss().
Details
Probabilities are generated using torch::nnf_sigmoid() and threshold is used to
classify between 0 or 1.
See also
Other luz_metrics:
luz_metric(),
luz_metric_accuracy(),
luz_metric_binary_accuracy(),
luz_metric_binary_auroc(),
luz_metric_mae(),
luz_metric_mse(),
luz_metric_multiclass_auroc(),
luz_metric_rmse()
Examples
if (torch::torch_is_installed()) {
library(torch)
metric <- luz_metric_binary_accuracy_with_logits(threshold = 0.5)
metric <- metric$new()
metric$update(torch_randn(100), torch::torch_randint(0, 1, size = 100))
metric$compute()
}
#> [1] 0.43