Skip to contents

Create predictions for a fitted model

Usage

# S3 method for class 'luz_module_fitted'
predict(
  object,
  newdata,
  ...,
  callbacks = list(),
  accelerator = NULL,
  verbose = NULL,
  dataloader_options = NULL
)

Arguments

object

(fitted model) the fitted model object returned from fit.luz_module_generator()

newdata

(dataloader, dataset, list or array) returning a list with at least 1 element. The other elements aren't used.

...

Currently unused.

callbacks

(list, optional) A list of callbacks defined with luz_callback() that will be called during the training procedure. The callbacks luz_callback_metrics(), luz_callback_progress() and luz_callback_train_valid() are always added by default.

accelerator

(accelerator, optional) An optional accelerator() object used to configure device placement of the components like nn_modules, optimizers and batches of data.

verbose

(logical, optional) An optional boolean value indicating if the fitting procedure should emit output to the console during training. By default, it will produce output if interactive() is TRUE, otherwise it won't print to the console.

dataloader_options

Options used when creating a dataloader. See torch::dataloader(). shuffle=TRUE by default for the training data and batch_size=32 by default. It will error if not NULL and data is already a dataloader.

See also

Other training: evaluate(), fit.luz_module_generator(), setup()