Temporal Fusion Transformer
tft.RdTemporal Fusion Transformer
Configuration for the Temporal Fusion Transformer network
Usage
tft(x, ...)
tft_config(
lookback,
horizon,
subsample = 1,
hidden_state_size = 16,
num_attention_heads = 4,
num_lstm_layers = 2,
dropout = 0.1,
batch_size = 256,
epochs = 5,
optimizer = "adam",
learn_rate = 0.01,
learn_rate_decay = c(0.1, 5),
gradient_clip_norm = 0.1,
quantiles = c(0.1, 0.5, 0.9),
num_workers = 0,
callbacks = list(),
verbose = FALSE
)
temporal_fusion_transformer(
mode = "regression",
lookback = NULL,
horizon = NULL,
hidden_state_size = NULL,
dropout = NULL,
learn_rate = NULL,
batch_size = NULL,
epochs = NULL
)Arguments
- x
A recipe containing
step_include_roles()as the last step. Can also be a data.frame, but expect it to have arecipeattribute attribute containing therecipethat generated it viarecipes::bake()orrecipes::juice().- ...
Additional arguments passed to
tft_config().- lookback
Number of timesteps that are used as historic data for prediction.
- horizon
Number of timesteps ahead that will be predicted by the model.
- subsample
Subsample from all possible slices. An integer with the number of samples or a proportion.
- hidden_state_size
Hidden size of network which is its main hyperparameter and can range from 8 to 512. It's also known as
d_modelacross the paper.- num_attention_heads
Number of attention heads in the Multi-head attention layer. The paper refer to it as
m_H.4is a good default.- num_lstm_layers
Number of LSTM layers used in the Locality Enhancement Layer. Usually 2 is good enough.
- dropout
Dropout rate used in many places in the architecture.
- batch_size
How many samples per batch to load.
- epochs
Maximum number of epochs for training the model.
- optimizer
Optimizer used for training. Can be a string with 'adam', 'sgd', or 'adagrad'. Can also be a
torch::optimizer().- learn_rate
Leaning rate used by the optimizer.
- learn_rate_decay
Decrease the learning rate by this factor each epoch. Can also be a vector with 2 elements. In this case we decrease the learning by the
x[1]everyx[2]epochs - (wherexis thelearn_rate_decayvector.) UseFALSEor any negative number to disable.- gradient_clip_norm
Maximum norm of the gradients. Passed on to
luz::luz_callback_gradient_clip(). If <= 0 orFALSEthen no gradient clipping is performed.- quantiles
A numeric vector with 3 quantiles for the quantile loss. The first is treated as lower bound of the interval, the second as the point prediction and the thir as the upper bound.
- num_workers
Number of parallel workers for preprocessing data.
- callbacks
Additional callbacks passed when fitting the module with luz.
- verbose
Logical value stating if the model should produce status outputs, like a progress bar, during training.
- mode
A single character string for the type of model. The only possible value for this model is "regression".
Functions
tft_config: Configuration configuration options for tft.temporal_fusion_transformer: Parsnip wrappers for TFT.
See also
predict.tft() for how to create predictions.