Pretrain the TabNet: Attentive Interpretable Tabular Learning model on the predictor data exclusively (unsupervised training).
tabnet_pretrain(x, ...)
# Default S3 method
tabnet_pretrain(x, ...)
# S3 method for class 'data.frame'
tabnet_pretrain(
x,
y,
tabnet_model = NULL,
config = tabnet_config(),
...,
from_epoch = NULL
)
# S3 method for class 'formula'
tabnet_pretrain(
formula,
data,
tabnet_model = NULL,
config = tabnet_config(),
...,
from_epoch = NULL
)
# S3 method for class 'recipe'
tabnet_pretrain(
x,
data,
tabnet_model = NULL,
config = tabnet_config(),
...,
from_epoch = NULL
)
# S3 method for class 'Node'
tabnet_pretrain(
x,
tabnet_model = NULL,
config = tabnet_config(),
...,
from_epoch = NULL
)
Depending on the context:
A data frame of predictors.
A matrix of predictors.
A recipe specifying a set of preprocessing steps
created from recipes::recipe()
.
A Node where tree leaves will be left out, and attributes will be used as predictors.
The predictor data should be standardized (e.g. centered or scaled). The model treats categorical predictors internally thus, you don't need to make any treatment. The model treats missing values internally thus, you don't need to make any treatment.
Model hyperparameters.
Any hyperparameters set here will update those set by the config argument.
See tabnet_config()
for a list of all possible hyperparameters.
(optional) When x
is a data frame or matrix, y
is the outcome
A pretrained tabnet_model
object to continue the fitting on.
if NULL
(the default) a brand new model is initialized.
A set of hyperparameters created using the tabnet_config
function.
If no argument is supplied, this will use the default values in tabnet_config()
.
When a tabnet_model
is provided, restore the network weights from a specific epoch.
Default is last available checkpoint for restored model, or last epoch for in-memory model.
A formula specifying the outcome terms on the left-hand side, and the predictor terms on the right-hand side.
When a recipe or formula is used, data
is specified as:
A data frame containing both the predictors and the outcome.
A TabNet model object. It can be used for serialization, predictions, or further fitting.
Outcome value are accepted here only for consistent syntax with tabnet_fit
, but
by design the outcome, if present, is ignored during pre-training.
When providing a parent tabnet_model
parameter, the model pretraining resumes from that model weights
at the following epoch:
last pretrained epoch for a model already in torch context
Last model checkpoint epoch for a model loaded from file
the epoch related to a checkpoint matching or preceding the from_epoch
value if provided
The model pretraining metrics append on top of the parent metrics in the returned TabNet model.
TabNet uses torch
as its backend for computation and torch
uses all
available threads by default.
You can control the number of threads used by torch
with:
data("ames", package = "modeldata")
pretrained <- tabnet_pretrain(Sale_Price ~ ., data = ames, epochs = 1)