Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
feat: parallel training to disk (#492)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff authored Apr 24, 2023
2 parents 4e02289 + a889c55 commit d7d1a66
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 22 deletions.
30 changes: 19 additions & 11 deletions src/psycop_model_training/application_modules/train_model/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Train a single model and evaluate it."""
from pathlib import Path
from typing import Optional

import wandb
from psycop_model_training.application_modules.wandb_handler import WandbHandler
Expand All @@ -21,16 +22,19 @@

def get_eval_dir(cfg: FullConfigSchema) -> Path:
"""Get the directory to save evaluation results to."""
if wandb.run is not None and cfg.project.wandb.mode != "offline":
eval_dir_path = (
SHARED_RESOURCES_PATH
/ cfg.project.name
/ "model_eval"
/ wandb.run.group
/ wandb.run.name
)
else:
# If online
ovartaci_path = (
SHARED_RESOURCES_PATH
/ cfg.project.name
/ "model_eval"
/ wandb.run.group # type: ignore
/ wandb.run.name # type: ignore
)

if cfg.project.wandb.group == "integration_testing":
eval_dir_path = PROJECT_ROOT / "tests" / "test_eval_results"
else:
eval_dir_path = ovartaci_path

eval_dir_path.mkdir(parents=True, exist_ok=True)

Expand All @@ -40,6 +44,7 @@ def get_eval_dir(cfg: FullConfigSchema) -> Path:
@wandb_alert_on_exception_return_terrible_auc
def post_wandb_setup_train_model(
cfg: FullConfigSchema,
override_output_dir: Optional[Path] = None,
) -> float:
"""Train a single model and evaluate it."""
eval_dir_path = get_eval_dir(cfg)
Expand All @@ -58,8 +63,10 @@ def post_wandb_setup_train_model(
n_splits=cfg.train.n_splits,
)

eval_dir = eval_dir_path if override_output_dir is None else override_output_dir

roc_auc = ModelEvaluator(
eval_dir_path=eval_dir_path,
eval_dir_path=eval_dir,
cfg=cfg,
pipe=pipe,
eval_ds=eval_dataset,
Expand All @@ -71,13 +78,14 @@ def post_wandb_setup_train_model(

def train_model(
cfg: FullConfigSchema,
override_output_dir: Optional[Path] = None,
) -> float:
"""Main function for training a single model."""
WandbHandler(cfg=cfg).setup_wandb()

# Try except block ensures process doesn't die in the case of an exception,
# but rather logs to wandb and starts another run with a new combination of
# hyperparameters
roc_auc = post_wandb_setup_train_model(cfg)
roc_auc = post_wandb_setup_train_model(cfg, override_output_dir=override_output_dir)

return roc_auc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from psycop_model_training.config_schemas.basemodel import BaseModel
from psycop_model_training.config_schemas.full_config import FullConfigSchema
from psycop_model_training.utils.utils import create_wandb_folders, flatten_nested_dict
from random_word import RandomWords


class WandbHandler:
Expand Down Expand Up @@ -51,11 +52,18 @@ def _get_cfg_as_dict(self) -> dict[str, Any]:

def setup_wandb(self):
"""Setup wandb for the current run."""
run_name = (
None
if self.cfg.project.wandb.mode != "offline"
else RandomWords().get_random_word()
)

wandb.init(
project=f"{self.cfg.project.name}-baseline-model-training",
reinit=True,
mode=self.cfg.project.wandb.mode,
group=self.cfg.project.wandb.group,
config=self._get_cfg_as_dict(),
entity=self.cfg.project.wandb.entity,
name=run_name,
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import dill as pkl
import pandas as pd
import wandb
from psycop_model_training.config_schemas.full_config import FullConfigSchema
from psycop_model_training.training_output.dataclasses import EvalDataset, PipeMetadata
from psycop_model_training.utils.utils import write_df_to_file
Expand Down Expand Up @@ -65,10 +66,41 @@ def eval_dataset_to_disk(eval_dataset: EvalDataset, file_path: Path) -> None:

write_df_to_file(df=df, file_path=file_path)

def save_run_performance_to_group_parquet(
self,
roc_auc: float,
cfg: FullConfigSchema,
):
# Get run performance row
lookahead_days = cfg.preprocessing.pre_split.min_lookahead_days

row = {
"run_name": wandb.run.name, # type: ignore
"roc_auc": roc_auc,
"timestamp": pd.Timestamp.now(),
"lookahead_days": lookahead_days,
"model_name": cfg.model.name,
}

# Append row to parquet file in group dir
run_group_path = self.dir_path.parent
run_performance_path = (
run_group_path / f"{cfg.model.name}_{lookahead_days}.parquet"
)

if run_performance_path.exists():
df = pd.read_parquet(run_performance_path)
df = df.append(row, ignore_index=True) # type: ignore
else:
df = pd.DataFrame([row])

df.to_parquet(run_performance_path, index=False)

def save(
self,
cfg: Optional[FullConfigSchema],
eval_dataset: Optional[EvalDataset],
roc_auc: float,
cfg: FullConfigSchema,
eval_dataset: EvalDataset,
pipe_metadata: Optional[PipeMetadata],
pipe: Optional[Pipeline],
) -> None:
Expand All @@ -93,6 +125,8 @@ def save(
if pipe is not None:
dump_to_pickle(pipe, self.dir_path / "pipe.pkl")

self.save_run_performance_to_group_parquet(roc_auc=roc_auc, cfg=cfg)

log.info( # pylint: disable=logging-fstring-interpolation
f"Saved evaluation dataset, cfg and pipe metadata to {self.dir_path}",
)
17 changes: 9 additions & 8 deletions src/psycop_model_training/training_output/model_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ class ModelEvaluator:
def _get_pipeline_metadata(self) -> PipeMetadata:
pipe_metadata = PipeMetadata()

if hasattr(self.pipe["model"], "feature_importances_"):
if hasattr(self.pipe["model"], "feature_importances_"): # type: ignore
pipe_metadata.feature_importances = get_feature_importance_dict(
pipe=self.pipe,
)

if "preprocessing" in self.pipe and hasattr(
self.pipe["preprocessing"].named_steps,
self.pipe["preprocessing"].named_steps, # type: ignore
"feature_selection",
):
pipe_metadata.selected_features = get_selected_features_dict(
Expand Down Expand Up @@ -77,16 +77,17 @@ def __init__(

def evaluate_and_save_eval_data(self) -> float:
"""Evaluate the model and save artifacts."""
roc_auc: float = roc_auc_score( # type: ignore
self.eval_ds.y,
self.eval_ds.y_hat_probs,
)

self.disk_saver.save(
cfg=self.cfg,
eval_dataset=self.eval_ds,
pipe=self.pipe,
pipe_metadata=self.pipeline_metadata,
)

roc_auc: float = roc_auc_score( # type: ignore
self.eval_ds.y,
self.eval_ds.y_hat_probs,
roc_auc=roc_auc,
)

wandb.log(
Expand All @@ -99,7 +100,7 @@ def evaluate_and_save_eval_data(self) -> float:
},
)

logging.info( # pylint: disable=logging-not-lazy,logging-fstring-interpolation
logging.info(
f"ROC AUC: {roc_auc}",
)

Expand Down
25 changes: 25 additions & 0 deletions tests/application_modules/test_model_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pathlib import Path

import pandas as pd
from psycop_model_training.application_modules.train_model.main import train_model
from psycop_model_training.config_schemas.full_config import FullConfigSchema


def test_saving_results_to_parquet(
muteable_test_config: FullConfigSchema,
tmp_path: Path,
):
"""Test that model performance is saved to a parquet file for querying."""
cfg = muteable_test_config

for _ in [0, 1]:
# Run twice to ensure that we can also append to a file
train_model(cfg, override_output_dir=tmp_path / "run_eval")

run_performance_path = list(tmp_path.glob(r"*.parquet"))[0]
run_performance_df = pd.read_parquet(run_performance_path)

for info in ["run_name", "roc_auc", "timestamp", "lookahead_days", "model_name"]:
assert info in run_performance_df.columns

assert len(run_performance_df["run_name"].unique()) == 2
4 changes: 3 additions & 1 deletion tests/application_modules/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
post_wandb_setup_train_model,
train_model,
)
from psycop_model_training.application_modules.wandb_handler import WandbHandler
from psycop_model_training.config_schemas.conf_utils import (
load_test_cfg_as_pydantic,
)
Expand Down Expand Up @@ -76,7 +77,8 @@ def test_self_healing_nan_select_percentile(muteable_test_config: FullConfigSche
cfg.preprocessing.post_split.feature_selection.name = "mutual_info_classif"

# Train without the wrapper
with pytest.raises(ValueError, match=r".*Input X contains NaN.*"):
with pytest.raises(ValueError, match=r".*Input X contains NaN.*"): # noqa
WandbHandler(cfg=cfg).setup_wandb()
post_wandb_setup_train_model.__wrapped__(cfg)

# Train with the wrapper
Expand Down

0 comments on commit d7d1a66

Please sign in to comment.