Skip to contents

Temporal Fusion transformer

Configuration for the tft model

Usage

temporal_fusion_transformer(spec, ...)

tft_config(
  hidden_state_size = 16,
  num_attention_heads = 4,
  num_lstm_layers = 2,
  dropout = 0.1,
  optimizer = "adam",
  learn_rate = 0.01,
  quantiles = c(0.1, 0.5, 0.9)
)

Arguments

spec

A spec created with tft_dataset_spec(). This is required because the model depends on some information that is created/defined in the dataset.

...

Additional parameters passed to tft_config().

hidden_state_size

Hidden size of network which is its main hyperparameter and can range from 8 to 512. It's also known as d_model across the paper.

num_attention_heads

Number of attention heads in the Multi-head attention layer. The paper refer to it as m_H. 4 is a good default.

num_lstm_layers

Number of LSTM layers used in the Locality Enhancement Layer. Usually 2 is good enough.

dropout

Dropout rate used in many places in the architecture.

optimizer

Optimizer used for training. Can be a string with 'adam', 'sgd', or 'adagrad'. Can also be a torch::optimizer().

learn_rate

Leaning rate used by the optimizer.

quantiles

A numeric vector with 3 quantiles for the quantile loss. The first is treated as lower bound of the interval, the second as the point prediction and the thir as the upper bound.

Value

A luz_module that has been setup and is ready to be fitted.

Functions

  • temporal_fusion_transformer: Create the tft module

  • tft_config: Configuration for the Temporal Fusion Transformer

See also

fit.luz_module_generator() for fit arguments. See predict.tft_result() for information on how to generate predictions and forecast.tft_result() for forecasting.