Fit a nn_module
Usage
# S3 method for class 'luz_module_generator'
fit(
object,
data,
epochs = 10,
callbacks = NULL,
valid_data = NULL,
accelerator = NULL,
verbose = NULL,
...,
dataloader_options = NULL
)
Arguments
- object
An
nn_module
that has beensetup()
.- data
(dataloader, dataset or list) A dataloader created with
torch::dataloader()
used for training the model, or a dataset created withtorch::dataset()
or a list. Dataloaders and datasets must return a list with at most 2 items. The first item will be used as input for the module and the second will be used as a target for the loss function.- epochs
(int) The maximum number of epochs for training the model. If a single value is provided, this is taken to be the
max_epochs
andmin_epochs
is set to 0. If a vector of two numbers is provided, the first value ismin_epochs
and the second value ismax_epochs
. The minimum and maximum number of epochs are included in the context object asctx$min_epochs
andctx$max_epochs
, respectively.- callbacks
(list, optional) A list of callbacks defined with
luz_callback()
that will be called during the training procedure. The callbacksluz_callback_metrics()
,luz_callback_progress()
andluz_callback_train_valid()
are always added by default.- valid_data
(dataloader, dataset, list or scalar value; optional) A dataloader created with
torch::dataloader()
or a dataset created withtorch::dataset()
that will be used during the validation procedure. They must return a list with (input, target). Ifdata
is a torch dataset or a list, then you can also supply a numeric value between 0 and 1 - and in this case a random sample with size corresponding to that proportion fromdata
will be used for validation.- accelerator
(accelerator, optional) An optional
accelerator()
object used to configure device placement of the components like nn_modules, optimizers and batches of data.- verbose
(logical, optional) An optional boolean value indicating if the fitting procedure should emit output to the console during training. By default, it will produce output if
interactive()
isTRUE
, otherwise it won't print to the console.- ...
Currently unused.
- dataloader_options
Options used when creating a dataloader. See
torch::dataloader()
.shuffle=TRUE
by default for the training data andbatch_size=32
by default. It will error if notNULL
anddata
is already a dataloader.
Value
A fitted object that can be saved with luz_save()
and can be
printed with print()
and plotted with plot()
.
See also
predict.luz_module_fitted()
for how to create predictions.
setup()
to find out how to create modules that can be trained with fit
.
Other training:
evaluate()
,
predict.luz_module_fitted()
,
setup()