Fits the TabNet: Attentive Interpretable Tabular Learning model

```
tabnet_fit(x, ...)
# S3 method for default
tabnet_fit(x, ...)
# S3 method for data.frame
tabnet_fit(
x,
y,
tabnet_model = NULL,
config = tabnet_config(),
...,
from_epoch = NULL,
weights = NULL
)
# S3 method for formula
tabnet_fit(
formula,
data,
tabnet_model = NULL,
config = tabnet_config(),
...,
from_epoch = NULL,
weights = NULL
)
# S3 method for recipe
tabnet_fit(
x,
data,
tabnet_model = NULL,
config = tabnet_config(),
...,
from_epoch = NULL,
weights = NULL
)
# S3 method for Node
tabnet_fit(
x,
tabnet_model = NULL,
config = tabnet_config(),
...,
from_epoch = NULL
)
```

- x
Depending on the context:

A

**data frame**of predictors.A

**matrix**of predictors.A

**recipe**specifying a set of preprocessing steps created from`recipes::recipe()`

.

The predictor data should be standardized (e.g. centered or scaled). The model treats categorical predictors internally thus, you don't need to make any treatment.

- ...
Model hyperparameters. Any hyperparameters set here will update those set by the config argument. See

`tabnet_config()`

for a list of all possible hyperparameters.- y
When

`x`

is a**data frame**or**matrix**,`y`

is the outcome specified as:A

**data frame**with 1 or many numeric column (regression) or 1 or many categorical columns (classification) .A

**matrix**with 1 column.A

**vector**, either numeric or categorical.

- tabnet_model
A previously fitted TabNet model object to continue the fitting on. if

`NULL`

(the default) a brand new model is initialized.- config
A set of hyperparameters created using the

`tabnet_config`

function. If no argument is supplied, this will use the default values in`tabnet_config()`

.- from_epoch
When a

`tabnet_model`

is provided, restore the network weights from a specific epoch. Default is last available checkpoint for restored model, or last epoch for in-memory model.- weights
Unused.

- formula
A formula specifying the outcome terms on the left-hand side, and the predictor terms on the right-hand side.

- data
When a

**recipe**or**formula**is used,`data`

is specified as:A

**data frame**containing both the predictors and the outcome.

A TabNet model object. It can be used for serialization, predictions, or further fitting.

When providing a parent `tabnet_model`

parameter, the model fitting resumes from that model weights
at the following epoch:

last fitted epoch for a model already in torch context

Last model checkpoint epoch for a model loaded from file

the epoch related to a checkpoint matching or preceding the

`from_epoch`

value if provided The model fitting metrics append on top of the parent metrics in the returned TabNet model.

TabNet allows multi-outcome prediction, which is usually named multi-label classification or multi-output classification when outcomes are categorical. Multi-outcome currently expect outcomes to be either all numeric or all categorical.

TabNet uses `torch`

as its backend for computation and `torch`

uses all
available threads by default.

You can control the number of threads used by `torch`

with:

```
data("ames", package = "modeldata")
data("attrition", package = "modeldata")
ids <- sample(nrow(attrition), 256)
## Single-outcome regression using formula specification
fit <- tabnet_fit(Sale_Price ~ ., data = ames, epochs = 1)
## Single-outcome classification using data-frame specification
attrition_x <- attrition[,-which(names(attrition) == "Attrition")]
fit <- tabnet_fit(attrition_x, attrition$Attrition, epochs = 1, verbose = TRUE)
#> [Epoch 001] Loss: 0.995773
## Multi-outcome regression on `Sale_Price` and `Pool_Area` in `ames` dataset using formula,
ames_fit <- tabnet_fit(Sale_Price + Pool_Area ~ ., data = ames[ids,], epochs = 2, valid_split = 0.2)
## Multi-label classification on `Attrition` and `JobSatisfaction` in
## `attrition` dataset using recipe
library(recipes)
#>
#> Attaching package: ‘recipes’
#> The following object is masked from ‘package:stats’:
#>
#> step
rec <- recipe(Attrition + JobSatisfaction ~ ., data = attrition[ids,]) %>%
step_normalize(all_numeric(), -all_outcomes())
attrition_fit <- tabnet_fit(rec, data = attrition[ids,], epochs = 2, valid_split = 0.2)
## Hierarchical classification on `acme`
data(acme, package = "data.tree")
acme_fit <- tabnet_fit(acme, epochs = 2, verbose = TRUE)
#> [Epoch 001] Loss: 2.754614
#> [Epoch 002] Loss: 1.907745
# Note: Dataset number of rows and model number of epochs should be increased
# for publication-level results.
```