Skip to content

Commit

Permalink
adjust validation type for compete
Browse files Browse the repository at this point in the history
  • Loading branch information
pplonski committed Nov 30, 2020
1 parent 08fb410 commit f7105cd
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 36 deletions.
86 changes: 64 additions & 22 deletions supervised/base_automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
import logging
import traceback
import shutil
from tabulate import tabulate
from abc import ABC
from copy import deepcopy
Expand Down Expand Up @@ -461,8 +462,8 @@ def _handle_drastic_imbalance(self, X, y):
classes, cnts = np.unique(y, return_counts=True)
min_samples_per_class = 20
if self._validation_strategy is not None:
min_samples_per_class = self._validation_strategy.get(
"k_folds", min_samples_per_class
min_samples_per_class = max(
min_samples_per_class, self._validation_strategy.get("k_folds", 0)
)
for i in range(len(classes)):
if cnts[i] < min_samples_per_class:
Expand Down Expand Up @@ -625,31 +626,62 @@ def _apply_constraints(self):
if a in self._algorithms:
self._algorithms.remove(a)

# Change the validation type based on number of cells in the data
# cells = rows * cols
# Adjust the validation type based on speed of Decision Tree learning
if (
self._get_mode() == "Compete"
and self._total_time_limit is not None
and self.validation_strategy == "auto"
and self._validation_strategy["validation_type"]
!= "split" # split is the fastest validation type, no need to change
):
cells = self.n_rows_in_ * self.n_features_in_
if cells > 100e6:
self._validation_strategy = {
"validation_type": "split",
"train_ratio": 0.9,
"shuffle": True,
}
if self._get_ml_task() != REGRESSION:
self._validation_strategy["stratify"] = True
elif cells > 50e6:
self._validation_strategy = {
"validation_type": "kfold",
"k_folds": 5,
"shuffle": True,
}
if self._get_ml_task() != REGRESSION:
self._validation_strategy["stratify"] = True
# the validation will be adjusted after first Decision Tree learning on
# train/test split (1-fold)
self._adjust_validation = True
self._validation_strategy = self._fastest_validation()


def _fastest_validation(self):
strategy = {"validation_type": "split", "train_ratio": 0.9, "shuffle": True}
if self._get_ml_task() != REGRESSION:
strategy["stratify"] = True
return strategy

def _set_adjusted_validation(self):
if self._validation_strategy["validation_type"] != "split":
return
train_time = self._models[-1].get_train_time()
# the time of Decision Tree training multiply by 5.0
# to get the rough estimation how much time is needed for
# other algorithms
one_fold_time = train_time * 5.0
# it will be good to train at least 10 models
min_model_cnt = 10.0
# the number of folds we can afford during the training
folds_cnt = np.round(self._total_time_limit / one_fold_time / min_model_cnt)

# adjust the validation if possible
if folds_cnt >= 5.0:
self.verbose_print(f"Adjust validation. Remove: {self._model_paths[0]}")
k_folds = 5
if folds_cnt >= 15:
k_folds = 10
self._validation_strategy["validation_type"] = "kfold"
del self._validation_strategy["train_ratio"]
self._validation_strategy["k_folds"] = k_folds
self.tuner._validation_strategy = self._validation_strategy
shutil.rmtree(self._model_paths[0], ignore_errors=True)
del self._models[0]
del self._model_paths[0]
del self.tuner._unique_params_keys[0]
self._adjust_validation = False
cv = []
if self._validation_strategy.get("shuffle", False):
cv += ["Shuffle"]
if self._validation_strategy.get("stratify", False):
cv += ["Stratify"]

self.verbose_print(f"Validation strategy: {k_folds}-fold CV {','.join(cv)}")


def _fit(self, X, y):
"""Fits the AutoML model with data"""
Expand Down Expand Up @@ -685,6 +717,7 @@ def _fit(self, X, y):
self._top_models_to_improve = self._get_top_models_to_improve()
self._random_state = self._get_random_state()

self._adjust_validation = False
self._apply_constraints()

try:
Expand Down Expand Up @@ -733,6 +766,7 @@ def _fit(self, X, y):
self._features_selection,
self._train_ensemble,
self._stack_models,
self._adjust_validation,
self._random_state,
)
self.tuner = tuner
Expand Down Expand Up @@ -811,6 +845,10 @@ def _fit(self, X, y):
params["status"] = "trained" if trained else "skipped"
params["final_loss"] = self._models[-1].get_final_loss()
params["train_time"] = self._models[-1].get_train_time()

