Skip to content

Commit

Permalink
multinomial naive bayes: exactly same result with sklearn
Browse files Browse the repository at this point in the history
  • Loading branch information
shenxiangzhuang committed Nov 6, 2024
1 parent db3dd03 commit ed8b482
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions toyml/classification/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,21 +137,26 @@ def fit(self, dataset: list[list[int]], labels: list[int]) -> MultinomialNaiveBa
return self

Check warning on line 137 in toyml/classification/naive_bayes.py

View check run for this annotation

Codecov / codecov/patch

toyml/classification/naive_bayes.py#L135-L137

Added lines #L135 - L137 were not covered by tests

def predict(self, sample: list[int]) -> int:
label_posteriors = self.predict_log_prob(sample)
label = max(label_posteriors, key=lambda k: label_posteriors[k])
return label

Check warning on line 142 in toyml/classification/naive_bayes.py

View check run for this annotation

Codecov / codecov/patch

toyml/classification/naive_bayes.py#L140-L142

Added lines #L140 - L142 were not covered by tests

def predict_prob(self, sample: list[int]) -> dict[int, float]:
label_posteriors = self.predict_log_prob(sample)
return {label: math.exp(log_prob) for label, log_prob in label_posteriors.items()}

Check warning on line 146 in toyml/classification/naive_bayes.py

View check run for this annotation

Codecov / codecov/patch

toyml/classification/naive_bayes.py#L145-L146

Added lines #L145 - L146 were not covered by tests

def predict_log_prob(self, sample: list[int]) -> dict[int, float]:
label_likelihoods = self._likelihood(sample)
raw_label_posteriors: dict[int, float] = {}
for label, likelihood in label_likelihoods.items():
raw_label_posteriors[label] = likelihood + math.log(self.class_prior_[label])

Check warning on line 152 in toyml/classification/naive_bayes.py

View check run for this annotation

Codecov / codecov/patch

toyml/classification/naive_bayes.py#L149-L152

Added lines #L149 - L152 were not covered by tests
print(raw_label_posteriors)
raw_label_posteriors_shift = {
label: likelihood - max(raw_label_posteriors.values()) for label, likelihood in raw_label_posteriors.items()

# ref: https://github.com/scikit-learn/scikit-learn/blob/2beed55847ee70d363bdbfe14ee4401438fba057/sklearn/naive_bayes.py#L97
logsumexp_prob = math.log(sum(math.exp(log_prob) for log_prob in raw_label_posteriors.values()))
label_posteriors = {

Check warning on line 156 in toyml/classification/naive_bayes.py

View check run for this annotation

Codecov / codecov/patch

toyml/classification/naive_bayes.py#L155-L156

Added lines #L155 - L156 were not covered by tests
label: raw_posterior - logsumexp_prob for label, raw_posterior in raw_label_posteriors.items()
}
print(raw_label_posteriors_shift)
# evidence = sum(raw_label_posteriors_shift.values())
# label_posteriors = {
# label: raw_posterior / evidence for label, raw_posterior in raw_label_posteriors_shift.items()
# }
label = max(raw_label_posteriors_shift, key=lambda k: raw_label_posteriors_shift[k])
return label
return label_posteriors

Check warning on line 159 in toyml/classification/naive_bayes.py

View check run for this annotation

Codecov / codecov/patch

toyml/classification/naive_bayes.py#L159

Added line #L159 was not covered by tests

def _likelihood(self, sample: list[int]) -> dict[int, float]:
"""
Expand Down Expand Up @@ -198,9 +203,12 @@ def _dataset_feature_counts(self, dataset: list[list[int]]) -> list[int]:
clf = MultinomialNB()
clf.fit(X, y)
print(clf.predict(X[2:3]))
print(clf.predict_proba(X[2:3]))
print(clf.predict_log_proba(X[2:3]))

clf1 = MultinomialNaiveBayes(alpha=1)
clf1.fit([[int(v) for v in s] for s in X], [int(v) for v in y])
sample = [int(x) for x in X[2:3][0]]
print(clf1.predict(sample))
print(clf1.predict_prob(sample))
print(clf1.predict_log_prob(sample))

0 comments on commit ed8b482

Please sign in to comment.