Skip to contents
dplyr::glimpse(electricity::electricity_hourly)
rec <- recipe(consumption ~ ., data = electricity::electricity_hourly) %>% 
  update_role(date_hour, new_role = "index") %>% 
  update_role(client, new_role = "key") %>% 
  step_date(date_hour, role = "known", features = c("month", "dow")) %>% 
  step_mutate(date_hour_hour = as.factor(lubridate::hour(date_hour)), role = "known") %>% 
  step_mutate(
    time_since_begining = as.numeric(difftime(
      time1 = date_hour, 
      time2 = lubridate::ymd("2011-01-01"), 
      units = "hour"
    )),
    doy = as.factor(lubridate::yday(date_hour))
  ) %>% 
  step_normalize(time_since_begining) %>% 
  step_include_roles()
elec <- electricity::electricity_hourly

elec_train <- elec %>% 
  filter(date_hour <= (max(date_hour) - lubridate::days(14)))

elec_valid <- elec %>% 
  filter(date_hour > (max(date_hour) - lubridate::days(14)),
         date_hour <= (max(date_hour) - lubridate::days(7)))
  
elec_test <- elec %>% 
  filter(date_hour > (max(date_hour) - lubridate::days(7)))
model <- tft(
  rec, 
  elec_train, 
  lookback = 168, 
  horizon = 24, 
  subsample = 45000,
  hidden_state_size = 160, 
  batch_size = 64, 
  learn_rate = 0.001, 
  gradient_clip_norm = 0.01,
  num_workers = 8,
  learn_rate_decay = FALSE,
  callbacks = list(
    luz::luz_callback_keep_best_model(monitor = "train_loss"),
    luz::luz_callback_early_stopping(monitor = "train_loss", patience = 5, 
                                     min_delta = 0.001)
  ),
  epochs = 100,
  verbose = TRUE
)
predictions <- predict(
  model2, 
  elec_test, 
  mode = "full"
)

preds <- predictions %>% 
  filter(.pred_at %in% (min(predictions$.pred_at, na.rm=TRUE) + lubridate::days(0:6))) %>% 
  mutate(client = as.character(client))

id <- 1

bind_rows(
  preds %>% 
    filter(.pred_at == unique(preds$.pred_at)[id]) %>% 
    filter(client %in% sprintf("MT_%03d", 1:3)), 
  elec %>% filter(
    date_hour >= (min(predictions$.pred_at) - lubridate::days(7)),
    date_hour < unique(preds$.pred_at)[id]
    ) %>% 
    filter(client %in% sprintf("MT_%03d", 1:3))
) %>% 
  ggplot(aes(x = date_hour)) +
  geom_line(aes(y = consumption)) +
  geom_point(aes(y = consumption)) +
  geom_line(aes(y = .pred), color = "red") +
  geom_point(aes(y = .pred), color = "red") +
  geom_ribbon(aes(ymin = .pred_lower, ymax = .pred_upper), alpha = 0.2) +
  facet_wrap(~client, scales = "free_y", ncol = 1)