Skip to contents

Training

setup()
Set's up a nn_module to use with luz
fit(<luz_module_generator>)
Fit a nn_module
predict(<luz_module_fitted>)
Create predictions for a fitted model
evaluate()
Evaluates a fitted model on a dataset
set_hparams()
Set hyper-parameter of a module
set_opt_hparams()
Set optimizer hyper-parameters
get_metrics()
Get metrics from the object
ctx
Context object
context
Context object
lr_finder()
Learning Rate Finder
as_dataloader()
Creates a dataloader from its input

Metrics

luz_metric()
Creates a new luz metric
luz_metric_accuracy()
Accuracy
luz_metric_binary_accuracy()
Binary accuracy
luz_metric_binary_accuracy_with_logits()
Binary accuracy with logits
luz_metric_binary_auroc()
Computes the area under the ROC
luz_metric_mae()
Mean absolute error
luz_metric_mse()
Mean squared error
luz_metric_multiclass_auroc()
Computes the multi-class AUROC
luz_metric_rmse()
Root mean squared error

Misc

nn_mixup_loss()
Loss to be used with callbacks_mixup().
nnf_mixup()
Mixup logic

Callbacks

luz_callback()
Create a new callback
luz_callback_csv_logger()
CSV logger callback
luz_callback_early_stopping()
Early stopping callback
luz_callback_gradient_clip()
Gradient clipping callback
luz_callback_interrupt()
Interrupt callback
luz_callback_keep_best_model()
Keep the best model
luz_callback_lr_scheduler()
Learning rate scheduler callback
luz_callback_metrics()
Metrics callback
luz_callback_mixup()
Mixup callback
luz_callback_model_checkpoint()
Checkpoints model weights
luz_callback_profile()
Profile callback
luz_callback_progress()
Progress callback
luz_callback_train_valid()
Train-eval callback

Accelerator

accelerator()
Create an accelerator

Serialization

luz_save()
Saves luz objects to disk
luz_load()
Load trained model
luz_load_model_weights() luz_save_model_weights()
Loads model weights into a fitted object.