From ecbe8a0fcab933b7eca803769222b1cb250deb75 Mon Sep 17 00:00:00 2001 From: Svyatoslav Date: Sun, 8 Oct 2023 20:41:29 +0300 Subject: [PATCH] Added the only one entrypoint --- commands.py. Corrected imports. --- README.md | 4 ++-- commands.py | 37 +++++++++++++++++++++++++++++++++++++ mlopscourse/infer.py | 2 +- mlopscourse/train.py | 6 +++--- 4 files changed, 43 insertions(+), 6 deletions(-) create mode 100644 commands.py diff --git a/README.md b/README.md index 891c39b..8ff8f59 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ poetry install If you want to train the chosen model and save it afterwards, run: ``` -poetry run python3 mlopscourse/train.py --model_type [chosen_model] +poetry run python3 commands.py train --model_type [chosen_model] ``` The available models are Random Forest (from the scikit-learn library) and CatBoost. @@ -58,5 +58,5 @@ If you want to infer a previously trained model, make sure you've placed the che `checkpoints/` and then run ``` -poetry run python3 mlopscourse/infer.py --model_type [chosen_model] --ckpt [checkpoint_filename_with_extension] +poetry run python3 commands.py infer --model_type [chosen_model] --ckpt [checkpoint_filename_with_extension] ``` diff --git a/commands.py b/commands.py new file mode 100644 index 0000000..27ce8c2 --- /dev/null +++ b/commands.py @@ -0,0 +1,37 @@ +import fire + +from mlopscourse.infer import infer +from mlopscourse.train import train + + +def outer_train(model_type: str) -> None: + """ + Trains the chosen model on the train split of the dataset and saves the checkpoint. + + Parameters + ---------- + model_type : str + The type of model for training. Should be "rf" for RandomForest and "cb" + for CatBoost. + """ + train(model_type) + + +def outer_infer(model_type: str, ckpt: str) -> None: + """ + Runs the chosen model on the test set of the dataset and calculates the R^2 metric. + + Parameters + ---------- + model_type : str + The type of model that was used for training. Should be "rf" for RandomForest + and "cb" for CatBoost. + ckpt : str + The filename inside 'checkpoint/' to load the model from. Should also contain the + the filename extension. + """ + infer(model_type, ckpt) + + +if __name__ == "__main__": + fire.Fire() diff --git a/mlopscourse/infer.py b/mlopscourse/infer.py index df6f4a3..183680b 100644 --- a/mlopscourse/infer.py +++ b/mlopscourse/infer.py @@ -3,7 +3,7 @@ import fire -from data.prepare_dataset import prepare_dataset +from .data.prepare_dataset import prepare_dataset def infer(model_type: str, ckpt: str) -> None: diff --git a/mlopscourse/train.py b/mlopscourse/train.py index b93c8e2..fa3d2d1 100644 --- a/mlopscourse/train.py +++ b/mlopscourse/train.py @@ -2,8 +2,8 @@ import fire -from data.prepare_dataset import prepare_dataset -from models.models_zoo import prepare_model +from .data.prepare_dataset import prepare_dataset +from .models.models_zoo import prepare_model def train(model_type: str) -> None: @@ -23,7 +23,7 @@ def train(model_type: str) -> None: _, numerical_features, categorical_features, - ) = prepare_dataset() + ) = prepare_dataset(print_info=True) model = prepare_model(model_type, numerical_features, categorical_features) print(f"Training the {model_type} model...")