Skip to content

Commit

Permalink
reorg: test
Browse files Browse the repository at this point in the history
  • Loading branch information
shenxiangzhuang committed Nov 11, 2024
1 parent 3a06f24 commit 7f80854
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 120 deletions.
49 changes: 0 additions & 49 deletions tests/classification/naive_bayes/test_categorical_naive_bayes.py

This file was deleted.

49 changes: 0 additions & 49 deletions tests/classification/naive_bayes/test_multinomial_naive_bayes.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import math

import numpy as np
import pytest

from sklearn.naive_bayes import GaussianNB
from sklearn.naive_bayes import CategoricalNB, GaussianNB, MultinomialNB

from toyml.classification.naive_bayes import GaussianNaiveBayes
from toyml.classification.naive_bayes import (
CategoricalNaiveBayes,
GaussianNaiveBayes,
MultinomialNaiveBayes,
)


@pytest.fixture
Expand Down Expand Up @@ -36,6 +41,17 @@ def wikipedia_person_classification_sample() -> list[float]:
return [6, 130, 8]


@pytest.fixture
def sklearn_example_random_dataset_label() -> tuple[list[list[int]], list[int]]:
"""
References: https://scikit-learn.org/1.5/modules/generated/sklearn.naive_bayes.MultinomialNB.html#multinomialnb
"""
rng = np.random.RandomState(1)
dataset = rng.randint(5, size=(6, 100)).tolist()
label = np.array([1, 2, 3, 4, 5, 6]).tolist()
return dataset, label


