Skip to content

Commit

Permalink
init impl
Browse files Browse the repository at this point in the history
  • Loading branch information
shenxiangzhuang committed Oct 25, 2024
1 parent e973902 commit 73e86c8
Showing 1 changed file with 162 additions and 0 deletions.
162 changes: 162 additions & 0 deletions toyml/ensemble/iforest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from __future__ import annotations

import math
import random
import sys

from dataclasses import dataclass, field
from typing import Optional


def bst_expect_length(n: int) -> float:
return 2 * (math.log(n - 1) + 0.5772156649) - (2 * (n - 1) / n)

Check warning on line 12 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L12

Added line #L12 was not covered by tests


@dataclass
class IsolationTree:
max_height: int
"""The maximum height of the tree."""
sample_size_: int | None = None
"""The sample size."""
feature_num_: int | None = None
"""The number of features at each sample."""
left_: IsolationTree | None = None
"""The left child of the tree."""
right_: IsolationTree | None = None
"""The right child of the tree."""
split_at_: int | None = None
"""The index of feature which is used to split the tree's samples into children."""
split_value_: float | None = None
"""The value of split_at feature that use to split samples"""

def __post_init__(self) -> None:
if self.max_height < 1:
raise ValueError(f"The max height of {self.__class__.__name__} must >= 1, " f"not get {self.max_height}")

Check warning on line 34 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L33-L34

Added lines #L33 - L34 were not covered by tests

def fit(self, samples: list[list[float]]) -> IsolationTree:
self.sample_size_ = len(samples)
self.feature_num_ = len(samples[0])

Check warning on line 38 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L37-L38

Added lines #L37 - L38 were not covered by tests
# exNode
# TODO: if all samples are same
if self.max_height == 1 or self.sample_size_ == 1:
return self

Check warning on line 42 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L41-L42

Added lines #L41 - L42 were not covered by tests
# inNode
left_itree, right_itree = self._get_left_right_child_itree(samples)
self.left_, self.right_ = left_itree, right_itree
return self

Check warning on line 46 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L44-L46

Added lines #L44 - L46 were not covered by tests

def get_sample_path_length(self, sample: list[float]) -> float:
if self.is_external_node():
assert self.sample_size_ is not None

Check warning on line 50 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L49-L50

Added lines #L49 - L50 were not covered by tests
# sklearn: https://github.com/scikit-learn/scikit-learn/blob/6e9039160f0dfc3153643143af4cfdca941d2045/sklearn/ensemble/_iforest.py#L517-L518
# For a single training sample, denominator and depth are 0.
# Therefore, we set the score manually to 1.
if self.sample_size_ == 1:
return 1

Check warning on line 55 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L54-L55

Added lines #L54 - L55 were not covered by tests
else:
return bst_expect_length(self.sample_size_)

Check warning on line 57 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L57

Added line #L57 was not covered by tests

assert self.split_at_ is not None and self.split_value_ is not None
if sample[self.split_at_] < self.split_value_:
assert self.left_ is not None
return 1 + self.left_.get_sample_path_length(sample)

Check warning on line 62 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L59-L62

Added lines #L59 - L62 were not covered by tests
else:
assert self.right_ is not None
return 1 + self.right_.get_sample_path_length(sample)

Check warning on line 65 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L64-L65

Added lines #L64 - L65 were not covered by tests

def is_external_node(self) -> bool:
if self.left_ is None and self.right_ is None:
return True
return False

Check warning on line 70 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L68-L70

Added lines #L68 - L70 were not covered by tests

def _get_left_right_child_itree(
self, samples: list[list[float]]
) -> tuple[Optional[IsolationTree], Optional[IsolationTree]]:
assert self.feature_num_ is not None
split_at_list = list(range(self.feature_num_))
random.shuffle(split_at_list)
for split_at in split_at_list:
split_at_feature_values = [sample[split_at] for sample in samples]
split_value = random.uniform(

Check warning on line 80 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L75-L80

Added lines #L75 - L80 were not covered by tests
min(split_at_feature_values) + sys.float_info.epsilon, max(split_at_feature_values)
)
left_samples, right_samples = self._get_sub_samples_by_split(samples, split_at, split_value)

