Skip to content

Commit

Permalink
Added the only one entrypoint --- commands.py. Corrected imports.
Browse files Browse the repository at this point in the history
  • Loading branch information
TopCoder2K committed Oct 8, 2023
1 parent ec4148b commit ecbe8a0
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 6 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ poetry install
If you want to train the chosen model and save it afterwards, run:

```
poetry run python3 mlopscourse/train.py --model_type [chosen_model]
poetry run python3 commands.py train --model_type [chosen_model]
```

The available models are Random Forest (from the scikit-learn library) and CatBoost.
Expand All @@ -58,5 +58,5 @@ If you want to infer a previously trained model, make sure you've placed the che
`checkpoints/` and then run

```
poetry run python3 mlopscourse/infer.py --model_type [chosen_model] --ckpt [checkpoint_filename_with_extension]
poetry run python3 commands.py infer --model_type [chosen_model] --ckpt [checkpoint_filename_with_extension]
```
37 changes: 37 additions & 0 deletions commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import fire

from mlopscourse.infer import infer
from mlopscourse.train import train


def outer_train(model_type: str) -> None:
"""
Trains the chosen model on the train split of the dataset and saves the checkpoint.
Parameters
----------
model_type : str
The type of model for training. Should be "rf" for RandomForest and "cb"
for CatBoost.
"""
train(model_type)


def outer_infer(model_type: str, ckpt: str) -> None:
"""
Runs the chosen model on the test set of the dataset and calculates the R^2 metric.
Parameters
----------
model_type : str
The type of model that was used for training. Should be "rf" for RandomForest
and "cb" for CatBoost.
ckpt : str
The filename inside 'checkpoint/' to load the model from. Should also contain the
the filename extension.
"""
infer(model_type, ckpt)


if __name__ == "__main__":
fire.Fire()
2 changes: 1 addition & 1 deletion mlopscourse/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import fire

from data.prepare_dataset import prepare_dataset
from .data.prepare_dataset import prepare_dataset


def infer(model_type: str, ckpt: str) -> None:
Expand Down
6 changes: 3 additions & 3 deletions mlopscourse/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import fire

from data.prepare_dataset import prepare_dataset
from models.models_zoo import prepare_model
from .data.prepare_dataset import prepare_dataset
from .models.models_zoo import prepare_model


def train(model_type: str) -> None:
Expand All @@ -23,7 +23,7 @@ def train(model_type: str) -> None:
_,
numerical_features,
categorical_features,
) = prepare_dataset()
) = prepare_dataset(print_info=True)
model = prepare_model(model_type, numerical_features, categorical_features)

print(f"Training the {model_type} model...")
Expand Down

0 comments on commit ecbe8a0

Please sign in to comment.