forked from sPredictorX1708/Ultimate-Java-Resources
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNaive_Bayes_Classifier.java
More file actions
138 lines (120 loc) · 4.04 KB
/
Naive_Bayes_Classifier.java
File metadata and controls
138 lines (120 loc) · 4.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
/* Naive Bayes Classifier using Java */
/**
* Naive Bayes classifier. A naive Bayes classifier is a simple probabilistic
* classifier based on applying Bayes' theorem with strong (naive) independence
* assumptions. Depending on the precise nature of the probability model, naive
* Bayes classifiers can be trained very efficiently in a supervised learning
* setting. **/
package smile.classification;
import smile.math.MathEx;
import smile.stat.distribution.Distribution;
import smile.util.IntSet;
public class NaiveBayes implements SoftClassifier<double[]> {
private static final long serialVersionUID = 2L;
/**
* The number of classes.
*/
private int k;
/**
* The number of independent variables.
*/
private int p;
/**
* The priori probability of each class.
*/
private double[] priori;
/**
* The conditional distribution for general purpose naive Bayes classifier.
*/
private Distribution[][] prob;
/**
* The class label encoder.
*/
private IntSet labels;
/**
* Constructor of general naive Bayes classifier.
*
* @param priori the priori probability of each class.
* @param condprob the conditional distribution of each variable in
* each class. In particular, condprob[i][j] is the conditional
* distribution P(x<sub>j</sub> | class i).
*/
public NaiveBayes(double[] priori, Distribution[][] condprob) {
this(priori, condprob, IntSet.of(priori.length));
}
/**
* Constructor of general naive Bayes classifier.
*
* @param priori the priori probability of each class.
* @param condprob the conditional distribution of each variable in
* each class. In particular, condprob[i][j] is the conditional
* distribution P(x<sub>j</sub> | class i).
* @param labels class labels
*/
public NaiveBayes(double[] priori, Distribution[][] condprob, IntSet labels) {
if (priori.length != condprob.length) {
throw new IllegalArgumentException("The number of priori probabilities and that of the classes are not same.");
}
double sum = 0.0;
for (double pr : priori) {
if (pr <= 0.0 || pr >= 1.0) {
throw new IllegalArgumentException("Invalid priori probability: " + pr);
}
sum += pr;
}
if (Math.abs(sum - 1.0) > 1E-5) {
throw new IllegalArgumentException("The sum of priori probabilities is not one: " + sum);
}
this.k = priori.length;
this.p = condprob[0].length;
this.priori = priori;
this.prob = condprob;
this.labels = labels;
}
/**
* Returns a priori probabilities.
*/
public double[] priori() {
return priori;
}
/**
* Predict the class of an instance.
*
* @param x the instance to be classified.
* @return the predicted class label.
*/
@Override
public int predict(double[] x) {
return predict(x, new double[k]);
}
/**
* Predict the class of an instance.
*
* @param x the instance to be classified.
* @param posteriori the array to store a posteriori probabilities on output.
* @return the predicted class label.
*/
@Override
public int predict(double[] x, double[] posteriori) {
if (x.length != p) {
throw new IllegalArgumentException(String.format("Invalid input vector size: %d", x.length));
}
for (int i = 0; i < k; i++) {
double logprob = Math.log(priori[i]);
for (int j = 0; j < p; j++) {
logprob += prob[i][j].logp(x[j]);
}
posteriori[i] = logprob;
}
double Z = 0.0;
double max = MathEx.max(posteriori);
for (int i = 0; i < k; i++) {
posteriori[i] = Math.exp(posteriori[i] - max);
Z += posteriori[i];
}
for (int i = 0; i < k; i++) {
posteriori[i] /= Z;
}
return labels.valueOf(MathEx.whichMax(posteriori));
}
}