Skip to contents

Fit a nn_module

Usage

# S3 method for class 'luz_module_generator'
fit(
  object,
  data,
  epochs = 10,
  callbacks = NULL,
  valid_data = NULL,
  accelerator = NULL,
  verbose = NULL,
  ...,
  dataloader_options = NULL
)

Arguments

object

An nn_module that has been setup().

data

(dataloader, dataset or list) A dataloader created with torch::dataloader() used for training the model, or a dataset created with torch::dataset() or a list. Dataloaders and datasets must return a list with at most 2 items. The first item will be used as input for the module and the second will be used as a target for the loss function.

epochs

(int) The maximum number of epochs for training the model. If a single value is provided, this is taken to be the max_epochs and min_epochs is set to 0. If a vector of two numbers is provided, the first value is min_epochs and the second value is max_epochs. The minimum and maximum number of epochs are included in the context object as ctx$min_epochs and ctx$max_epochs, respectively.

callbacks

(list, optional) A list of callbacks defined with luz_callback() that will be called during the training procedure. The callbacks luz_callback_metrics(), luz_callback_progress() and luz_callback_train_valid() are always added by default.

valid_data

(dataloader, dataset, list or scalar value; optional) A dataloader created with torch::dataloader() or a dataset created with torch::dataset() that will be used during the validation procedure. They must return a list with (input, target). If data is a torch dataset or a list, then you can also supply a numeric value between 0 and 1 - and in this case a random sample with size corresponding to that proportion from data will be used for validation.

accelerator

(accelerator, optional) An optional accelerator() object used to configure device placement of the components like nn_modules, optimizers and batches of data.

verbose

(logical, optional) An optional boolean value indicating if the fitting procedure should emit output to the console during training. By default, it will produce output if interactive() is TRUE, otherwise it won't print to the console.

...

Currently unused.

dataloader_options

Options used when creating a dataloader. See torch::dataloader(). shuffle=TRUE by default for the training data and batch_size=32 by default. It will error if not NULL and data is already a dataloader.

Value

A fitted object that can be saved with luz_save() and can be printed with print() and plotted with plot().

See also

predict.luz_module_fitted() for how to create predictions. setup() to find out how to create modules that can be trained with fit.

Other training: evaluate(), predict.luz_module_fitted(), setup()