Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mitokic committed Oct 28, 2024
1 parent eacf6ea commit 990ba4c
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion R/ensemble_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ ensemble_models <- function(run_info,
avail_arg_list <- list(
"train_data" = prep_ensemble_tbl %>% dplyr::select(-Train_Test_ID),
"model_type" = "ensemble",
"pca" = FALSE,
"pca" = FALSE,
"multistep" = FALSE
)

Expand Down
4 changes: 2 additions & 2 deletions R/feature_selection.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ run_feature_selection <- function(input_data,

return(fs_list)
}

# check for multiple time series
if (length(unique(input_data$Combo)) > 1) {
global <- TRUE
Expand Down Expand Up @@ -103,7 +103,7 @@ run_feature_selection <- function(input_data,
) %>%
dplyr::select(Feature, Vote, Auto_Accept)

# don't run leave one feature out process for daily, weekly, or global model data
# don't run boruta for daily, weekly, or global model data
boruta_results <- tibble::tibble()
} else {
if (!fast) { # full implementation
Expand Down
2 changes: 1 addition & 1 deletion R/models.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ list_global_models <- function() {
#' @noRd
list_multivariate_models <- function() {
list <- c(
"cubist", "glmnet", "mars", "svm-poly", "svm-rbf", "xgboost",
"cubist", "glmnet", "mars", "svm-poly", "svm-rbf", "xgboost",
"arima-boost", "arimax", "prophet-boost", "prophet-xregs",
"nnetar-xregs"
)
Expand Down
16 changes: 8 additions & 8 deletions R/train_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ train_models <- function(run_info,
dplyr::select(Model_Workflow)

workflow <- workflow$Model_Workflow[[1]]

if (nrow(prep_data) > 500 & model == "xgboost") {
# update xgboost model to use 'hist' tree method to speed up training
workflow <- workflows::update_model(workflow,
Expand All @@ -429,8 +429,10 @@ train_models <- function(run_info,

if (combo_hash == "All-Data") {
# adjust column types to match original data
prep_data <- adjust_column_types(prep_data,
workflows::extract_recipe(workflow, estimated = FALSE))
prep_data <- adjust_column_types(
prep_data,
workflows::extract_recipe(workflow, estimated = FALSE)
)
}

if (feature_selection & model %in% fs_model_list) {
Expand Down Expand Up @@ -1013,16 +1015,16 @@ adjust_column_types <- function(data, recipe) {
expected_types <- recipe$var_info %>%
dplyr::select(variable, type) %>%
dplyr::mutate(type = purrr::map_chr(type, ~ .x[[1]]))

# Identify and coerce mismatched columns
for (i in seq_len(nrow(expected_types))) {
col_name <- expected_types$variable[i]
expected_type <- expected_types$type[i]

# Check if column exists and type mismatch
if (col_name %in% names(data)) {
actual_type <- class(data[[col_name]])[1]

# Convert if types are different
if (expected_type == "string" && actual_type != "character") {
data[[col_name]] <- as.character(data[[col_name]])
Expand All @@ -1035,5 +1037,3 @@ adjust_column_types <- function(data, recipe) {
}
return(data)
}


0 comments on commit 990ba4c

Please sign in to comment.