Creates a TFT data specification
tft_dataset_spec.RdThis is used to create torch::dataset()s for training the model,
take care of target normalization and allow initializing the
temporal_fusion_transformer() model, that requires a specification
to be passed as its first argument.
Usage
tft_dataset_spec(x, ...)
spec_time_splits(spec, lookback, horizon, step = 1L)
spec_covariate_index(spec, index)
spec_covariate_key(spec, ...)
spec_covariate_known(spec, ...)
spec_covariate_unknown(spec, ...)
spec_covariate_static(spec, ...)Arguments
- x
A recipe or data.frame that will be used to obtain statiscs for preparing the recipe and preparing the dataset.
- ...
Column names, selected using tidyselect. See <
tidy-select> for more information.- spec
A spec created with
tft_dataset_spec().- lookback
Number of timesteps that are used as historic data for prediction.
- horizon
Number of timesteps ahead that will be predicted by the model.
- step
Number of steps between slices.
- index
A column name that indexes the data. Usually a date column.
Value
A tft_dataset_spec that you can add spec_ functions using the |> (pipe)
prep() when done and transform() to obtain torch::dataset()s.
Functions
spec_time_splits: Setslookbackandhorizonparameters.spec_covariate_index: Sets theindexcolumn.spec_covariate_key: Sets thekeys- variables that define each time seriesspec_covariate_known: Setsknowntime varying covariates.spec_covariate_unknown: Setsunknowntime varying covariates.spec_covariate_static: Setsstaticcovariates.
Examples
if (torch::torch_is_installed()) {
sales <- timetk::walmart_sales_weekly %>%
dplyr::select(-id) %>%
dplyr::filter(Store == 1, Dept %in% c(1,2))
rec <- recipes::recipe(Weekly_Sales ~ ., sales)
spec <- tft_dataset_spec(rec, sales) %>%
spec_time_splits(lookback = 52, horizon = 4) %>%
spec_covariate_index(Date) %>%
spec_covariate_key(Store, Dept) %>%
spec_covariate_static(Type, Size) %>%
spec_covariate_known(starts_with("MarkDown"))
print(spec)
spec <- prep(spec)
dataset <- transform(spec) # this is a torch dataset.
str(dataset[1])
}
#> A <tft_dataset_spec> with:
#>
#> ✔ lookback = 52 and horizon = 4.
#>
#> ── Covariates:
#> ✔ `index`: Date
#> ✔ `keys`: <list: Store, Dept>
#> ✔ `static`: <list: Type, Size>
#> ✔ `known`: <list: starts_with("MarkDown")>
#> ! `unknown` is not set. Covariates that are not listed as other types are considered `unknown`.
#>
#> ℹ Call `prep()` to prepare the specification.
#> List of 2
#> $ :List of 2
#> ..$ encoder:List of 2
#> .. ..$ past :List of 2
#> .. .. ..$ num:Float [1:52, 1:11]
#> .. .. ..$ cat:Float [1:0, 1:0]
#> .. ..$ static:List of 2
#> .. .. ..$ num:Float [1:3]
#> .. .. ..$ cat:Long [1:1]
#> ..$ decoder:List of 2
#> .. ..$ known :List of 2
#> .. .. ..$ num:Float [1:4, 1:5]
#> .. .. ..$ cat:Float [1:0, 1:0]
#> .. ..$ target:List of 2
#> .. .. ..$ num:Float [1:4, 1:1]
#> .. .. ..$ cat:Float [1:0, 1:0]
#> $ :Float [1:4, 1:1]