From 38bdf231ab8e184109e04da86aad901f0b1008a8 Mon Sep 17 00:00:00 2001 From: Svyatoslav Date: Tue, 12 Dec 2023 11:27:25 +0300 Subject: [PATCH] Fixed bug with converting LabelEncoder to ONNX --- mlopscourse/data/prepare_dataset.py | 6 +- mlopscourse/data/test_split.csv.dvc | 4 +- mlopscourse/data/train_split.csv.dvc | 4 +- mlopscourse/train.py | 8 +- poetry.lock | 201 ++++++++++++++++++++++++++- pyproject.toml | 2 + 6 files changed, 216 insertions(+), 9 deletions(-) diff --git a/mlopscourse/data/prepare_dataset.py b/mlopscourse/data/prepare_dataset.py index 2735ec4..bcfcc62 100644 --- a/mlopscourse/data/prepare_dataset.py +++ b/mlopscourse/data/prepare_dataset.py @@ -13,7 +13,11 @@ def prepare_dataset(print_info: bool = True) -> None: # Because of this rare category, we collapse it into "rain". X["weather"].replace(to_replace="heavy_rain", value="rain", inplace=True) - # We can see that we have data from two years. We use the first year + # Since ONNX LabelEncoder doesn't support booleans, boolean columns must be + # converted to integer columns + X.replace({"False": 0, "True": 1}, inplace=True) + + # The dataset contains data from two years. We use the first year # to train the model and the second year to test the model. mask_training = X["year"] == 0 X = X.drop(columns=["year"]) diff --git a/mlopscourse/data/test_split.csv.dvc b/mlopscourse/data/test_split.csv.dvc index f0c1da7..539593f 100644 --- a/mlopscourse/data/test_split.csv.dvc +++ b/mlopscourse/data/test_split.csv.dvc @@ -1,5 +1,5 @@ outs: -- md5: 8504078666337f0330b6e6aeb2321d7f - size: 587277 +- md5: 337984495de9ee90281613b25bba93df + size: 523620 hash: md5 path: test_split.csv diff --git a/mlopscourse/data/train_split.csv.dvc b/mlopscourse/data/train_split.csv.dvc index 105ad92..bcfe87e 100644 --- a/mlopscourse/data/train_split.csv.dvc +++ b/mlopscourse/data/train_split.csv.dvc @@ -1,5 +1,5 @@ outs: -- md5: 990eb9f38a19ff65662e2082b8ae6221 - size: 573490 +- md5: 37f6c642fdde95926945240a5f573415 + size: 510480 hash: md5 path: train_split.csv diff --git a/mlopscourse/train.py b/mlopscourse/train.py index df20409..ee1356d 100644 --- a/mlopscourse/train.py +++ b/mlopscourse/train.py @@ -4,6 +4,7 @@ import mlflow from hydra import compose from omegaconf import DictConfig, OmegaConf +from skl2onnx import to_onnx from .data.prepare_dataset import load_dataset from .models.models_zoo import prepare_model @@ -64,9 +65,10 @@ def train(self) -> None: ) model.log_fis_and_metrics(exp_id, X_train.columns) else: - mlflow.sklearn.save_model( - model.model, - f"checkpoints/mlflow_{self.cfg.model.name}_ckpt/", + model_onnx = to_onnx(model.model, X=X_train.iloc[:1], verbose=1) + mlflow.onnx.save_model( + model_onnx, + f"checkpoints/mlflow_{self.cfg.model.name}_onnx_ckpt/", signature=signature, ) model.log_fis_and_metrics(exp_id, X_train, y_train) diff --git a/poetry.lock b/poetry.lock index 310d8c4..a16650b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -665,6 +665,23 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "coloredlogs" +version = "15.0.1" +description = "Colored terminal output for Python's logging module" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"}, + {file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"}, +] + +[package.dependencies] +humanfriendly = ">=9.1" + +[package.extras] +cron = ["capturer (>=2.4)"] + [[package]] name = "configobj" version = "5.0.8" @@ -1264,6 +1281,17 @@ Werkzeug = ">=3.0.0" async = ["asgiref (>=3.2)"] dotenv = ["python-dotenv"] +[[package]] +name = "flatbuffers" +version = "23.5.26" +description = "The FlatBuffers serialization format for Python" +optional = false +python-versions = "*" +files = [ + {file = "flatbuffers-23.5.26-py2.py3-none-any.whl", hash = "sha256:c0ff356da363087b915fde4b8b45bdda73432fc17cddb3c8157472eab1422ad1"}, + {file = "flatbuffers-23.5.26.tar.gz", hash = "sha256:9ea1144cac05ce5d86e2859f431c6cd5e66cd9c78c558317c7955fb8d4c78d89"}, +] + [[package]] name = "flatten-dict" version = "0.4.2" @@ -1758,6 +1786,20 @@ files = [ [package.dependencies] pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""} +[[package]] +name = "humanfriendly" +version = "10.0" +description = "Human friendly output for text interfaces using Python" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477"}, + {file = "humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc"}, +] + +[package.dependencies] +pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""} + [[package]] name = "hydra-core" version = "1.3.2" @@ -2284,6 +2326,23 @@ gateway = ["aiohttp (<4)", "boto3 (>=1.28.56,<2)", "fastapi (<1)", "pydantic (>= sqlserver = ["mlflow-dbstore"] xethub = ["mlflow-xethub"] +[[package]] +name = "mpmath" +version = "1.3.0" +description = "Python library for arbitrary-precision floating-point arithmetic" +optional = false +python-versions = "*" +files = [ + {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, + {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, +] + +[package.extras] +develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] +docs = ["sphinx"] +gmpy = ["gmpy2 (>=2.1.0a4)"] +tests = ["pytest (>=4.6)"] + [[package]] name = "multidict" version = "6.0.4" @@ -2489,6 +2548,105 @@ files = [ antlr4-python3-runtime = "==4.9.*" PyYAML = ">=5.1.0" +[[package]] +name = "onnx" +version = "1.15.0" +description = "Open Neural Network Exchange" +optional = false +python-versions = ">=3.8" +files = [ + {file = "onnx-1.15.0-cp310-cp310-macosx_10_12_universal2.whl", hash = "sha256:51cacb6aafba308aaf462252ced562111f6991cdc7bc57a6c554c3519453a8ff"}, + {file = "onnx-1.15.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:0aee26b6f7f7da7e840de75ad9195a77a147d0662c94eaa6483be13ba468ffc1"}, + {file = "onnx-1.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:baf6ef6c93b3b843edb97a8d5b3d229a1301984f3f8dee859c29634d2083e6f9"}, + {file = "onnx-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96ed899fe6000edc05bb2828863d3841cfddd5a7cf04c1a771f112e94de75d9f"}, + {file = "onnx-1.15.0-cp310-cp310-win32.whl", hash = "sha256:f1ad3d77fc2f4b4296f0ac2c8cadd8c1dcf765fc586b737462d3a0fe8f7c696a"}, + {file = "onnx-1.15.0-cp310-cp310-win_amd64.whl", hash = "sha256:ca4ebc4f47109bfb12c8c9e83dd99ec5c9f07d2e5f05976356c6ccdce3552010"}, + {file = "onnx-1.15.0-cp311-cp311-macosx_10_12_universal2.whl", hash = "sha256:233ffdb5ca8cc2d960b10965a763910c0830b64b450376da59207f454701f343"}, + {file = "onnx-1.15.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:51fa79c9ea9af033638ec51f9177b8e76c55fad65bb83ea96ee88fafade18ee7"}, + {file = "onnx-1.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f277d4861729f5253a51fa41ce91bfec1c4574ee41b5637056b43500917295ce"}, + {file = "onnx-1.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8a7c94d2ebead8f739fdb70d1ce5a71726f4e17b3e5b8ad64455ea1b2801a85"}, + {file = "onnx-1.15.0-cp311-cp311-win32.whl", hash = "sha256:17dcfb86a8c6bdc3971443c29b023dd9c90ff1d15d8baecee0747a6b7f74e650"}, + {file = "onnx-1.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:60a3e28747e305cd2e766e6a53a0a6d952cf9e72005ec6023ce5e07666676a4e"}, + {file = "onnx-1.15.0-cp38-cp38-macosx_10_12_universal2.whl", hash = "sha256:6b5c798d9e0907eaf319e3d3e7c89a2ed9a854bcb83da5fefb6d4c12d5e90721"}, + {file = "onnx-1.15.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:a4f774ff50092fe19bd8f46b2c9b27b1d30fbd700c22abde48a478142d464322"}, + {file = "onnx-1.15.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2b0e7f3938f2d994c34616bfb8b4b1cebbc4a0398483344fe5e9f2fe95175e6"}, + {file = "onnx-1.15.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49cebebd0020a4b12c1dd0909d426631212ef28606d7e4d49463d36abe7639ad"}, + {file = "onnx-1.15.0-cp38-cp38-win32.whl", hash = "sha256:1fdf8a3ff75abc2b32c83bf27fb7c18d6b976c9c537263fadd82b9560fe186fa"}, + {file = "onnx-1.15.0-cp38-cp38-win_amd64.whl", hash = "sha256:763e55c26e8de3a2dce008d55ae81b27fa8fb4acbb01a29b9f3c01f200c4d676"}, + {file = "onnx-1.15.0-cp39-cp39-macosx_10_12_universal2.whl", hash = "sha256:b2d5e802837629fc9c86f19448d19dd04d206578328bce202aeb3d4bedab43c4"}, + {file = "onnx-1.15.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:9a9cfbb5e5d5d88f89d0dfc9df5fb858899db874e1d5ed21e76c481f3cafc90d"}, + {file = "onnx-1.15.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f472bbe5cb670a0a4a4db08f41fde69b187a009d0cb628f964840d3f83524e9"}, + {file = "onnx-1.15.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bf2de9bef64792e5b8080c678023ac7d2b9e05d79a3e17e92cf6a4a624831d2"}, + {file = "onnx-1.15.0-cp39-cp39-win32.whl", hash = "sha256:ef4d9eb44b111e69e4534f3233fc2c13d1e26920d24ae4359d513bd54694bc6d"}, + {file = "onnx-1.15.0-cp39-cp39-win_amd64.whl", hash = "sha256:95d7a3e2d79d371e272e39ae3f7547e0b116d0c7f774a4004e97febe6c93507f"}, + {file = "onnx-1.15.0.tar.gz", hash = "sha256:b18461a7d38f286618ca2a6e78062a2a9c634ce498e631e708a8041b00094825"}, +] + +[package.dependencies] +numpy = "*" +protobuf = ">=3.20.2" + +[package.extras] +reference = ["Pillow", "google-re2"] + +[[package]] +name = "onnxconverter-common" +version = "1.13.0" +description = "ONNX Converter and Optimization Tools" +optional = false +python-versions = "*" +files = [ + {file = "onnxconverter-common-1.13.0.tar.gz", hash = "sha256:03db8a6033a3d6590f22df3f64234079caa826375d1fcb0b37b8123c06bf598c"}, + {file = "onnxconverter_common-1.13.0-py2.py3-none-any.whl", hash = "sha256:5ee1c025ef6c3b4abaede8425bc6b393248941a6cf8c21563d0d0e3f04634a0a"}, +] + +[package.dependencies] +numpy = "*" +onnx = "*" +packaging = "*" +protobuf = "*" + +[[package]] +name = "onnxruntime" +version = "1.16.3" +description = "ONNX Runtime is a runtime accelerator for Machine Learning models" +optional = false +python-versions = "*" +files = [ + {file = "onnxruntime-1.16.3-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:3bc41f323ac77acfed190be8ffdc47a6a75e4beeb3473fbf55eeb075ccca8df2"}, + {file = "onnxruntime-1.16.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:212741b519ee61a4822c79c47147d63a8b0ffde25cd33988d3d7be9fbd51005d"}, + {file = "onnxruntime-1.16.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f91f5497fe3df4ceee2f9e66c6148d9bfeb320cd6a71df361c66c5b8bac985a"}, + {file = "onnxruntime-1.16.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef2b1fc269cabd27f129fb9058917d6fdc89b188c49ed8700f300b945c81f889"}, + {file = "onnxruntime-1.16.3-cp310-cp310-win32.whl", hash = "sha256:f36b56a593b49a3c430be008c2aea6658d91a3030115729609ec1d5ffbaab1b6"}, + {file = "onnxruntime-1.16.3-cp310-cp310-win_amd64.whl", hash = "sha256:3c467eaa3d2429c026b10c3d17b78b7f311f718ef9d2a0d6938e5c3c2611b0cf"}, + {file = "onnxruntime-1.16.3-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:a225bb683991001d111f75323d355b3590e75e16b5e0f07a0401e741a0143ea1"}, + {file = "onnxruntime-1.16.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9aded21fe3d898edd86be8aa2eb995aa375e800ad3dfe4be9f618a20b8ee3630"}, + {file = "onnxruntime-1.16.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00cccc37a5195c8fca5011b9690b349db435986bd508eb44c9fce432da9228a4"}, + {file = "onnxruntime-1.16.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e253e572021563226a86f1c024f8f70cdae28f2fb1cc8c3a9221e8b1ce37db5"}, + {file = "onnxruntime-1.16.3-cp311-cp311-win32.whl", hash = "sha256:a82a8f0b4c978d08f9f5c7a6019ae51151bced9fd91e5aaa0c20a9e4ac7a60b6"}, + {file = "onnxruntime-1.16.3-cp311-cp311-win_amd64.whl", hash = "sha256:78d81d9af457a1dc90db9a7da0d09f3ccb1288ea1236c6ab19f0ca61f3eee2d3"}, + {file = "onnxruntime-1.16.3-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:04ebcd29c20473596a1412e471524b2fb88d55e6301c40b98dd2407b5911595f"}, + {file = "onnxruntime-1.16.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9996bab0f202a6435ab867bc55598f15210d0b72794d5de83712b53d564084ae"}, + {file = "onnxruntime-1.16.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b8f5083f903408238883821dd8c775f8120cb4a604166dbdabe97f4715256d5"}, + {file = "onnxruntime-1.16.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c2dcf1b70f8434abb1116fe0975c00e740722aaf321997195ea3618cc00558e"}, + {file = "onnxruntime-1.16.3-cp38-cp38-win32.whl", hash = "sha256:d4a0151e1accd04da6711f6fd89024509602f82c65a754498e960b032359b02d"}, + {file = "onnxruntime-1.16.3-cp38-cp38-win_amd64.whl", hash = "sha256:e8aa5bba78afbd4d8a2654b14ec7462ff3ce4a6aad312a3c2d2c2b65009f2541"}, + {file = "onnxruntime-1.16.3-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:6829dc2a79d48c911fedaf4c0f01e03c86297d32718a3fdee7a282766dfd282a"}, + {file = "onnxruntime-1.16.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:76f876c53bfa912c6c242fc38213a6f13f47612d4360bc9d599bd23753e53161"}, + {file = "onnxruntime-1.16.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4137e5d443e2dccebe5e156a47f1d6d66f8077b03587c35f11ee0c7eda98b533"}, + {file = "onnxruntime-1.16.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c56695c1a343c7c008b647fff3df44da63741fbe7b6003ef576758640719be7b"}, + {file = "onnxruntime-1.16.3-cp39-cp39-win32.whl", hash = "sha256:985a029798744ce4743fcf8442240fed35c8e4d4d30ec7d0c2cdf1388cd44408"}, + {file = "onnxruntime-1.16.3-cp39-cp39-win_amd64.whl", hash = "sha256:28ff758b17ce3ca6bcad3d936ec53bd7f5482e7630a13f6dcae518eba8f71d85"}, +] + +[package.dependencies] +coloredlogs = "*" +flatbuffers = "*" +numpy = ">=1.21.6" +packaging = "*" +protobuf = "*" +sympy = "*" + [[package]] name = "orjson" version = "3.9.10" @@ -3185,6 +3343,17 @@ files = [ [package.extras] diagrams = ["jinja2", "railroad-diagrams"] +[[package]] +name = "pyreadline3" +version = "3.4.1" +description = "A python implementation of GNU readline." +optional = false +python-versions = "*" +files = [ + {file = "pyreadline3-3.4.1-py3-none-any.whl", hash = "sha256:b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb"}, + {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"}, +] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -3632,6 +3801,22 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "skl2onnx" +version = "1.15.0" +description = "Convert scikit-learn models to ONNX" +optional = false +python-versions = "*" +files = [ + {file = "skl2onnx-1.15.0-py2.py3-none-any.whl", hash = "sha256:13a9ea5d50619ce42381c67001db8c87ce574a459a8f0738b45d2f4b93f465f6"}, + {file = "skl2onnx-1.15.0.tar.gz", hash = "sha256:05b2c2643ad0357ec1ea684d138438a2df657df828e57d07cb78c2e76be20e37"}, +] + +[package.dependencies] +onnx = ">=1.2.1" +onnxconverter-common = ">=1.7.0" +scikit-learn = ">=0.19" + [[package]] name = "smmap" version = "5.0.1" @@ -3766,6 +3951,20 @@ pygtrie = "*" dev = ["mypy (==0.971)", "pyinstaller", "pylint (==2.15.0)", "pytest (==7.2.0)", "pytest-benchmark", "pytest-cov (==3.0.0)", "pytest-mock (==3.8.2)", "pytest-sugar (==0.9.5)"] tests = ["mypy (==0.971)", "pyinstaller", "pylint (==2.15.0)", "pytest (==7.2.0)", "pytest-benchmark", "pytest-cov (==3.0.0)", "pytest-mock (==3.8.2)", "pytest-sugar (==0.9.5)"] +[[package]] +name = "sympy" +version = "1.12" +description = "Computer algebra system (CAS) in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, + {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, +] + +[package.dependencies] +mpmath = ">=0.19" + [[package]] name = "tabulate" version = "0.9.0" @@ -4155,4 +4354,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">= 3.9, < 3.13" -content-hash = "66409ad23983c1749fe99b550fc67c77e14ec77798dc3e922a35981a418c96ca" +content-hash = "57b2a352e5631d0ff96681a6d4c1f08403d3b182e9e346c84bb4708a67830f79" diff --git a/pyproject.toml b/pyproject.toml index c78e4d3..1417852 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,8 @@ catboost = "^1.2.2" fire = "^0.5.0" hydra-core = "^1.3.2" mlflow = "^2.8.1" +skl2onnx = "^1.15.0" +onnxruntime = "^1.16.3" [tool.poetry.group.dev.dependencies] pre-commit = "^3.4.0"