Skip to content

Commit

Permalink
Merge pull request #80 from microsoft/users/mitokic/pca-feature
Browse files Browse the repository at this point in the history
merging into Main
  • Loading branch information
mitokic authored Nov 12, 2021
2 parents f682185 + 8d1d050 commit cba6266
Show file tree
Hide file tree
Showing 17 changed files with 148 additions and 80 deletions.
6 changes: 2 additions & 4 deletions R/configure_forecast_run.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ get_fourier_periods <- function(fourier_periods,
#'
#' @param lag_periods lag_periods override
#' @param date_type year, quarter, month, week, day
#' @param forecast_horizon horion input from user
#' @param forecast_horizon horizon input from user
#'
#' @return Returns lag_periods
#' @noRd
Expand All @@ -49,9 +49,7 @@ get_lag_periods <- function(lag_periods,
"quarter" = c(1,2,3,4),
"month" = c(1, 2, 3, 6, 9, 12),
"week" = c(1, 2, 3, 4, 8, 12, 24, 48, 52),
"day" = c(1, 2, 3, 4, 5, 6, 7, 14,
21, 28, 28*2, 28*3, 28*6,
28*9, 28*12, 365)
"day" = c(7, 14, 21, 28, 60, 90, 180, 365)
)

oplist <- c(oplist,forecast_horizon)
Expand Down
74 changes: 45 additions & 29 deletions R/forecast_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ invoke_forecast_function <- function(fn_to_invoke,
date_rm_regex,
back_test_spacing,
fiscal_year_start,
model_type){
model_type,
pca){

exp_arg_list <- formalArgs(fn_to_invoke)

Expand All @@ -177,7 +178,8 @@ invoke_forecast_function <- function(fn_to_invoke,
'date_rm_regex' = date_rm_regex,
'back_test_spacing' = back_test_spacing,
'fiscal_year_start' = fiscal_year_start,
'model_type' = model_type)
'model_type' = model_type,
"pca" = pca)

avail_names <- names(avail_arg_list)

Expand Down Expand Up @@ -248,9 +250,10 @@ construct_forecast_models <- function(full_data_tbl,
back_test_scenarios,
date_regex,
fiscal_year_start,
seasonal_periods
){

seasonal_periods,
pca
){

forecast_models <- function(combo_value) {

cli::cli_h2("Running Combo: {combo_value}")
Expand Down Expand Up @@ -335,6 +338,8 @@ construct_forecast_models <- function(full_data_tbl,

cli::cli_h3("Individual Model Training")


# models to run
model_list <- get_model_functions(models_to_run,
models_not_to_run,
run_deep_learning)
Expand All @@ -347,7 +352,13 @@ construct_forecast_models <- function(full_data_tbl,

models_to_go_over <- names(model_list)


# PCA
if(sum(pca == TRUE) == 1 | (combo_value == "All-Data" & is.null(pca)) | (is.null(pca) & date_type %in% c("day", "week"))) {
run_pca <- TRUE
} else {
run_pca <- FALSE
}

for(model_name in models_to_go_over){

model_fn <- as.character(model_list[model_name])
Expand All @@ -372,7 +383,9 @@ construct_forecast_models <- function(full_data_tbl,
fiscal_year_start = fiscal_year_start,
tscv_inital = hist_periods_80,
date_rm_regex = date_regex,
model_type = "single"))
model_type = "single",
pca = run_pca))


try(combined_models_recipe_1 <- modeltime::add_modeltime_model(combined_models_recipe_1,
mdl_called,
Expand All @@ -392,25 +405,26 @@ construct_forecast_models <- function(full_data_tbl,
freq_val <- gluon_ts_frequency
add_name <- paste0(model_name,model_name_suffix)
}


try(mdl_called <- invoke_forecast_function(fn_to_invoke = model_fn,
train_data = train_data_recipe_1,
frequency = freq_val,
parallel = run_model_parallel,
horizon = forecast_horizon,
seasonal_period =seasonal_periods,
back_test_spacing = back_test_spacing,
fiscal_year_start = fiscal_year_start,
tscv_inital = hist_periods_80,
date_rm_regex = date_regex,
model_type = "single"))

try(combined_models_recipe_1 <- modeltime::add_modeltime_model(combined_models_recipe_1,
mdl_called,
location = "top") %>%
update_model_description(1, add_name),
silent = TRUE)

try(mdl_called <- invoke_forecast_function(fn_to_invoke = model_fn,
train_data = train_data_recipe_1,
frequency = freq_val,
parallel = run_model_parallel,
horizon = forecast_horizon,
seasonal_period =seasonal_periods,
back_test_spacing = back_test_spacing,
fiscal_year_start = fiscal_year_start,
tscv_inital = hist_periods_80,
date_rm_regex = date_regex,
model_type = "single",
pca = run_pca))

try(combined_models_recipe_1 <- modeltime::add_modeltime_model(combined_models_recipe_1,
mdl_called,
location = "top") %>%
update_model_description(1, add_name),
silent = TRUE)

}

if(model_name %in% r2_models & ("R2" %in% recipes_to_run | sum(recipes_to_run == "all") == 1 | (is.null(recipes_to_run) & date_type %in% c("month", "quarter", "year")))){
Expand All @@ -426,8 +440,9 @@ construct_forecast_models <- function(full_data_tbl,
fiscal_year_start = fiscal_year_start,
tscv_inital = hist_periods_80,
date_rm_regex = date_regex,
model_type = "single"))

model_type = "single",
pca = run_pca))

try(combined_models_recipe_2 <- modeltime::add_modeltime_model(combined_models_recipe_2,
mdl_called,
location = "top") %>%
Expand Down Expand Up @@ -630,7 +645,8 @@ construct_forecast_models <- function(full_data_tbl,
fiscal_year_start = fiscal_year_start,
tscv_inital = "1 year",
date_rm_regex = date_regex,
model_type = "ensemble"))
model_type = "ensemble",
pca = FALSE))

