Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: onehot op #472

Merged
merged 14 commits into from
Sep 12, 2023
9 changes: 3 additions & 6 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,7 @@ jobs:

prove-and-verify-aggr-evm-tests:
runs-on: large-self-hosted
needs:
[
build,
library-tests,
python-tests,
]
needs: [build, library-tests, python-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
Expand Down Expand Up @@ -460,6 +455,8 @@ jobs:
# # now dump the contents of the file into a file called kaggle.json
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
- name: SVM
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_18_expects
- name: LightGBM
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_17_expects
- name: XGBoost
Expand Down
14 changes: 7 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ tokio = { version = "1.26.0", default_features = false, features = ["macros", "
pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
pyo3-asyncio = { version = "0.18.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
pyo3-log = { version = "0.8.1", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "561614519e6cb49eea4d88dcee3b880f127813cb", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "2ea76c09678f092d00713ebbe6fdb046c0a9ad0f", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }


Expand Down
10 changes: 5 additions & 5 deletions examples/notebooks/decision_tree.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"![image-2.png](attachment:image-2.png)\n",
"\n",
"\n",
"This notebook showcases how to do that using the `sk2torch` python package ! "
"This notebook showcases how to do that using the `hummingbird-ml` python package ! "
]
},
{
Expand All @@ -46,7 +46,7 @@
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"sk2torch\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
Expand All @@ -61,7 +61,7 @@
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.tree import DecisionTreeClassifier as De\n",
"import sk2torch\n",
"from hummingbird.ml import convert\n",
"import torch\n",
"import ezkl\n",
"import os\n",
Expand All @@ -75,7 +75,7 @@
"clr = De()\n",
"clr.fit(X_train, y_train)\n",
"\n",
"circuit = sk2torch.wrap(clr)\n",
"circuit = convert(clr, \"torch\", X_test[:1]).model\n",
"\n",
"\n",
"\n"
Expand Down Expand Up @@ -282,7 +282,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
}
},
"nbformat": 4,
Expand Down
Loading