Skip to contents

All functions

autoplot.tabnet_explain()
Plot tabnet_explain mask importance heatmap
autoplot.tabnet_fit() autoplot.tabnet_pretrain()
Plot tabnet_fit model loss along epochs
check_compliant_node()
Check that Node object names are compliant
entmax() entmax15()
Alpha-entmax
get_tau()
Optimal threshold (tau) computation for 1.5-entmax
nn_aum_loss()
AUM loss
nn_prune_head(<tabnet_fit>) nn_prune_head(<tabnet_pretrain>)
Prune top layer(s) of a tabnet network
node_to_df()
Turn a Node object into predictor and outcome.
sparsemax() sparsemax15()
Sparsemax
tabnet()
Parsnip compatible tabnet model
tabnet_config()
Configuration for TabNet models
tabnet_explain()
Interpretation metrics from a TabNet model
tabnet_fit()
Tabnet model
tabnet_nn()
TabNet Model Architecture
cat_emb_dim() checkpoint_epochs() drop_last() encoder_activation() lr_scheduler() mlp_activation() mlp_hidden_multiplier() num_independent_decoder() num_shared_decoder() optimizer() penalty() verbose() virtual_batch_size()
Non-tunable parameters for the tabnet model
attention_width() decision_width() feature_reusage() momentum() mask_type() num_independent() num_shared() num_steps()
Parameters for the tabnet model
tabnet_pretrain()
Tabnet model