All functions

autoplot.tabnet_explain()

Plot tabnet_explain mask importance heatmap

autoplot.tabnet_fit() autoplot.tabnet_pretrain()

Plot tabnet_fit model loss along epochs

check_compliant_node()

Check that Node object names are compliant

nn_prune_head.tabnet_fit() nn_prune_head.tabnet_pretrain()

Prune top layer(s) of a tabnet network

node_to_df()

Turn a Node object into predictor and outcome.

tabnet()

Parsnip compatible tabnet model

tabnet_config()

Configuration for TabNet models

tabnet_explain()

Interpretation metrics from a TabNet model

tabnet_fit()

Tabnet model

tabnet_nn()

TabNet Model Architecture

cat_emb_dim() checkpoint_epochs() drop_last() encoder_activation() lr_scheduler() mlp_activation() mlp_hidden_multiplier() num_independent_decoder() num_shared_decoder() optimizer() penalty() verbose() virtual_batch_size()

Non-tunable parameters for the tabnet model

attention_width() decision_width() feature_reusage() momentum() mask_type() num_independent() num_shared() num_steps()

Parameters for the tabnet model

tabnet_pretrain()

Tabnet model