Check warning on line 83 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L83

Added line #L83 was not covered by tests
# need to keep proper binary tree property: all internal nodes have exactly two children
if len(left_samples) > 0 and len(right_samples) > 0:
self.split_at_, self.split_value_ = split_at, split_value
left_itree = IsolationTree(max_height=self.max_height - 1).fit(left_samples)
right_itree = IsolationTree(max_height=self.max_height - 1).fit(right_samples)
return left_itree, right_itree

Check warning on line 89 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L85-L89

Added lines #L85 - L89 were not covered by tests
# can not split the samples by any features
return None, None

Check warning on line 91 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L91

Added line #L91 was not covered by tests

@staticmethod
def _get_sub_samples_by_split(
samples: list[list[float]],
split_at: int,
split_value: float,
) -> tuple[list[list[float]], list[list[float]]]:
left_samples, right_samples = [], []
for sample in samples:
if sample[split_at] < split_value:
left_samples.append(sample)

Check warning on line 102 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L99-L102

Added lines #L99 - L102 were not covered by tests
else:
right_samples.append(sample)
return left_samples, right_samples

Check warning on line 105 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L104-L105

Added lines #L104 - L105 were not covered by tests


@dataclass
class IsolationForest:
"""
Isolation Forest.
"""

n_itree: int = 100
"""The number of isolation tree in the ensemble."""
max_samples: int = 256
"""The number of samples to draw from X to train each base estimator."""
random_seed: int | None = None
"""The random seed used to initialize the centroids."""
itrees_: list[IsolationTree] = field(default_factory=list)
"""The isolation trees in the forest."""

def __post_init__(self) -> None:
random.seed(self.random_seed)

Check warning on line 124 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L124

Added line #L124 was not covered by tests

def fit(self, dataset: list[list[float]]) -> IsolationForest:
if self.max_samples > len(dataset):
self.max_samples = len(dataset)

Check warning on line 128 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L127-L128

Added lines #L127 - L128 were not covered by tests

self.itrees_ = self._fit_itrees(dataset)
return self

Check warning on line 131 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L130-L131

Added lines #L130 - L131 were not covered by tests

def predict(self, sample: list[float]) -> float:
"""
Predict the sample's anomaly score.
Args:
sample: The data sample.
Returns:
The anomaly score.
"""
itree_path_lengths = [itree.get_sample_path_length(sample) for itree in self.itrees_]
expect_path_length = sum(itree_path_lengths) / len(itree_path_lengths)
score = 2 ** (-expect_path_length / bst_expect_length(self.max_samples))
return score

Check warning on line 146 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L143-L146

Added lines #L143 - L146 were not covered by tests

def _fit_itrees(self, dataset: list[list[float]]) -> list[IsolationTree]:
itrees = [self._fit_itree(dataset) for _ in range(self.n_itree)]
return itrees

Check warning on line 150 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L149-L150

Added lines #L149 - L150 were not covered by tests

def _fit_itree(self, dataset: list[list[float]]) -> IsolationTree:
samples = random.sample(dataset, self.max_samples)
itree_max_height = math.ceil(len(samples))
return IsolationTree(itree_max_height).fit(samples)

Check warning on line 155 in toyml/ensemble/iforest.py

View check run for this annotation

Codecov / codecov/patch

toyml/ensemble/iforest.py#L153-L155

Added lines #L153 - L155 were not covered by tests


if __name__ == "__main__":
simple_dataset = [[1.0, 1.0], [0.0, 0.0], [1.1, 1.0], [10.0, 10.0]]
clf = IsolationForest(n_itree=10, max_samples=4).fit(simple_dataset)
print(clf.predict([1, 1]))
print(clf.predict([10, 10]))

0 comments on commit 73e86c8

Please sign in to comment.