Creates a TFT data specification
tft_dataset_spec.Rd
This is used to create torch::dataset()
s for training the model,
take care of target normalization and allow initializing the
temporal_fusion_transformer()
model, that requires a specification
to be passed as its first argument.
Usage
tft_dataset_spec(x, ...)
spec_time_splits(spec, lookback, horizon, step = 1L)
spec_covariate_index(spec, index)
spec_covariate_key(spec, ...)
spec_covariate_known(spec, ...)
spec_covariate_unknown(spec, ...)
spec_covariate_static(spec, ...)
Arguments
- x
A recipe or data.frame that will be used to obtain statiscs for preparing the recipe and preparing the dataset.
- ...
Column names, selected using tidyselect. See <
tidy-select
> for more information.- spec
A spec created with
tft_dataset_spec()
.- lookback
Number of timesteps that are used as historic data for prediction.
- horizon
Number of timesteps ahead that will be predicted by the model.
- step
Number of steps between slices.
- index
A column name that indexes the data. Usually a date column.
Value
A tft_dataset_spec
that you can add spec_
functions using the |>
(pipe)
prep()
when done and transform()
to obtain torch::dataset()
s.
Functions
spec_time_splits
: Setslookback
andhorizon
parameters.spec_covariate_index
: Sets theindex
column.spec_covariate_key
: Sets thekeys
- variables that define each time seriesspec_covariate_known
: Setsknown
time varying covariates.spec_covariate_unknown
: Setsunknown
time varying covariates.spec_covariate_static
: Setsstatic
covariates.
Examples
if (torch::torch_is_installed()) {
sales <- timetk::walmart_sales_weekly %>%
dplyr::select(-id) %>%
dplyr::filter(Store == 1, Dept %in% c(1,2))
rec <- recipes::recipe(Weekly_Sales ~ ., sales)
spec <- tft_dataset_spec(rec, sales) %>%
spec_time_splits(lookback = 52, horizon = 4) %>%
spec_covariate_index(Date) %>%
spec_covariate_key(Store, Dept) %>%
spec_covariate_static(Type, Size) %>%
spec_covariate_known(starts_with("MarkDown"))
print(spec)
spec <- prep(spec)
dataset <- transform(spec) # this is a torch dataset.
str(dataset[1])
}
#> A <tft_dataset_spec> with:
#>
#> ✔ lookback = 52 and horizon = 4.
#>
#> ── Covariates:
#> ✔ `index`: Date
#> ✔ `keys`: <list: Store, Dept>
#> ✔ `static`: <list: Type, Size>
#> ✔ `known`: <list: starts_with("MarkDown")>
#> ! `unknown` is not set. Covariates that are not listed as other types are considered `unknown`.
#>
#> ℹ Call `prep()` to prepare the specification.
#> List of 2
#> $ :List of 2
#> ..$ encoder:List of 2
#> .. ..$ past :List of 2
#> .. .. ..$ num:Float [1:52, 1:11]
#> .. .. ..$ cat:Float [1:0, 1:0]
#> .. ..$ static:List of 2
#> .. .. ..$ num:Float [1:3]
#> .. .. ..$ cat:Long [1:1]
#> ..$ decoder:List of 2
#> .. ..$ known :List of 2
#> .. .. ..$ num:Float [1:4, 1:5]
#> .. .. ..$ cat:Float [1:0, 1:0]
#> .. ..$ target:List of 2
#> .. .. ..$ num:Float [1:4, 1:1]
#> .. .. ..$ cat:Float [1:0, 1:0]
#> $ :Float [1:4, 1:1]