Skip to content

Commit

Permalink
fix: random state
Browse files Browse the repository at this point in the history
  • Loading branch information
shenxiangzhuang committed Oct 26, 2024
1 parent e153df2 commit 0af6604
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions toyml/ensemble/iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import math
import random
import sys

from dataclasses import dataclass, field
from typing import Optional
Expand All @@ -18,6 +17,8 @@ def bst_expect_length(n: int) -> float:
class IsolationTree:
max_height: int
"""The maximum height of the tree."""
random_seed: int | None = None
"""The random seed used to initialize the centroids."""
sample_size_: int | None = None
"""The sample size."""
feature_num_: int | None = None
Expand All @@ -32,14 +33,14 @@ class IsolationTree:
"""The value of split_at feature that use to split samples"""

def __post_init__(self) -> None:
self.random_state = random.Random(self.random_seed)
if self.max_height < 1:
raise ValueError(f"The max height of {self.__class__.__name__} must >= 1, " f"not get {self.max_height}")

def fit(self, samples: list[list[float]]) -> IsolationTree:
self.sample_size_ = len(samples)
self.feature_num_ = len(samples[0])
# exNode
# TODO: if all samples are same
if self.max_height == 1 or self.sample_size_ == 1:
return self
# inNode
Expand Down Expand Up @@ -79,17 +80,18 @@ def _get_left_right_child_itree(
random.shuffle(split_at_list)
for split_at in split_at_list:
split_at_feature_values = [sample[split_at] for sample in samples]
split_value = random.uniform(
min(split_at_feature_values) + sys.float_info.epsilon, max(split_at_feature_values)
)
split_at_min, split_at_max = min(split_at_feature_values), max(split_at_feature_values)
if math.isclose(split_at_min, split_at_max):
continue
split_value = self.random_state.uniform(split_at_min, split_at_max)
left_samples, right_samples = self._get_sub_samples_by_split(samples, split_at, split_value)
# need to keep proper binary tree property: all internal nodes have exactly two children
if len(left_samples) > 0 and len(right_samples) > 0:
self.split_at_, self.split_value_ = split_at, split_value
left_itree = IsolationTree(max_height=self.max_height - 1).fit(left_samples)
right_itree = IsolationTree(max_height=self.max_height - 1).fit(right_samples)
return left_itree, right_itree
# can not split the samples by any features
# cannot split the samples by any features
return None, None

@staticmethod
Expand Down Expand Up @@ -123,7 +125,7 @@ class IsolationForest:
"""The isolation trees in the forest."""

def __post_init__(self) -> None:
random.seed(self.random_seed)
self.random_state = random.Random(self.random_seed)

def fit(self, dataset: list[list[float]]) -> IsolationForest:
if self.max_samples > len(dataset):
Expand Down Expand Up @@ -152,9 +154,9 @@ def _fit_itrees(self, dataset: list[list[float]]) -> list[IsolationTree]:
return itrees

def _fit_itree(self, dataset: list[list[float]]) -> IsolationTree:
samples = random.sample(dataset, self.max_samples)
samples = self.random_state.sample(dataset, self.max_samples)
itree_max_height = math.ceil(math.log2(len(samples)))
return IsolationTree(itree_max_height).fit(samples)
return IsolationTree(max_height=itree_max_height, random_seed=self.random_seed).fit(samples)


if __name__ == "__main__":
Expand Down

0 comments on commit 0af6604

Please sign in to comment.