Skip to content

Commit

Permalink
Test GTIL with IsolationForest (#370)
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 authored Feb 23, 2022
1 parent b65bedc commit 4cc4f7e
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion tests/python/test_gtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import scipy
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, \
ExtraTreesClassifier, RandomForestRegressor, GradientBoostingRegressor, \
ExtraTreesRegressor
ExtraTreesRegressor, IsolationForest
from sklearn.datasets import load_iris, load_breast_cancer, load_boston, load_svmlight_file
from sklearn.model_selection import train_test_split

Expand Down Expand Up @@ -80,6 +80,18 @@ def test_skl_multiclass_classifier(clazz):
np.testing.assert_almost_equal(out_prob, expected_prob, decimal=5)


def test_skl_converter_iforest():
"""Scikit-learn isolation forest"""
X, _ = load_boston(return_X_y=True)
clf = IsolationForest(max_samples=64, random_state=0, n_estimators=10)
clf.fit(X)
expected_pred = clf._compute_chunked_score_samples(X) # pylint: disable=W0212

tl_model = treelite.sklearn.import_model(clf)
out_pred = treelite.gtil.predict(tl_model, X)
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=2)


@pytest.mark.parametrize('objective', ['reg:linear', 'reg:squarederror', 'reg:squaredlogerror',
'reg:pseudohubererror'])
@pytest.mark.parametrize('model_format', ['binary', 'json'])
Expand Down

0 comments on commit 4cc4f7e

Please sign in to comment.