Skip to contents

Evaluates a fitted model on a dataset

Usage

evaluate(
  object,
  data,
  ...,
  callbacks = list(),
  accelerator = NULL,
  verbose = NULL,
  dataloader_options = NULL
)

Arguments

object

A fitted model to evaluate.

data

(dataloader, dataset or list) A dataloader created with torch::dataloader() used for training the model, or a dataset created with torch::dataset() or a list. Dataloaders and datasets must return a list with at most 2 items. The first item will be used as input for the module and the second will be used as a target for the loss function.

...

Currently unused.

callbacks

(list, optional) A list of callbacks defined with luz_callback() that will be called during the training procedure. The callbacks luz_callback_metrics(), luz_callback_progress() and luz_callback_train_valid() are always added by default.

accelerator

(accelerator, optional) An optional accelerator() object used to configure device placement of the components like nn_modules, optimizers and batches of data.

verbose

(logical, optional) An optional boolean value indicating if the fitting procedure should emit output to the console during training. By default, it will produce output if interactive() is TRUE, otherwise it won't print to the console.

dataloader_options

Options used when creating a dataloader. See torch::dataloader(). shuffle=TRUE by default for the training data and batch_size=32 by default. It will error if not NULL and data is already a dataloader.

Details

Once a model has been trained you might want to evaluate its performance on a different dataset. For that reason, luz provides the ?evaluate function that takes a fitted model and a dataset and computes the metrics attached to the model.

Evaluate returns a luz_module_evaluation object that you can query for metrics using the get_metrics function or simply print to see the results.

For example:

evaluation <- fitted %>% evaluate(data = valid_dl)
metrics <- get_metrics(evaluation)
print(evaluation)

## A `luz_module_evaluation`
## -- Results ---------------------------------------------------------------------
## loss: 1.8892
## mae: 1.0522
## mse: 1.645
## rmse: 1.2826

See also