Skip to content

Commit

Permalink
Fixed bug with converting LabelEncoder to ONNX
Browse files Browse the repository at this point in the history
  • Loading branch information
TopCoder2K committed Dec 12, 2023
1 parent aff1f38 commit 38bdf23
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 9 deletions.
6 changes: 5 additions & 1 deletion mlopscourse/data/prepare_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ def prepare_dataset(print_info: bool = True) -> None:
# Because of this rare category, we collapse it into "rain".
X["weather"].replace(to_replace="heavy_rain", value="rain", inplace=True)

# We can see that we have data from two years. We use the first year
# Since ONNX LabelEncoder doesn't support booleans, boolean columns must be
# converted to integer columns
X.replace({"False": 0, "True": 1}, inplace=True)

# The dataset contains data from two years. We use the first year
# to train the model and the second year to test the model.
mask_training = X["year"] == 0
X = X.drop(columns=["year"])
Expand Down
4 changes: 2 additions & 2 deletions mlopscourse/data/test_split.csv.dvc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
outs:
- md5: 8504078666337f0330b6e6aeb2321d7f
size: 587277
- md5: 337984495de9ee90281613b25bba93df
size: 523620
hash: md5
path: test_split.csv
4 changes: 2 additions & 2 deletions mlopscourse/data/train_split.csv.dvc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
outs:
- md5: 990eb9f38a19ff65662e2082b8ae6221
size: 573490
- md5: 37f6c642fdde95926945240a5f573415
size: 510480
hash: md5
path: train_split.csv
8 changes: 5 additions & 3 deletions mlopscourse/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import mlflow
from hydra import compose
from omegaconf import DictConfig, OmegaConf
from skl2onnx import to_onnx

from .data.prepare_dataset import load_dataset
from .models.models_zoo import prepare_model
Expand Down Expand Up @@ -64,9 +65,10 @@ def train(self) -> None:
)
model.log_fis_and_metrics(exp_id, X_train.columns)
else:
mlflow.sklearn.save_model(
model.model,
f"checkpoints/mlflow_{self.cfg.model.name}_ckpt/",
model_onnx = to_onnx(model.model, X=X_train.iloc[:1], verbose=1)
mlflow.onnx.save_model(
model_onnx,
f"checkpoints/mlflow_{self.cfg.model.name}_onnx_ckpt/",
signature=signature,
)
model.log_fis_and_metrics(exp_id, X_train, y_train)
Expand Down
201 changes: 200 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ catboost = "^1.2.2"
fire = "^0.5.0"
hydra-core = "^1.3.2"
mlflow = "^2.8.1"
skl2onnx = "^1.15.0"
onnxruntime = "^1.16.3"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.4.0"
Expand Down

0 comments on commit 38bdf23

Please sign in to comment.