Skip to contents

All functions

autoplot.tabnet_explain()
Plot tabnet_explain mask importance heatmap
autoplot.tabnet_fit() autoplot.tabnet_pretrain()
Plot tabnet_fit model loss along epochs
build_ancestor_matrix_from_outcomes()
Build ancestor matrix aligned with observed outcome classes
check_compliant_node()
Check that Node object names are compliant
entmax() entmax15()
Alpha-entmax
get_constr_output()
Apply hierarchy constraints via max-pooling over descendants (MCM)
get_tau()
Optimal threshold (tau) computation for 1.5-entmax
nn_aum_loss()
AUM loss
nn_mc_loss()
Max-Constraint Margin Loss (module)
nn_prune_head(<tabnet_fit>) nn_prune_head(<tabnet_pretrain>)
Prune top layer(s) of a tabnet network
nnf_mc_loss()
Max-Constraint Margin Loss (functional)
nnf_multilabel_one_hot()
Convert class_id tensor to binary one-hot tensor
node_to_df()
Turn a Node object into predictor and outcome.
predict(<tabnet_fit>) augment(<tabnet_fit>)
Predict using tabnet
sparsemax() sparsemax15()
Sparsemax
tabnet()
Parsnip compatible tabnet model
tabnet_config()
Configuration for TabNet models
tabnet_explain()
Interpretation metrics from a TabNet model
tabnet_fit()
Tabnet model
tabnet_nn()
TabNet Model Architecture
cat_emb_dim() checkpoint_epochs() drop_last() encoder_activation() lr_scheduler() mlp_activation() mlp_hidden_multiplier() num_independent_decoder() num_shared_decoder() optimizer() penalty() verbose() virtual_batch_size()
Non-tunable parameters for the tabnet model
attention_width() decision_width() feature_reusage() momentum() mask_type() num_independent() num_shared() num_steps()
Parameters for the tabnet model
tabnet_pretrain()
Tabnet model