forked from ptnplanet/Java-Naive-Bayes-Classifier
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BayesClassifier.java
119 lines (107 loc) · 4.12 KB
/
BayesClassifier.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package de.daslaboratorium.machinelearning.classifier;
import java.util.Collection;
import java.util.Comparator;
import java.util.SortedSet;
import java.util.TreeSet;
/**
* A concrete implementation of the abstract Classifier class. The Bayes
* classifier implements a naive Bayes approach to classifying a given set of
* features: classify(feat1,...,featN) = argmax(P(cat)*PROD(P(featI|cat)
*
* @author Philipp Nolte
*
* @see http://en.wikipedia.org/wiki/Naive_Bayes_classifier
*
* @param <T> The feature class.
* @param <K> The category class.
*/
public class BayesClassifier<T, K> extends Classifier<T, K> {
/**
* Calculates the product of all feature probabilities: PROD(P(featI|cat)
*
* @param features The set of features to use.
* @param category The category to test for.
* @return The product of all feature probabilities.
*/
private float featuresProbabilityProduct(Collection<T> features,
K category) {
float product = 1.0f;
for (T feature : features)
product *= this.featureWeighedAverage(feature, category);
return product;
}
/**
* Calculates the probability that the features can be classified as the
* category given.
*
* @param features The set of features to use.
* @param category The category to test for.
* @return The probability that the features can be classified as the
* category.
*/
private float categoryProbability(Collection<T> features, K category) {
return ((float) this.categoryCount(category)
/ (float) this.getCategoriesTotal())
* featuresProbabilityProduct(features, category);
}
/**
* Retrieves a sorted <code>Set</code> of probabilities that the given set
* of features is classified as the available categories.
*
* @param features The set of features to use.
* @return A sorted <code>Set</code> of category-probability-entries.
*/
private SortedSet<Classification<T, K>> categoryProbabilities(
Collection<T> features) {
/*
* Sort the set according to the possibilities. Because we have to sort
* by the mapped value and not by the mapped key, we can not use a
* sorted tree (TreeMap) and we have to use a set-entry approach to
* achieve the desired functionality. A custom comparator is therefore
* needed.
*/
SortedSet<Classification<T, K>> probabilities =
new TreeSet<Classification<T, K>>(
new Comparator<Classification<T, K>>() {
@Override
public int compare(Classification<T, K> o1,
Classification<T, K> o2) {
int toReturn = Float.compare(
o1.getProbability(), o2.getProbability());
if ((toReturn == 0)
&& !o1.getCategory().equals(o2.getCategory()))
toReturn = -1;
return toReturn;
}
});
for (K category : this.getCategories())
probabilities.add(new Classification<T, K>(
features, category,
this.categoryProbability(features, category)));
return probabilities;
}
/**
* Classifies the given set of features.
*
* @return The category the set of features is classified as.
*/
@Override
public Classification<T, K> classify(Collection<T> features) {
SortedSet<Classification<T, K>> probabilites =
this.categoryProbabilities(features);
if (probabilites.size() > 0) {
return probabilites.last();
}
return null;
}
/**
* Classifies the given set of features. and return the full details of the
* classification.
*
* @return The set of categories the set of features is classified as.
*/
public Collection<Classification<T, K>> classifyDetailed(
Collection<T> features) {
return this.categoryProbabilities(features);
}
}