if self._adjust_validation and len(self._models) == 1:
self._set_adjusted_validation()

except Exception as e:
self._update_errors_report(
params.get("name"), str(e) + "\n" + traceback.format_exc()
Expand Down Expand Up @@ -1351,7 +1389,11 @@ def _validate_eval_metric(self):
Use 'logloss'"
)

elif self._get_ml_task() == REGRESSION and self.eval_metric not in ["rmse", "mse", "mae"]:
elif self._get_ml_task() == REGRESSION and self.eval_metric not in [
"rmse",
"mse",
"mae",
]:
raise ValueError(
f"Metric {self.eval_metric} is not allowed in ML task: {self._get_ml_task()}. \
Use 'rmse'"
Expand Down
45 changes: 33 additions & 12 deletions supervised/tuner/mljar_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
features_selection,
train_ensemble,
stack_models,
adjust_validation,
seed,
):
logger.debug("MljarTuner.__init__")
Expand All @@ -49,23 +50,18 @@ def __init__(
self._features_selection = features_selection
self._train_ensemble = train_ensemble
self._stack_models = stack_models
self._adjust_validation = adjust_validation
self._seed = seed

self._unique_params_keys = []

def steps(self):

all_steps = [
"simple_algorithms",
"default_algorithms",
# "not_so_random",
# "golden_features",
# "features_selection",
# "hill_climbing",
# "ensemble",
# "stack",
# "ensemble_stack",
]
all_steps = []
if self._adjust_validation:
all_steps += ["adjust_validation"]

all_steps += ["simple_algorithms", "default_algorithms"]
if self._start_random_models > 1:
all_steps += ["not_so_random"]
if self._golden_features:
Expand All @@ -89,7 +85,9 @@ def get_model_name(self, model_type, models_cnt, special=""):
def generate_params(self, step, models, results_path, stacked_models):

models_cnt = len(models)
if step == "simple_algorithms":
if step == "adjust_validation":
return self.adjust_validation_params()
elif step == "simple_algorithms":
return self.simple_algorithms_params()
elif step == "default_algorithms":
return self.default_params(models_cnt)
Expand Down Expand Up @@ -188,6 +186,29 @@ def get_params_stack_models(self, stacked_models):
generated_params += [params]
return generated_params

def adjust_validation_params(self):
models_cnt = 0
generated_params = []
for model_type in ["Decision Tree"]:
models_to_check = 1

logger.info(f"Generate parameters for {model_type} (#{models_cnt + 1})")
params = self._get_model_params(model_type, seed=1)
if params is None:
continue

params["name"] = self.get_model_name(model_type, models_cnt + 1)
params["status"] = "initialized"
params["final_loss"] = None
params["train_time"] = None

unique_params_key = MljarTuner.get_params_key(params)
if unique_params_key not in self._unique_params_keys:
generated_params += [params]
self._unique_params_keys += [unique_params_key]
models_cnt += 1
return generated_params

def simple_algorithms_params(self):
models_cnt = 0
generated_params = []
Expand Down
2 changes: 1 addition & 1 deletion supervised/utils/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, params):
raise MetricException("Metric name not defined")
self.minimize_direction = self.name in [
"logloss",
"auc", # negative auc
"auc", # negative auc
"rmse",
"mae",
"mse",
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_tuner/test_hill_climbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ def test_hill_climbing(self):
validation_strategy={},
explain_level=2,
data_info={"columns_info": [], "target_info": []},
seed=12,
golden_features=False,
features_selection=False,
train_ensemble=False,
stack_models=False,
adjust_validation=False,
seed=12,
)
ind = 121
score = 0.1
Expand Down

0 comments on commit f7105cd

Please sign in to comment.