forked from bwaldvogel/liblinear-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathModel.java
More file actions
178 lines (150 loc) · 5.08 KB
/
Copy pathModel.java
File metadata and controls
178 lines (150 loc) · 5.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
package liblinear;
import static liblinear.Linear.copyOf;
import java.io.File;
import java.io.IOException;
import java.io.Reader;
import java.io.Serializable;
import java.io.Writer;
import java.util.Arrays;
/**
* <p>Model stores the model obtained from the training procedure</p>
*
* <p>use {@link Linear#loadModel(String)} and {@link Linear#saveModel(String, Model)} to load/save it</p>
*/
public final class Model implements Serializable {
private static final long serialVersionUID = -6456047576741854834L;
double bias;
/** label of each class */
int[] label;
int nr_class;
int nr_feature;
SolverType solverType;
/** feature weight array */
double[] w;
/**
* @return number of classes
*/
public int getNrClass() {
return nr_class;
}
/**
* @return number of features
*/
public int getNrFeature() {
return nr_feature;
}
public int[] getLabels() {
return copyOf(label, nr_class);
}
/**
* The nr_feature*nr_class array w gives feature weights. We use one
* against the rest for multi-class classification, so each feature
* index corresponds to nr_class weight values. Weights are
* organized in the following way
*
* <pre>
* +------------------+------------------+------------+
* | nr_class weights | nr_class weights | ...
* | for 1st feature | for 2nd feature |
* +------------------+------------------+------------+
* </pre>
*
* If bias >= 0, x becomes [x; bias]. The number of features is
* increased by one, so w is a (nr_feature+1)*nr_class array. The
* value of bias is stored in the variable bias.
* @see #getBias()
* @return a <b>copy of</b> the feature weight array as described
*/
public double[] getFeatureWeights() {
return Linear.copyOf(w, w.length);
}
/**
* @return true for logistic regression solvers
*/
public boolean isProbabilityModel() {
return (solverType == SolverType.L2R_LR || solverType == SolverType.L2R_LR_DUAL || solverType == SolverType.L1R_LR);
}
/**
* @see #getFeatureWeights()
*/
public double getBias() {
return bias;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder("Model");
sb.append(" bias=").append(bias);
sb.append(" nr_class=").append(nr_class);
sb.append(" nr_feature=").append(nr_feature);
sb.append(" solverType=").append(solverType);
return sb.toString();
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
long temp;
temp = Double.doubleToLongBits(bias);
result = prime * result + (int)(temp ^ (temp >>> 32));
result = prime * result + Arrays.hashCode(label);
result = prime * result + nr_class;
result = prime * result + nr_feature;
result = prime * result + ((solverType == null) ? 0 : solverType.hashCode());
result = prime * result + Arrays.hashCode(w);
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (obj == null) return false;
if (getClass() != obj.getClass()) return false;
Model other = (Model)obj;
if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false;
if (!Arrays.equals(label, other.label)) return false;
if (nr_class != other.nr_class) return false;
if (nr_feature != other.nr_feature) return false;
if (solverType == null) {
if (other.solverType != null) return false;
} else if (!solverType.equals(other.solverType)) return false;
if (!equals(w, other.w)) return false;
return true;
}
/**
* don't use {@link Arrays#equals(double[], double[])} here, cause 0.0 and -0.0 should be handled the same
*
* @see Linear#saveModel(java.io.Writer, Model)
*/
protected static boolean equals(double[] a, double[] a2) {
if (a == a2) return true;
if (a == null || a2 == null) return false;
int length = a.length;
if (a2.length != length) return false;
for (int i = 0; i < length; i++)
if (a[i] != a2[i]) return false;
return true;
}
/**
* see {@link Linear#saveModel(java.io.File, Model)}
*/
public void save(File file) throws IOException {
Linear.saveModel(file, this);
}
/**
* see {@link Linear#saveModel(Writer, Model)}
*/
public void save(Writer writer) throws IOException {
Linear.saveModel(writer, this);
}
/**
* see {@link Linear#loadModel(File)}
*/
public static Model load(File file) throws IOException {
return Linear.loadModel(file);
}
/**
* see {@link Linear#loadModel(Reader)}
*/
public static Model load(Reader inputReader) throws IOException {
return Linear.loadModel(inputReader);
}
}