-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Isolation Forest Implementation (#130)
* note: The max_height min value should be 0 rather than 1; the bst_expect_length should be 1 when n = 2(same as sklearn implementation). This fix makes the demo example result almost same result as sklearn
- Loading branch information
1 parent
b0e6b0d
commit e7648ac
Showing
5 changed files
with
372 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Isolation Forest | ||
|
||
::: toyml.ensemble.iforest | ||
options: | ||
members: | ||
- IsolationForest | ||
- IsolationTree |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import math | ||
|
||
import pytest | ||
|
||
from toyml.ensemble.iforest import IsolationForest, IsolationTree | ||
|
||
|
||
class TestIsolationTree: | ||
@pytest.mark.parametrize("max_height", [-1, -2]) | ||
def test_invalid_max_height_raise_error(self, max_height: int) -> None: | ||
with pytest.raises(ValueError, match="The max height of"): | ||
_ = IsolationTree(max_height=max_height) | ||
|
||
@pytest.mark.parametrize("max_height", [0, 1, 2]) | ||
def test_single_sample_build_leaf_node(self, max_height: int) -> None: | ||
sut = IsolationTree(max_height=max_height) | ||
samples = [[0.0, 1.0]] | ||
|
||
sut.fit(samples) | ||
|
||
assert sut.sample_size_ == 1 | ||
assert sut.feature_num_ == 2 | ||
assert sut.is_external_node() is True | ||
|
||
@pytest.mark.parametrize("max_height", [0, 1, 2]) | ||
def test_two_same_samples_build_leaf_node(self, max_height: int) -> None: | ||
sut = IsolationTree(max_height=max_height) | ||
samples = [[0.0, 1.0], [0.0, 1.0]] | ||
|
||
sut.fit(samples) | ||
|
||
assert sut.sample_size_ == 2 | ||
assert sut.feature_num_ == 2 | ||
assert sut.is_external_node() is True | ||
|
||
@pytest.mark.parametrize("max_height", [1, 2, 3]) | ||
def test_two_different_samples_build_three_node_itree(self, max_height: int) -> None: | ||
sut = IsolationTree(max_height=max_height) | ||
samples = [[0.0, 1.0], [0.0, 2.0]] | ||
|
||
sut.fit(samples) | ||
|
||
assert sut.sample_size_ == 2 | ||
assert sut.feature_num_ == 2 | ||
assert sut.is_external_node() is False | ||
assert sut.left_ is not None | ||
assert sut.left_.is_external_node() is True | ||
assert sut.right_ is not None | ||
assert sut.right_.is_external_node() is True | ||
|
||
@pytest.mark.parametrize("max_height", [0, 1, 2]) | ||
def test_leaf_node_path_length(self, max_height: int) -> None: | ||
sut = IsolationTree(max_height=max_height) | ||
samples = [[0.0, 1.0]] | ||
|
||
sut.fit(samples) | ||
|
||
assert sut.get_sample_path_length(samples[0]) == 0 | ||
|
||
@pytest.mark.parametrize("max_height", [1, 2, 3]) | ||
def test_itree_node_path_length(self, max_height: int) -> None: | ||
sut = IsolationTree(max_height=max_height) | ||
samples = [[0.0, 1.0], [0.0, 2.0]] | ||
|
||
sut.fit(samples) | ||
|
||
assert sut.get_sample_path_length(samples[0]) == 1 | ||
assert sut.get_sample_path_length(samples[1]) == 1 | ||
|
||
|
||
class TestIsolationForest: | ||
@pytest.fixture | ||
def simple_dataset(self) -> list[list[float]]: | ||
return [[-1.1], [0.3], [0.5], [100.0]] | ||
|
||
@pytest.mark.parametrize( | ||
"n_itree, max_samples", | ||
[ | ||
(5, 3), | ||
(8, 4), | ||
(10, 6), | ||
], | ||
) | ||
def test_itree_build( | ||
self, | ||
simple_dataset: list[list[float]], | ||
n_itree: int, | ||
max_samples: int, | ||
) -> None: | ||
sut = IsolationForest(n_itree=n_itree, max_samples=max_samples) | ||
|
||
sut.fit(simple_dataset) | ||
|
||
assert len(sut.itrees_) == n_itree | ||
|
||
@pytest.mark.parametrize( | ||
"n_itree, max_samples", | ||
[ | ||
(5, 3), | ||
(8, 4), | ||
(10, 6), | ||
], | ||
) | ||
def test_anomaly_predict( | ||
self, | ||
simple_dataset: list[list[float]], | ||
n_itree: int, | ||
max_samples: int, | ||
) -> None: | ||
sut = IsolationForest(n_itree=n_itree) | ||
|
||
labels = sut.fit_predict(simple_dataset) | ||
|
||
assert all(label == 1 or label == -1 for label in labels) is True | ||
assert labels[-1] == -1 | ||
|
||
@pytest.mark.parametrize( | ||
"n_itree, max_samples", | ||
[ | ||
(5, 4), | ||
(10, 3), | ||
], | ||
) | ||
def test_anomaly_score_property( | ||
self, | ||
simple_dataset: list[list[float]], | ||
n_itree: int, | ||
max_samples: int, | ||
) -> None: | ||
sut = IsolationForest(n_itree=n_itree, max_samples=max_samples) | ||
|
||
sut.fit(simple_dataset) | ||
scores = [sut.score(sample) for sample in simple_dataset] | ||
|
||
assert math.isclose(max(scores), scores[-1]) | ||
assert all([0 <= score <= 1 for score in scores]) is True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from .adaboost import AdaBoost | ||
from .iforest import IsolationForest | ||
|
||
__all__ = [ | ||
"AdaBoost", | ||
"IsolationForest", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
from __future__ import annotations | ||
|
||
import math | ||
import random | ||
|
||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
|
||
|
||
def bst_expect_length(n: int) -> float: | ||
if n <= 1: | ||
return 0 | ||
if n == 2: | ||
return 1 | ||
return 2 * (math.log(n - 1) + 0.5772156649) - (2 * (n - 1) / n) | ||
|
||
|
||
@dataclass | ||
class IsolationTree: | ||
""" | ||
The isolation tree. | ||
Note: | ||
The isolation tree is a proper(full) binary tree, which has either 0 or 2 children. | ||
""" | ||
|
||
max_height: int | ||
"""The maximum height of the tree.""" | ||
random_seed: int | None = None | ||
"""The random seed used to initialize the centroids.""" | ||
sample_size_: int | None = None | ||
"""The sample size.""" | ||
feature_num_: int | None = None | ||
"""The number of features at each sample.""" | ||
left_: IsolationTree | None = None | ||
"""The left child of the tree.""" | ||
right_: IsolationTree | None = None | ||
"""The right child of the tree.""" | ||
split_at_: int | None = None | ||
"""The index of feature which is used to split the tree's samples into children.""" | ||
split_value_: float | None = None | ||
"""The value of split_at feature that use to split samples""" | ||
|
||
def __post_init__(self) -> None: | ||
self.random_state = random.Random(self.random_seed) | ||
if self.max_height < 0: | ||
raise ValueError(f"The max height of {self.__class__.__name__} must >= 0, not get {self.max_height}") | ||
|
||
def fit(self, samples: list[list[float]]) -> IsolationTree: | ||
""" | ||
Fit the isolation tree. | ||
""" | ||
self.sample_size_ = len(samples) | ||
self.feature_num_ = len(samples[0]) | ||
# exNode | ||
if self.max_height == 0 or self.sample_size_ == 1: | ||
return self | ||
# inNode | ||
left_itree, right_itree = self._get_left_right_child_itree(samples) | ||
self.left_, self.right_ = left_itree, right_itree | ||
return self | ||
|
||
def get_sample_path_length(self, sample: list[float]) -> float: | ||
""" | ||
Get the sample's path length to the external(leaf) node. | ||
Args: | ||
sample: The data sample. | ||
Returns: | ||
The path length of the sample. | ||
""" | ||
if self.is_external_node(): | ||
assert self.sample_size_ is not None | ||
if self.sample_size_ == 1: | ||
return 0 | ||
else: | ||
return bst_expect_length(self.sample_size_) | ||
|
||
assert self.split_at_ is not None and self.split_value_ is not None | ||
if sample[self.split_at_] < self.split_value_: | ||
assert self.left_ is not None | ||
return 1 + self.left_.get_sample_path_length(sample) | ||
else: | ||
assert self.right_ is not None | ||
return 1 + self.right_.get_sample_path_length(sample) | ||
|
||
def is_external_node(self) -> bool: | ||
""" | ||
The tree node is external(leaf) node or not. | ||
""" | ||
if self.left_ is None and self.right_ is None: | ||
return True | ||
return False | ||
|
||
def _get_left_right_child_itree( | ||
self, samples: list[list[float]] | ||
) -> tuple[Optional[IsolationTree], Optional[IsolationTree]]: | ||
assert self.feature_num_ is not None | ||
split_at_list = list(range(self.feature_num_)) | ||
self.random_state.shuffle(split_at_list) | ||
for split_at in split_at_list: | ||
split_at_feature_values = [sample[split_at] for sample in samples] | ||
split_at_min, split_at_max = min(split_at_feature_values), max(split_at_feature_values) | ||
if math.isclose(split_at_min, split_at_max): | ||
continue | ||
split_value = self.random_state.uniform(split_at_min, split_at_max) | ||
left_samples, right_samples = self._get_sub_samples_by_split(samples, split_at, split_value) | ||
# need to keep proper binary tree property: all internal nodes have exactly two children | ||
if len(left_samples) > 0 and len(right_samples) > 0: | ||
self.split_at_, self.split_value_ = split_at, split_value | ||
left_itree = IsolationTree(max_height=self.max_height - 1).fit(left_samples) | ||
right_itree = IsolationTree(max_height=self.max_height - 1).fit(right_samples) | ||
return left_itree, right_itree | ||
# cannot split the samples by any features | ||
return None, None | ||
|
||
@staticmethod | ||
def _get_sub_samples_by_split( | ||
samples: list[list[float]], | ||
split_at: int, | ||
split_value: float, | ||
) -> tuple[list[list[float]], list[list[float]]]: | ||
left_samples, right_samples = [], [] | ||
for sample in samples: | ||
if sample[split_at] < split_value: | ||
left_samples.append(sample) | ||
else: | ||
right_samples.append(sample) | ||
return left_samples, right_samples | ||
|
||
|
||
@dataclass | ||
class IsolationForest: | ||
""" | ||
Isolation Forest. | ||
Examples: | ||
>>> from toyml.ensemble.iforest import IsolationForest | ||
>>> dataset = [[-1.1], [0.3], [0.5], [100.0]] | ||
>>> IsolationForest(n_itree=100, max_samples=4).fit_predict(dataset) | ||
[1, 1, 1, -1] | ||
References: | ||
Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou. "Isolation forest." 2008 eighth ieee international conference on data mining. IEEE, 2008. | ||
""" | ||
|
||
n_itree: int = 100 | ||
"""The number of isolation tree in the ensemble.""" | ||
max_samples: None | int = None | ||
"""The number of samples to draw from X to train each base estimator.""" | ||
score_threshold: float = 0.5 | ||
"""The score threshold that is used to define outlier: | ||
If sample's anomaly score > score_threshold, then the sample is detected as outlier(predict return -1); | ||
otherwise, the sample is normal(predict return 1)""" | ||
random_seed: int | None = None | ||
"""The random seed used to initialize the centroids.""" | ||
itrees_: list[IsolationTree] = field(default_factory=list) | ||
"""The isolation trees in the forest.""" | ||
|
||
def __post_init__(self) -> None: | ||
self.random_state = random.Random(self.random_seed) | ||
|
||
def fit(self, dataset: list[list[float]]) -> IsolationForest: | ||
""" | ||
Fit the isolation forest model. | ||
""" | ||
if self.max_samples is None or self.max_samples > len(dataset): | ||
self.max_samples = len(dataset) | ||
|
||
self.itrees_ = self._fit_itrees(dataset) | ||
return self | ||
|
||
def score(self, sample: list[float]) -> float: | ||
""" | ||
Predict the sample's anomaly score. | ||
Args: | ||
sample: The data sample. | ||
Returns: | ||
The anomaly score. | ||
""" | ||
assert len(self.itrees_) == self.n_itree, "Please fit the model before score sample!" | ||
assert self.max_samples is not None, "Please fit the model before score sample!" | ||
itree_path_lengths = [itree.get_sample_path_length(sample) for itree in self.itrees_] | ||
expect_path_length = sum(itree_path_lengths) / len(itree_path_lengths) | ||
score = 2 ** (-expect_path_length / bst_expect_length(self.max_samples)) | ||
return score | ||
|
||
def predict(self, sample: list[float]) -> int: | ||
""" | ||
Predict the sample is outlier ot not. | ||
Args: | ||
sample: The data sample. | ||
Returns: | ||
Outlier: -1; Normal: 1. | ||
""" | ||
score = self.score(sample) | ||
# outlier | ||
if score > self.score_threshold: | ||
return -1 | ||
else: | ||
return 1 | ||
|
||
def fit_predict(self, dataset: list[list[float]]) -> list[int]: | ||
self.fit(dataset) | ||
return [self.predict(sample) for sample in dataset] | ||
|
||
def _fit_itrees(self, dataset: list[list[float]]) -> list[IsolationTree]: | ||
itrees = [self._fit_itree(dataset) for _ in range(self.n_itree)] | ||
return itrees | ||
|
||
def _fit_itree(self, dataset: list[list[float]]) -> IsolationTree: | ||
assert self.max_samples is not None, "Please fit the model before score sample!" | ||
samples = self.random_state.sample(dataset, self.max_samples) | ||
itree_max_height = math.ceil(math.log2(len(samples))) | ||
return IsolationTree(max_height=itree_max_height, random_seed=self.random_seed).fit(samples) | ||
|
||
|
||
if __name__ == "__main__": | ||
simple_dataset = [[-1.1], [0.3], [0.5], [100.0]] | ||
result = IsolationForest(random_seed=42).fit_predict(simple_dataset) | ||
print(result) |