Training

setup()

Set's up a nn_module to use with luz

fit(<luz_module_generator>)

Fit a nn_module

predict(<luz_module_fitted>)

Create predictions for a fitted model

evaluate()

Evaluates a fitted model on a dataset

set_hparams()

Set hyper-parameter of a module

set_opt_hparams()

Set optimizer hyper-parameters

get_metrics()

Get metrics from the object

ctx

Context object

context

Context object

lr_finder()

Learning Rate Finder

as_dataloader()

Creates a dataloader from its input

Metrics

luz_metric()

Creates a new luz metric

luz_metric_accuracy()

Accuracy

luz_metric_binary_accuracy()

Binary accuracy

luz_metric_binary_accuracy_with_logits()

Binary accuracy with logits

luz_metric_binary_auroc()

Computes the area under the ROC

luz_metric_mae()

Mean absolute error

luz_metric_mse()

Mean squared error

luz_metric_multiclass_auroc()

Computes the multi-class AUROC

luz_metric_rmse()

Root mean squared error

Callbacks

luz_callback()

Create a new callback

luz_callback_csv_logger()

CSV logger callback

luz_callback_early_stopping()

Early stopping callback

luz_callback_interrupt()

Interrupt callback

luz_callback_lr_scheduler()

Learning rate scheduler callback

luz_callback_metrics()

Metrics callback

luz_callback_model_checkpoint()

Checkpoints model weights

luz_callback_profile()

Profile callback

luz_callback_progress()

Progress callback

luz_callback_train_valid()

Train-eval callback

Accelerator

accelerator()

Create an accelerator

Serialization

luz_save()

Saves luz objects to disk

luz_load()

Load trained model

luz_load_model_weights() luz_save_model_weights()

Loads model weights into a fitted object.

luz_load_model_weights() luz_save_model_weights()

Loads model weights into a fitted object.