33import edu .illinois .cs .cogcomp .lbjava .learn .AdaGrad ;
44import org .junit .Before ;
55import org .junit .Test ;
6+ import java .util .ArrayList ;
7+ import java .util .Arrays ;
8+ import java .util .Random ;
9+
610import static org .junit .Assert .*;
711
812/**
9- * Unit Test Class for <code>AdaGrad</code> class
13+ * Unit tests for <code>AdaGrad</code> class
1014 *
1115 * @author Yiming Jiang
1216 */
@@ -130,4 +134,192 @@ public void testHingeLossLearn() {
130134
131135 assertArrayEquals (exp_w3 , w3 , 0.000001 );
132136 }
137+
138+ /**
139+ * This is a simple test to test for overfitting
140+ *
141+ * The <code>AdaGrad</code> is given with simple data set with 2 features and a label.
142+ * Train the data set for 30 iterations and
143+ * see if the algorithm can classify the same data set correctly.
144+ */
145+ @ Test
146+ public void overfittingSimpleTest () {
147+ /**
148+ * static data set;
149+ * the first 2 numbers are 2 features and the last one is the label;
150+ * this data set is linearly separable
151+ */
152+ double [][] dataSet = new double [][]{
153+ {-2 , -4 , 1 },
154+ {-2 , 0 , 1 },
155+ {0 , 2 , 1 },
156+ {-2 , 2 , 1 },
157+ {0 , 4 , 1 },
158+ {2 , 2 , -1 },
159+ {2 , -2 , -1 },
160+ {0 , -4 , -1 },
161+ {2 , -4 , -1 },
162+ {4 , -2 , -1 }
163+ };
164+
165+ int [] exampleFeatures = {0 , 1 };
166+ int [] exampleLabels = {0 };
167+
168+ double [] exampleValues = {0 , 0 };
169+ double [] labelValues = {0 };
170+
171+ /* train <code>AdaGrad</code> for 30 iterations */
172+ for (int i = 0 ; i < 30 ; i ++) {
173+ exampleValues [0 ] = dataSet [i %10 ][0 ];
174+ exampleValues [1 ] = dataSet [i %10 ][1 ];
175+ labelValues [0 ] = dataSet [i %10 ][2 ];
176+
177+ learner .learn (exampleFeatures , exampleValues , exampleLabels , labelValues );
178+ }
179+
180+ /* test against the same data set */
181+ int correctNumber = 0 ;
182+ for (int i = 0 ; i < 10 ; i ++) {
183+ exampleValues [0 ] = dataSet [i ][0 ];
184+ exampleValues [1 ] = dataSet [i ][1 ];
185+
186+ double result = learner .realValue (exampleFeatures , exampleValues );
187+
188+ if (result * dataSet [i ][2 ] > 0 ) {
189+ correctNumber ++;
190+ }
191+ }
192+ assertEquals (10 , correctNumber );
193+ }
194+
195+ /**
196+ * This is a complete test to test overfitting in <code>AdaGrad</code>
197+ *
198+ * Data set consists of 10 examples, each with 2 features;
199+ * Each feature value is randomly generated from range [0, 10];
200+ *
201+ * A "correct" weight vector is randomly generated;
202+ * Each value is from range [0, 10];
203+ *
204+ * The hyperplane is set by taking the medium of w*x;
205+ * Almost half of examples are labeled as +1; the rest are labeled -1;
206+ *
207+ * Thus, the data set is linearly separable, while being random
208+ *
209+ * <code>AdaGrad</code> learning algorithm will train on this data set for 100 iterations.
210+ *
211+ * Then it will be tested using the same data set to see if all classifications are correct.
212+ */
213+ @ Test
214+ public void overfittingCompleteTest () {
215+
216+ /* set constant learning rate */
217+ AdaGrad .Parameters p = new AdaGrad .Parameters ();
218+ p .learningRateP = 10 ;
219+ learner .setParameters (p );
220+
221+ /* give a seed to rand */
222+ Random rand = new Random (0 );
223+
224+ /** create 10 examples, each with 2 features,
225+ * with values randomly generated from [0, 10]
226+ */
227+ ArrayList <ArrayList <Double >> dataSet = new ArrayList <ArrayList <Double >>();
228+
229+ for (int i = 0 ; i < 10 ; i ++) {
230+ ArrayList <Double > eachExample = new ArrayList <Double >();
231+ eachExample .add ((double ) randInt (rand , 0 , 10 ));
232+ eachExample .add ((double ) randInt (rand , 0 , 10 ));
233+ eachExample .add (0.0 );
234+ dataSet .add (eachExample );
235+ }
236+
237+ /* randomly generate the "correct" weight vector */
238+ ArrayList <Double > weightVector = new ArrayList <>();
239+ weightVector .add ((double ) randInt (rand , 0 , 10 ));
240+ weightVector .add ((double ) randInt (rand , 0 , 10 ));
241+ weightVector .add ((double ) randInt (rand , 0 , 10 ));
242+
243+ /* compute all w*x and set the medium to the decision boundary */
244+ double [] resultVector = new double [10 ];
245+ for (int i = 0 ; i < 10 ; i ++) {
246+ resultVector [i ] = computeDotProduct (dataSet .get (i ), weightVector );
247+ }
248+ Arrays .sort (resultVector );
249+
250+ double medium = resultVector [4 ];
251+
252+ /* for which example w*x >= medium, the label is set to +1, otherwise -1 */
253+ for (int i = 0 ; i < 10 ; i ++) {
254+ if (computeDotProduct (dataSet .get (i ), weightVector ) >= medium ) {
255+ dataSet .get (i ).set (2 , 1.0 );
256+ }
257+ else {
258+ dataSet .get (i ).set (2 , -1.0 );
259+ }
260+ }
261+
262+ int [] exampleFeatures = {0 , 1 };
263+ int [] exampleLabels = {0 };
264+
265+ double [] exampleValues = {0 , 0 };
266+ double [] labelValues = {0 };
267+
268+ /* train <code>AdaGrad</code> for 100 iterations */
269+ for (int i = 0 ; i < 100 ; i ++) {
270+ exampleValues [0 ] = dataSet .get (i % 10 ).get (0 );
271+ exampleValues [1 ] = dataSet .get (i % 10 ).get (1 );
272+ labelValues [0 ] = dataSet .get (i % 10 ).get (2 );
273+
274+ learner .learn (exampleFeatures , exampleValues , exampleLabels , labelValues );
275+ }
276+
277+ /* test against the same data set */
278+ int correctNumber = 0 ;
279+ for (int i = 0 ; i < 10 ; i ++) {
280+ exampleValues [0 ] = dataSet .get (i % 10 ).get (0 );
281+ exampleValues [1 ] = dataSet .get (i % 10 ).get (1 );
282+
283+ double result = learner .realValue (exampleFeatures , exampleValues );
284+
285+ if (result * dataSet .get (i % 10 ).get (2 ) > 0 ) {
286+ correctNumber ++;
287+ }
288+ }
289+
290+ /* test if the all classifications are correct */
291+ assertTrue ((correctNumber == 10 ));
292+ }
293+
294+ /**
295+ * Compute the dot product of weight vector and feature vector
296+ * @param x feature vector
297+ * @param w weight vector
298+ * @return dot product result
299+ */
300+ private double computeDotProduct (ArrayList <Double > x , ArrayList <Double > w ) {
301+ double result = 0.0 ;
302+ for (int i = 0 ; i < x .size ()-1 ; i ++) {
303+ result += x .get (i ) * w .get (i );
304+ }
305+ result += w .get (w .size () - 1 );
306+ return result ;
307+ }
308+
309+ /**
310+ * Returns a pseudo-random number between min and max, inclusive.
311+ * The difference between min and max can be at most
312+ * <code>Integer.MAX_VALUE - 1</code>.
313+ *
314+ * @param rand random instance
315+ * @param min minimim value
316+ * @param max maximim value. Must be greater than min.
317+ * @return integer between min and max, inclusive.
318+ * @see java.util.Random#nextInt(int)
319+ *
320+ * Reference: http://stackoverflow.com/questions/20389890/generating-a-random-number-between-1-and-10-java
321+ */
322+ private int randInt (Random rand , int min , int max ) {
323+ return rand .nextInt ((max - min ) + 1 ) + min ;
324+ }
133325}
0 commit comments