Skip to content

Commit 2c95dc9

Browse files
committed
overfitting unit tests are added
1 parent 08a559e commit 2c95dc9

File tree

2 files changed

+195
-2
lines changed

2 files changed

+195
-2
lines changed

lbjava/src/main/java/edu/illinois/cs/cogcomp/lbjava/learn/AdaGrad.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ public double realValue(int[] exampleFeatures, double[] exampleValues) {
284284
for(int i = 0; i < exampleFeatures.length; i++) {
285285
weightDotProductX += weightVector[i] * exampleValues[i];
286286
}
287+
weightDotProductX += weightVector[weightVector.length-1];
287288
return weightDotProductX;
288289
}
289290

@@ -375,4 +376,4 @@ public Parameters() {
375376
lossFunctionP = defaultLossFunction;
376377
}
377378
}
378-
}
379+
}

lbjava/src/test/java/edu/illinois/cs/cogcomp/lbjava/AdaGradTest.java

Lines changed: 193 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
import edu.illinois.cs.cogcomp.lbjava.learn.AdaGrad;
44
import org.junit.Before;
55
import org.junit.Test;
6+
import java.util.ArrayList;
7+
import java.util.Arrays;
8+
import java.util.Random;
9+
610
import static org.junit.Assert.*;
711

812
/**
9-
* Unit Test Class for <code>AdaGrad</code> class
13+
* Unit tests for <code>AdaGrad</code> class
1014
*
1115
* @author Yiming Jiang
1216
*/
@@ -130,4 +134,192 @@ public void testHingeLossLearn() {
130134

131135
assertArrayEquals(exp_w3, w3, 0.000001);
132136
}
137+
138+
/**
139+
* This is a simple test to test for overfitting
140+
*
141+
* The <code>AdaGrad</code> is given with simple data set with 2 features and a label.
142+
* Train the data set for 30 iterations and
143+
* see if the algorithm can classify the same data set correctly.
144+
*/
145+
@Test
146+
public void overfittingSimpleTest() {
147+
/**
148+
* static data set;
149+
* the first 2 numbers are 2 features and the last one is the label;
150+
* this data set is linearly separable
151+
*/
152+
double [][] dataSet = new double[][]{
153+
{-2, -4, 1},
154+
{-2, 0, 1},
155+
{0, 2, 1},
156+
{-2, 2, 1},
157+
{0, 4, 1},
158+
{2, 2, -1},
159+
{2, -2, -1},
160+
{0, -4, -1},
161+
{2, -4, -1},
162+
{4, -2, -1}
163+
};
164+
165+
int[] exampleFeatures = {0, 1};
166+
int[] exampleLabels = {0};
167+
168+
double[] exampleValues = {0, 0};
169+
double[] labelValues = {0};
170+
171+
/* train <code>AdaGrad</code> for 30 iterations */
172+
for (int i = 0; i < 30; i++) {
173+
exampleValues[0] = dataSet[i%10][0];
174+
exampleValues[1] = dataSet[i%10][1];
175+
labelValues[0] = dataSet[i%10][2];
176+
177+
learner.learn(exampleFeatures, exampleValues, exampleLabels, labelValues);
178+
}
179+
180+
/* test against the same data set */
181+
int correctNumber = 0;
182+
for (int i = 0; i < 10; i++) {
183+
exampleValues[0] = dataSet[i][0];
184+
exampleValues[1] = dataSet[i][1];
185+
186+
double result = learner.realValue(exampleFeatures, exampleValues);
187+
188+
if (result * dataSet[i][2] > 0) {
189+
correctNumber ++;
190+
}
191+
}
192+
assertEquals(10, correctNumber);
193+
}
194+
195+
/**
196+
* This is a complete test to test overfitting in <code>AdaGrad</code>
197+
*
198+
* Data set consists of 10 examples, each with 2 features;
199+
* Each feature value is randomly generated from range [0, 10];
200+
*
201+
* A "correct" weight vector is randomly generated;
202+
* Each value is from range [0, 10];
203+
*
204+
* The hyperplane is set by taking the medium of w*x;
205+
* Almost half of examples are labeled as +1; the rest are labeled -1;
206+
*
207+
* Thus, the data set is linearly separable, while being random
208+
*
209+
* <code>AdaGrad</code> learning algorithm will train on this data set for 100 iterations.
210+
*
211+
* Then it will be tested using the same data set to see if all classifications are correct.
212+
*/
213+
@Test
214+
public void overfittingCompleteTest() {
215+
216+
/* set constant learning rate */
217+
AdaGrad.Parameters p = new AdaGrad.Parameters();
218+
p.learningRateP = 10;
219+
learner.setParameters(p);
220+
221+
/* give a seed to rand */
222+
Random rand = new Random(0);
223+
224+
/** create 10 examples, each with 2 features,
225+
* with values randomly generated from [0, 10]
226+
*/
227+
ArrayList<ArrayList<Double>> dataSet = new ArrayList<ArrayList<Double>>();
228+
229+
for(int i = 0; i < 10; i++) {
230+
ArrayList<Double> eachExample = new ArrayList<Double>();
231+
eachExample.add((double) randInt(rand, 0, 10));
232+
eachExample.add((double) randInt(rand, 0, 10));
233+
eachExample.add(0.0);
234+
dataSet.add(eachExample);
235+
}
236+
237+
/* randomly generate the "correct" weight vector */
238+
ArrayList<Double> weightVector = new ArrayList<>();
239+
weightVector.add((double) randInt(rand, 0, 10));
240+
weightVector.add((double) randInt(rand, 0, 10));
241+
weightVector.add((double) randInt(rand, 0, 10));
242+
243+
/* compute all w*x and set the medium to the decision boundary */
244+
double[] resultVector = new double[10];
245+
for (int i = 0; i < 10; i++) {
246+
resultVector[i] = computeDotProduct(dataSet.get(i), weightVector);
247+
}
248+
Arrays.sort(resultVector);
249+
250+
double medium = resultVector[4];
251+
252+
/* for which example w*x >= medium, the label is set to +1, otherwise -1 */
253+
for (int i = 0; i < 10; i++) {
254+
if (computeDotProduct(dataSet.get(i), weightVector) >= medium) {
255+
dataSet.get(i).set(2, 1.0);
256+
}
257+
else {
258+
dataSet.get(i).set(2, -1.0);
259+
}
260+
}
261+
262+
int[] exampleFeatures = {0, 1};
263+
int[] exampleLabels = {0};
264+
265+
double[] exampleValues = {0, 0};
266+
double[] labelValues = {0};
267+
268+
/* train <code>AdaGrad</code> for 100 iterations */
269+
for (int i = 0; i < 100; i++) {
270+
exampleValues[0] = dataSet.get(i % 10).get(0);
271+
exampleValues[1] = dataSet.get(i % 10).get(1);
272+
labelValues[0] = dataSet.get(i % 10).get(2);
273+
274+
learner.learn(exampleFeatures, exampleValues, exampleLabels, labelValues);
275+
}
276+
277+
/* test against the same data set */
278+
int correctNumber = 0;
279+
for (int i = 0; i < 10; i++) {
280+
exampleValues[0] = dataSet.get(i % 10).get(0);
281+
exampleValues[1] = dataSet.get(i % 10).get(1);
282+
283+
double result = learner.realValue(exampleFeatures, exampleValues);
284+
285+
if (result * dataSet.get(i % 10).get(2) > 0) {
286+
correctNumber ++;
287+
}
288+
}
289+
290+
/* test if the all classifications are correct */
291+
assertTrue((correctNumber == 10));
292+
}
293+
294+
/**
295+
* Compute the dot product of weight vector and feature vector
296+
* @param x feature vector
297+
* @param w weight vector
298+
* @return dot product result
299+
*/
300+
private double computeDotProduct(ArrayList<Double> x, ArrayList<Double> w) {
301+
double result = 0.0;
302+
for (int i = 0; i < x.size()-1; i++) {
303+
result += x.get(i) * w.get(i);
304+
}
305+
result += w.get(w.size() - 1);
306+
return result;
307+
}
308+
309+
/**
310+
* Returns a pseudo-random number between min and max, inclusive.
311+
* The difference between min and max can be at most
312+
* <code>Integer.MAX_VALUE - 1</code>.
313+
*
314+
* @param rand random instance
315+
* @param min minimim value
316+
* @param max maximim value. Must be greater than min.
317+
* @return integer between min and max, inclusive.
318+
* @see java.util.Random#nextInt(int)
319+
*
320+
* Reference: http://stackoverflow.com/questions/20389890/generating-a-random-number-between-1-and-10-java
321+
*/
322+
private int randInt(Random rand, int min, int max) {
323+
return rand.nextInt((max - min) + 1) + min;
324+
}
133325
}

0 commit comments

Comments
 (0)