This is used to create torch::dataset()s for training the model, take care of target normalization and allow initializing the temporal_fusion_transformer() model, that requires a specification to be passed as its first argument.


tft_dataset_spec(x, ...)

spec_time_splits(spec, lookback, horizon, step = 1L)

spec_covariate_index(spec, index)

spec_covariate_key(spec, ...)

spec_covariate_known(spec, ...)

spec_covariate_unknown(spec, ...)

spec_covariate_static(spec, ...)



A recipe or data.frame that will be used to obtain statiscs for preparing the recipe and preparing the dataset.


Column names, selected using tidyselect. See <tidy-select> for more information.


A spec created with tft_dataset_spec().


Number of timesteps that are used as historic data for prediction.


Number of timesteps ahead that will be predicted by the model.


Number of steps between slices.


A column name that indexes the data. Usually a date column.


A tft_dataset_spec that you can add spec_ functions using the |> (pipe) prep() when done and transform() to obtain torch::dataset()s.


  • spec_time_splits: Sets lookback and horizon parameters.

  • spec_covariate_index: Sets the index column.

  • spec_covariate_key: Sets the keys - variables that define each time series

  • spec_covariate_known: Sets known time varying covariates.

  • spec_covariate_unknown: Sets unknown time varying covariates.

  • spec_covariate_static: Sets static covariates.


if (torch::torch_is_installed()) {
sales <- timetk::walmart_sales_weekly %>%
  dplyr::select(-id) %>%
  dplyr::filter(Store == 1, Dept %in% c(1,2))

rec <- recipes::recipe(Weekly_Sales ~ ., sales)

spec <- tft_dataset_spec(rec, sales) %>%
  spec_time_splits(lookback = 52, horizon = 4) %>%
  spec_covariate_index(Date) %>%
  spec_covariate_key(Store, Dept) %>%
  spec_covariate_static(Type, Size) %>%


spec <- prep(spec)
dataset <- transform(spec) # this is a torch dataset.
#> A <tft_dataset_spec> with:
#>  lookback = 52 and horizon = 4.
#> ── Covariates: 
#>  `index`: Date
#>  `keys`: <list: Store, Dept>
#>  `static`: <list: Type, Size>
#>  `known`: <list: starts_with("MarkDown")>
#> ! `unknown` is not set. Covariates that are not listed as other types are considered `unknown`.
#>  Call `prep()` to prepare the specification.
#> List of 2
#>  $ :List of 2
#>   ..$ encoder:List of 2
#>   .. ..$ past  :List of 2
#>   .. .. ..$ num:Float [1:52, 1:11]
#>   .. .. ..$ cat:Float [1:0, 1:0]
#>   .. ..$ static:List of 2
#>   .. .. ..$ num:Float [1:3]
#>   .. .. ..$ cat:Long [1:1]
#>   ..$ decoder:List of 2
#>   .. ..$ known :List of 2
#>   .. .. ..$ num:Float [1:4, 1:5]
#>   .. .. ..$ cat:Float [1:0, 1:0]
#>   .. ..$ target:List of 2
#>   .. .. ..$ num:Float [1:4, 1:1]
#>   .. .. ..$ cat:Float [1:0, 1:0]
#>  $ :Float [1:4, 1:1]