dplyr::glimpse(electricity::electricity_hourly)
rec <- recipe(consumption ~ ., data = electricity::electricity_hourly) %>%
update_role(date_hour, new_role = "index") %>%
update_role(client, new_role = "key") %>%
step_date(date_hour, role = "known", features = c("month", "dow")) %>%
step_mutate(date_hour_hour = as.factor(lubridate::hour(date_hour)), role = "known") %>%
step_mutate(
time_since_begining = as.numeric(difftime(
time1 = date_hour,
time2 = lubridate::ymd("2011-01-01"),
units = "hour"
)),
doy = as.factor(lubridate::yday(date_hour))
) %>%
step_normalize(time_since_begining) %>%
step_include_roles()
elec <- electricity::electricity_hourly
elec_train <- elec %>%
filter(date_hour <= (max(date_hour) - lubridate::days(14)))
elec_valid <- elec %>%
filter(date_hour > (max(date_hour) - lubridate::days(14)),
date_hour <= (max(date_hour) - lubridate::days(7)))
elec_test <- elec %>%
filter(date_hour > (max(date_hour) - lubridate::days(7)))
model <- tft(
rec,
elec_train,
lookback = 168,
horizon = 24,
subsample = 45000,
hidden_state_size = 160,
batch_size = 64,
learn_rate = 0.001,
gradient_clip_norm = 0.01,
num_workers = 8,
learn_rate_decay = FALSE,
callbacks = list(
luz::luz_callback_keep_best_model(monitor = "train_loss"),
luz::luz_callback_early_stopping(monitor = "train_loss", patience = 5,
min_delta = 0.001)
),
epochs = 100,
verbose = TRUE
)
predictions <- predict(
model2,
elec_test,
mode = "full"
)
preds <- predictions %>%
filter(.pred_at %in% (min(predictions$.pred_at, na.rm=TRUE) + lubridate::days(0:6))) %>%
mutate(client = as.character(client))
id <- 1
bind_rows(
preds %>%
filter(.pred_at == unique(preds$.pred_at)[id]) %>%
filter(client %in% sprintf("MT_%03d", 1:3)),
elec %>% filter(
date_hour >= (min(predictions$.pred_at) - lubridate::days(7)),
date_hour < unique(preds$.pred_at)[id]
) %>%
filter(client %in% sprintf("MT_%03d", 1:3))
) %>%
ggplot(aes(x = date_hour)) +
geom_line(aes(y = consumption)) +
geom_point(aes(y = consumption)) +
geom_line(aes(y = .pred), color = "red") +
geom_point(aes(y = .pred), color = "red") +
geom_ribbon(aes(ymin = .pred_lower, ymax = .pred_upper), alpha = 0.2) +
facet_wrap(~client, scales = "free_y", ncol = 1)