class TestGaussianNaiveBayesIntegration:
def test_same_result_with_wikipedia(
self,
Expand Down Expand Up @@ -79,3 +95,63 @@ def test_same_result_with_sklearn(
sklearn_prob = sklearn_clf.predict_proba([wikipedia_person_classification_sample])
assert math.isclose(sut_prob[0], sklearn_prob[0][0])
assert math.isclose(sut_prob[1], sklearn_prob[0][1])


class TestMultinomialNaiveBayesIntegration:
def test_same_result_with_sklearn(
self,
sklearn_example_random_dataset_label: tuple[list[list[float]], list[int]],
) -> None:
dataset, label = sklearn_example_random_dataset_label
sklearn_clf = MultinomialNB()
sklearn_clf.fit(dataset, label)
# use the same variance calculation config with sklearn
sut = MultinomialNaiveBayes(alpha=1).fit(dataset, label)
# test same labels
test_sample = dataset[2]
sklearn_label = sklearn_clf.predict([test_sample])
sut_label = sut.predict(test_sample)

assert sut_label == sklearn_label[0]

# test same log probs
sut_log_prob = sut.predict_log_proba(test_sample)
sklearn_log_prob = sklearn_clf.predict_log_proba([test_sample])
for i in range(6):
assert math.isclose(sut_log_prob[i + 1], sklearn_log_prob[0][i])

# # test same probs
sut_prob = sut.predict_proba(test_sample)
sklearn_prob = sklearn_clf.predict_proba([test_sample])
for i in range(6):
assert math.isclose(sut_prob[i + 1], sklearn_prob[0][i])


class TestCategoricalNaiveBayesIntegration:
def test_same_result_with_sklearn(
self,
sklearn_example_random_dataset_label: tuple[list[list[float]], list[int]],
) -> None:
dataset, label = sklearn_example_random_dataset_label
sklearn_clf = CategoricalNB(alpha=1)
sklearn_clf.fit(dataset, label)
# use the same variance calculation config with sklearn
sut = CategoricalNaiveBayes(alpha=1).fit(dataset, label)
# test same labels
test_sample = dataset[2]
sklearn_label = sklearn_clf.predict([test_sample])
sut_label = sut.predict(test_sample)

assert sut_label == sklearn_label[0]

# test same log probs
sut_log_prob = sut.predict_log_proba(test_sample)
sklearn_log_prob = sklearn_clf.predict_log_proba([test_sample])
for i in range(6):
assert math.isclose(sut_log_prob[i + 1], sklearn_log_prob[0][i])

# # test same probs
sut_prob = sut.predict_proba(test_sample)
sklearn_prob = sklearn_clf.predict_proba([test_sample])
for i in range(6):
assert math.isclose(sut_prob[i + 1], sklearn_prob[0][i])
40 changes: 20 additions & 20 deletions toyml/classification/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

Class = int
Dimension = int
FeatureValue = float | int
FeatureValue = float


@dataclass
Expand Down Expand Up @@ -110,7 +110,7 @@ class GaussianNaiveBayes(BaseNaiveBayes):
epsilon_: float = 0
"""The absolute additive value to variances."""

def fit(self, dataset: list[list[float]], labels: list[Class]) -> GaussianNaiveBayes:
def fit(self, dataset: list[list[FeatureValue]], labels: list[Class]) -> GaussianNaiveBayes:
"""Fit the Gaussian Naive Bayes classifier.
Args:
Expand All @@ -127,11 +127,11 @@ def fit(self, dataset: list[list[float]], labels: list[Class]) -> GaussianNaiveB
self.means_, self.variances_ = self._get_classes_means_variances(dataset, labels)
return self

def _log_likelihood(self, sample: list[float]) -> dict[Class, float]:
def _log_likelihood(self, sample: list[FeatureValue]) -> dict[Class, float]:
"""
Calculate the likelihood of each sample in each class
"""
label_likelihoods: dict[int, float] = {}
label_likelihoods: dict[Class, float] = {}
for label in self.labels_:
label_means = self.means_[label]
label_vars = self.variances_[label]
Expand All @@ -146,8 +146,8 @@ def _log_likelihood(self, sample: list[float]) -> dict[Class, float]:

def _get_classes_means_variances(
self,
dataset: list[list[float]],
labels: list[int],
dataset: list[list[FeatureValue]],
labels: list[Class],
) -> tuple[dict[Class, list[float]], dict[Class, list[float]]]:
means, variances = {}, {}
for label in self.labels_:
Expand All @@ -157,19 +157,19 @@ def _get_classes_means_variances(
return means, variances

@staticmethod
def _dataset_column_means(dataset: list[list[float]]) -> list[float]:
def _dataset_column_means(dataset: list[list[FeatureValue]]) -> list[float]:
"""
Calculate vectors mean
"""
return [statistics.mean(column) for column in zip(*dataset, strict=True)]

def _dataset_column_variances(self, dataset: list[list[float]]) -> list[float]:
def _dataset_column_variances(self, dataset: list[list[FeatureValue]]) -> list[float]:
"""
Calculate vectors(every column) standard variance
"""
return [self._variance(column) + self.epsilon_ for column in zip(*dataset, strict=True)]

def _variance(self, xs: list[float] | tuple[float, ...]) -> float:
def _variance(self, xs: list[FeatureValue] | tuple[FeatureValue, ...]) -> float:
mean = statistics.mean(xs)
ss = sum((x - mean) ** 2 for x in xs)
if self.unbiased_variance is True:
Expand Down Expand Up @@ -208,12 +208,12 @@ class MultinomialNaiveBayes(BaseNaiveBayes):
class_feature_log_prob_: dict[Class, list[float]] = field(default_factory=dict)
"""The feature value probability of each class in training dataset"""

def fit(self, dataset: list[list[int]], labels: list[Class]) -> MultinomialNaiveBayes: # type: ignore[override]
def fit(self, dataset: list[list[FeatureValue]], labels: list[Class]) -> MultinomialNaiveBayes:
"""Fit the Multinomial Naive Bayes classifier.
Args:
dataset: Training data, where each row is a sample and each column is a feature.
Features should be represented as counts (non-negative integers).
Features should be represented as counts (non-negative integers).
labels: Target labels for training data.
Returns:
Expand All @@ -226,7 +226,7 @@ def fit(self, dataset: list[list[int]], labels: list[Class]) -> MultinomialNaive
self.class_feature_count_, self.class_feature_log_prob_ = self._get_classes_feature_count_prob(dataset, labels)
return self

def _log_likelihood(self, sample: list[int]) -> dict[Class, float]: # type: ignore[override]
def _log_likelihood(self, sample: list[FeatureValue]) -> dict[Class, float]:
"""
Calculate the likelihood of each sample in each class
"""
Expand All @@ -241,8 +241,8 @@ def _log_likelihood(self, sample: list[int]) -> dict[Class, float]: # type: ign

def _get_classes_feature_count_prob(
self,
dataset: list[list[int]],
labels: list[int],
dataset: list[list[FeatureValue]],
labels: list[Class],
) -> tuple[dict[Class, list[int]], dict[Class, list[float]]]:
feature_count, feature_prob = {}, {}
for label in self.labels_:
Expand All @@ -253,7 +253,7 @@ def _get_classes_feature_count_prob(

return feature_count, feature_prob

def _dataset_feature_counts(self, dataset: list[list[int]]) -> list[int]:
def _dataset_feature_counts(self, dataset: list[list[FeatureValue]]) -> list[int]:
"""
Calculate feature value counts
"""
Expand Down Expand Up @@ -288,7 +288,7 @@ class CategoricalNaiveBayes(BaseNaiveBayes):
class_feature_log_prob_: dict[Class, dict[Dimension, dict[FeatureValue, float]]] = field(default_factory=dict)
"""The feature value probability of each class in training dataset"""

def fit(self, dataset: list[list[int]], labels: list[Class]) -> CategoricalNaiveBayes: # type: ignore[override]
def fit(self, dataset: list[list[FeatureValue]], labels: list[Class]) -> CategoricalNaiveBayes:
"""Fit the Categorical Naive Bayes classifier.
Args:
Expand All @@ -305,7 +305,7 @@ def fit(self, dataset: list[list[int]], labels: list[Class]) -> CategoricalNaive
self.class_feature_count_, self.class_feature_log_prob_ = self._get_classes_feature_count_prob(dataset, labels)
return self

def _log_likelihood(self, sample: list[int]) -> dict[Class, float]: # type: ignore[override]
def _log_likelihood(self, sample: list[FeatureValue]) -> dict[Class, float]:
"""
Calculate the likelihood of each sample in each class
"""
Expand All @@ -320,8 +320,8 @@ def _log_likelihood(self, sample: list[int]) -> dict[Class, float]: # type: ign

def _get_classes_feature_count_prob(
self,
dataset: list[list[int]],
labels: list[int],
dataset: list[list[FeatureValue]],
labels: list[Class],
) -> tuple: # type: ignore[type-arg]
feature_smooth_count: dict[Dimension, dict[FeatureValue, float]] = {}
for dim, column in enumerate(zip(*dataset)):
Expand All @@ -331,7 +331,7 @@ def _get_classes_feature_count_prob(
feature_prob: dict[Class, dict[Dimension, dict[FeatureValue, float]]] = {}
for label in self.labels_:
label_samples = [sample for (sample, sample_label) in zip(dataset, labels) if sample_label == label]
counts = self._dataset_feature_counts(label_samples, feature_smooth_count) # type: ignore[arg-type]
counts = self._dataset_feature_counts(label_samples, feature_smooth_count)
feature_count[label] = counts
feature_prob[label] = {}
for dim, feature_value_count in counts.items():
Expand Down

0 comments on commit 7f80854

Please sign in to comment.