diff --git a/README.md b/README.md index 891c39b..8ff8f59 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ poetry install If you want to train the chosen model and save it afterwards, run: ``` -poetry run python3 mlopscourse/train.py --model_type [chosen_model] +poetry run python3 commands.py train --model_type [chosen_model] ``` The available models are Random Forest (from the scikit-learn library) and CatBoost. @@ -58,5 +58,5 @@ If you want to infer a previously trained model, make sure you've placed the che `checkpoints/` and then run ``` -poetry run python3 mlopscourse/infer.py --model_type [chosen_model] --ckpt [checkpoint_filename_with_extension] +poetry run python3 commands.py infer --model_type [chosen_model] --ckpt [checkpoint_filename_with_extension] ``` diff --git a/commands.py b/commands.py new file mode 100644 index 0000000..27ce8c2 --- /dev/null +++ b/commands.py @@ -0,0 +1,37 @@ +import fire + +from mlopscourse.infer import infer +from mlopscourse.train import train + + +def outer_train(model_type: str) -> None: + """ + Trains the chosen model on the train split of the dataset and saves the checkpoint. + + Parameters + ---------- + model_type : str + The type of model for training. Should be "rf" for RandomForest and "cb" + for CatBoost. + """ + train(model_type) + + +def outer_infer(model_type: str, ckpt: str) -> None: + """ + Runs the chosen model on the test set of the dataset and calculates the R^2 metric. + + Parameters + ---------- + model_type : str + The type of model that was used for training. Should be "rf" for RandomForest + and "cb" for CatBoost. + ckpt : str + The filename inside 'checkpoint/' to load the model from. Should also contain the + the filename extension. + """ + infer(model_type, ckpt) + + +if __name__ == "__main__": + fire.Fire() diff --git a/mlopscourse/infer.py b/mlopscourse/infer.py index df6f4a3..183680b 100644 --- a/mlopscourse/infer.py +++ b/mlopscourse/infer.py @@ -3,7 +3,7 @@ import fire -from data.prepare_dataset import prepare_dataset +from .data.prepare_dataset import prepare_dataset def infer(model_type: str, ckpt: str) -> None: diff --git a/mlopscourse/train.py b/mlopscourse/train.py index b93c8e2..fa3d2d1 100644 --- a/mlopscourse/train.py +++ b/mlopscourse/train.py @@ -2,8 +2,8 @@ import fire -from data.prepare_dataset import prepare_dataset -from models.models_zoo import prepare_model +from .data.prepare_dataset import prepare_dataset +from .models.models_zoo import prepare_model def train(model_type: str) -> None: @@ -23,7 +23,7 @@ def train(model_type: str) -> None: _, numerical_features, categorical_features, - ) = prepare_dataset() + ) = prepare_dataset(print_info=True) model = prepare_model(model_type, numerical_features, categorical_features) print(f"Training the {model_type} model...")