try(combined_ensemble_models <- modeltime::add_modeltime_model(combined_ensemble_models,
mdl_ensemble,
Expand Down
8 changes: 6 additions & 2 deletions R/forecast_time_series.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@
#' @param lag_periods List of values to use in creating lag features. Default of NULL automatically chooses these values
#' based on date_type.
#' @param rolling_window_periods List of values to use in creating rolling window features. Default of NULL automatically
#' chooses these values based on date_type.
#' chooses these values based on date type.
#' @param recipes_to_run List of recipes to run on multivariate models that can run different recipes. A value of NULL runs
#' all recipes, but only runs the R1 recipe for weekly and daily date types. A value of "all" runs all recipes, regardless
#' of date type. A list like c("R1") or c("R2") would only run models with the R1 or R2 recipe.
#' @param pca Run principle component analysis on any lagged features to speed up model run time. Default of NULL runs
#' PCA on day and week date types across all local multivariate models, and also for global models across all date types.
#' @param reticulate_environment File path to python environment to use when training gluonts deep learning models.
#' Only important when parallel_processing is not set to 'azure_batch'. Azure Batch should use its own docker image
#' that has python environment already installed.
Expand Down Expand Up @@ -116,6 +118,7 @@ forecast_time_series <- function(input_data,
lag_periods = NULL,
rolling_window_periods = NULL,
recipes_to_run = NULL,
pca = NULL,
reticulate_environment = NULL,
models_to_run = NULL,
models_not_to_run = NULL,
Expand Down Expand Up @@ -278,7 +281,8 @@ forecast_time_series <- function(input_data,
back_test_scenarios,
date_regex,
fiscal_year_start,
seasonal_periods)
seasonal_periods,
pca)

# * Run Forecast ----
if(forecast_approach == "bottoms_up" & length(unique(full_data_tbl$Combo)) > 1 & (sum(run_global_models == TRUE) == 1 | (is.null(run_global_models) & date_type %in% c("month", "quarter", "year"))) & run_local_models) {
Expand Down
Loading

0 comments on commit cba6266

Please sign in to comment.