Temporal Fusion transformer
temporal_fusion_transformer.Rd
Temporal Fusion transformer
Configuration for the tft model
Usage
temporal_fusion_transformer(spec, ...)
tft_config(
hidden_state_size = 16,
num_attention_heads = 4,
num_lstm_layers = 2,
dropout = 0.1,
optimizer = "adam",
learn_rate = 0.01,
quantiles = c(0.1, 0.5, 0.9)
)
Arguments
- spec
A spec created with
tft_dataset_spec()
. This is required because the model depends on some information that is created/defined in the dataset.- ...
Additional parameters passed to
tft_config()
.- hidden_state_size
Hidden size of network which is its main hyperparameter and can range from 8 to 512. It's also known as
d_model
across the paper.- num_attention_heads
Number of attention heads in the Multi-head attention layer. The paper refer to it as
m_H
.4
is a good default.- num_lstm_layers
Number of LSTM layers used in the Locality Enhancement Layer. Usually 2 is good enough.
- dropout
Dropout rate used in many places in the architecture.
- optimizer
Optimizer used for training. Can be a string with 'adam', 'sgd', or 'adagrad'. Can also be a
torch::optimizer()
.- learn_rate
Leaning rate used by the optimizer.
- quantiles
A numeric vector with 3 quantiles for the quantile loss. The first is treated as lower bound of the interval, the second as the point prediction and the thir as the upper bound.
Functions
temporal_fusion_transformer
: Create the tft moduletft_config
: Configuration for the Temporal Fusion Transformer
See also
fit.luz_module_generator()
for fit arguments. See
predict.tft_result()
for information on how to generate predictions and
forecast.tft_result()
for forecasting.