Skip to contents

Temporal Fusion Transformer

Configuration for the Temporal Fusion Transformer network

Usage

tft(x, ...)

tft_config(
  lookback,
  horizon,
  subsample = 1,
  hidden_state_size = 16,
  num_attention_heads = 4,
  num_lstm_layers = 2,
  dropout = 0.1,
  batch_size = 256,
  epochs = 5,
  optimizer = "adam",
  learn_rate = 0.01,
  learn_rate_decay = c(0.1, 5),
  gradient_clip_norm = 0.1,
  quantiles = c(0.1, 0.5, 0.9),
  num_workers = 0,
  callbacks = list(),
  verbose = FALSE
)

temporal_fusion_transformer(
  mode = "regression",
  lookback = NULL,
  horizon = NULL,
  hidden_state_size = NULL,
  dropout = NULL,
  learn_rate = NULL,
  batch_size = NULL,
  epochs = NULL
)

Arguments

x

A recipe containing step_include_roles() as the last step. Can also be a data.frame, but expect it to have a recipe attribute attribute containing the recipe that generated it via recipes::bake() or recipes::juice().

...

Additional arguments passed to tft_config().

lookback

Number of timesteps that are used as historic data for prediction.

horizon

Number of timesteps ahead that will be predicted by the model.

subsample

Subsample from all possible slices. An integer with the number of samples or a proportion.

hidden_state_size

Hidden size of network which is its main hyperparameter and can range from 8 to 512. It's also known as d_model across the paper.

num_attention_heads

Number of attention heads in the Multi-head attention layer. The paper refer to it as m_H. 4 is a good default.

num_lstm_layers

Number of LSTM layers used in the Locality Enhancement Layer. Usually 2 is good enough.

dropout

Dropout rate used in many places in the architecture.

batch_size

How many samples per batch to load.

epochs

Maximum number of epochs for training the model.

optimizer

Optimizer used for training. Can be a string with 'adam', 'sgd', or 'adagrad'. Can also be a torch::optimizer().

learn_rate

Leaning rate used by the optimizer.

learn_rate_decay

Decrease the learning rate by this factor each epoch. Can also be a vector with 2 elements. In this case we decrease the learning by the x[1] every x[2] epochs - (where x is the learn_rate_decay vector.) Use FALSE or any negative number to disable.

gradient_clip_norm

Maximum norm of the gradients. Passed on to luz::luz_callback_gradient_clip(). If <= 0 or FALSE then no gradient clipping is performed.

quantiles

A numeric vector with 3 quantiles for the quantile loss. The first is treated as lower bound of the interval, the second as the point prediction and the thir as the upper bound.

num_workers

Number of parallel workers for preprocessing data.

callbacks

Additional callbacks passed when fitting the module with luz.

verbose

Logical value stating if the model should produce status outputs, like a progress bar, during training.

mode

A single character string for the type of model. The only possible value for this model is "regression".

Value

A list with the configuration parameters.

Functions

  • tft_config: Configuration configuration options for tft.

  • temporal_fusion_transformer: Parsnip wrappers for TFT.

See also

predict.tft() for how to create predictions.