Skip to contents

This is used to create torch::dataset()s for training the model, take care of target normalization and allow initializing the temporal_fusion_transformer() model, that requires a specification to be passed as its first argument.

Usage

tft_dataset_spec(x, ...)

spec_time_splits(spec, lookback, horizon, step = 1L)

spec_covariate_index(spec, index)

spec_covariate_key(spec, ...)

spec_covariate_known(spec, ...)

spec_covariate_unknown(spec, ...)

spec_covariate_static(spec, ...)

Arguments

x

A recipe or data.frame that will be used to obtain statiscs for preparing the recipe and preparing the dataset.

...

Column names, selected using tidyselect. See <tidy-select> for more information.

spec

A spec created with tft_dataset_spec().

lookback

Number of timesteps that are used as historic data for prediction.

horizon

Number of timesteps ahead that will be predicted by the model.

step

Number of steps between slices.

index

A column name that indexes the data. Usually a date column.

Value

A tft_dataset_spec that you can add spec_ functions using the |> (pipe) prep() when done and transform() to obtain torch::dataset()s.

Functions

  • spec_time_splits: Sets lookback and horizon parameters.

  • spec_covariate_index: Sets the index column.

  • spec_covariate_key: Sets the keys - variables that define each time series

  • spec_covariate_known: Sets known time varying covariates.

  • spec_covariate_unknown: Sets unknown time varying covariates.

  • spec_covariate_static: Sets static covariates.

Examples

if (torch::torch_is_installed()) {
sales <- timetk::walmart_sales_weekly %>%
  dplyr::select(-id) %>%
  dplyr::filter(Store == 1, Dept %in% c(1,2))

rec <- recipes::recipe(Weekly_Sales ~ ., sales)

spec <- tft_dataset_spec(rec, sales) %>%
  spec_time_splits(lookback = 52, horizon = 4) %>%
  spec_covariate_index(Date) %>%
  spec_covariate_key(Store, Dept) %>%
  spec_covariate_static(Type, Size) %>%
  spec_covariate_known(starts_with("MarkDown"))

print(spec)

spec <- prep(spec)
dataset <- transform(spec) # this is a torch dataset.
str(dataset[1])
}
#> A <tft_dataset_spec> with:
#> 
#>  lookback = 52 and horizon = 4.
#> 
#> ── Covariates: 
#>  `index`: Date
#>  `keys`: <list: Store, Dept>
#>  `static`: <list: Type, Size>
#>  `known`: <list: starts_with("MarkDown")>
#> ! `unknown` is not set. Covariates that are not listed as other types are considered `unknown`.
#> 
#>  Call `prep()` to prepare the specification.
#> List of 2
#>  $ :List of 2
#>   ..$ encoder:List of 2
#>   .. ..$ past  :List of 2
#>   .. .. ..$ num:Float [1:52, 1:11]
#>   .. .. ..$ cat:Float [1:0, 1:0]
#>   .. ..$ static:List of 2
#>   .. .. ..$ num:Float [1:3]
#>   .. .. ..$ cat:Long [1:1]
#>   ..$ decoder:List of 2
#>   .. ..$ known :List of 2
#>   .. .. ..$ num:Float [1:4, 1:5]
#>   .. .. ..$ cat:Float [1:0, 1:0]
#>   .. ..$ target:List of 2
#>   .. .. ..$ num:Float [1:4, 1:1]
#>   .. .. ..$ cat:Float [1:0, 1:0]
#>  $ :Float [1:4, 1:1]