Skip to content

Commit f35f841

Browse files
author
Daniel Khashabi
authored
Merge branch 'master' into neuralnet
2 parents 7a9eba7 + 32587db commit f35f841

20 files changed

Lines changed: 1134 additions & 108 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,4 @@ lbjava-examples/src/main/java/edu/illinois/cs/cogcomp/lbjava/examples/spam/SpamC
6363
lbjava-examples/src/main/java/edu/illinois/cs/cogcomp/lbjava/examples/regression/MyFeatures.java
6464
lbjava-examples/src/main/java/edu/illinois/cs/cogcomp/lbjava/examples/regression/MyLabel.java
6565
lbjava-examples/src/main/java/edu/illinois/cs/cogcomp/lbjava/examples/regression/SGDClassifier.java
66+
/.metadata/

lbjava-examples/pom.xml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
<parent>
44
<artifactId>lbjava-project</artifactId>
55
<groupId>edu.illinois.cs.cogcomp</groupId>
6-
<version>1.2.26</version>
6+
<version>1.3.0</version>
77
</parent>
88

99
<modelVersion>4.0.0</modelVersion>
@@ -27,12 +27,12 @@
2727
<dependency>
2828
<groupId>edu.illinois.cs.cogcomp</groupId>
2929
<artifactId>LBJava</artifactId>
30-
<version>1.2.26</version>
30+
<version>1.3.0</version>
3131
</dependency>
3232
<dependency>
3333
<groupId>edu.illinois.cs.cogcomp</groupId>
3434
<artifactId>lbjava-maven-plugin</artifactId>
35-
<version>1.2.26</version>
35+
<version>1.3.0</version>
3636
</dependency>
3737
</dependencies>
3838

@@ -63,7 +63,7 @@
6363
<plugin>
6464
<groupId>edu.illinois.cs.cogcomp</groupId>
6565
<artifactId>lbjava-maven-plugin</artifactId>
66-
<version>1.2.26</version>
66+
<version>1.3.0</version>
6767
<configuration>
6868
<gspFlag>${project.basedir}/src/main/java</gspFlag>
6969
<dFlag>${project.basedir}/target/classes</dFlag>

lbjava-mvn-plugin/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
<parent>
66
<artifactId>lbjava-project</artifactId>
77
<groupId>edu.illinois.cs.cogcomp</groupId>
8-
<version>1.2.26</version>
8+
<version>1.3.0</version>
99
</parent>
1010

1111
<artifactId>lbjava-maven-plugin</artifactId>
@@ -76,7 +76,7 @@
7676
<dependency>
7777
<groupId>edu.illinois.cs.cogcomp</groupId>
7878
<artifactId>LBJava</artifactId>
79-
<version>1.2.26</version>
79+
<version>1.3.0</version>
8080
<type>jar</type>
8181
<scope>compile</scope>
8282
</dependency>

lbjava/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
<parent>
44
<artifactId>lbjava-project</artifactId>
55
<groupId>edu.illinois.cs.cogcomp</groupId>
6-
<version>1.2.26</version>
6+
<version>1.3.0</version>
77
</parent>
88

99
<modelVersion>4.0.0</modelVersion>

lbjava/src/main/java/edu/illinois/cs/cogcomp/lbjava/Train.java

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -880,53 +880,57 @@ public void run() {
880880
if (!lce.onlyCodeGeneration) {
881881
// If there's a "from" clause, train.
882882
try {
883-
if (lce.parser != null) {
884-
System.out.println("Training " + getName());
885-
if (preExtract) {
886-
preExtractAndPrune();
887-
System.gc();
888-
} else
889-
learner.saveLexicon();
890-
int trainingRounds = 1;
891-
892-
if (tuningParameters) {
893-
String parametersPath = getName();
894-
if (Main.classDirectory != null)
895-
parametersPath =
896-
Main.classDirectory + File.separator + parametersPath;
897-
parametersPath += ".p";
898-
899-
Learner.Parameters bestParameters = tune();
900-
trainingRounds = bestParameters.rounds;
901-
Learner.writeParameters(bestParameters, parametersPath);
902-
System.out.println(" " + getName()
903-
+ ": Training on entire training set");
904-
} else {
905-
if (lce.rounds != null)
906-
trainingRounds = Integer.parseInt(((Constant) lce.rounds).value);
907-
908-
if (lce.K != null) {
909-
int[] rounds = {trainingRounds};
910-
int k = Integer.parseInt(lce.K.value);
911-
double alpha = Double.parseDouble(lce.alpha.value);
912-
trainer.crossValidation(rounds, k, lce.splitPolicy, alpha,
913-
testingMetric, true);
883+
learner.beginTraining();
884+
try {
885+
if (lce.parser != null) {
886+
System.out.println("Training " + getName());
887+
if (preExtract) {
888+
preExtractAndPrune();
889+
System.gc();
890+
} else
891+
learner.saveLexicon();
892+
int trainingRounds = 1;
893+
894+
if (tuningParameters) {
895+
String parametersPath = getName();
896+
if (Main.classDirectory != null)
897+
parametersPath =
898+
Main.classDirectory + File.separator + parametersPath;
899+
parametersPath += ".p";
900+
901+
Learner.Parameters bestParameters = tune();
902+
trainingRounds = bestParameters.rounds;
903+
Learner.writeParameters(bestParameters, parametersPath);
914904
System.out.println(" " + getName()
915905
+ ": Training on entire training set");
906+
} else {
907+
if (lce.rounds != null)
908+
trainingRounds = Integer.parseInt(((Constant) lce.rounds).value);
909+
910+
if (lce.K != null) {
911+
int[] rounds = {trainingRounds};
912+
int k = Integer.parseInt(lce.K.value);
913+
double alpha = Double.parseDouble(lce.alpha.value);
914+
trainer.crossValidation(rounds, k, lce.splitPolicy, alpha,
915+
testingMetric, true);
916+
System.out.println(" " + getName()
917+
+ ": Training on entire training set");
918+
}
916919
}
917-
}
918-
919-
trainer.train(lce.startingRound, trainingRounds);
920-
921-
if (testParser != null) {
922-
System.out.println("Testing " + getName());
923-
new Accuracy(true).test(learner, learner.getLabeler(), testParser);
924-
}
925-
926-
System.out.println("Writing " + getName());
927-
} else
928-
learner.saveLexicon(); // Writes .lex even if lexicon is empty.
929-
920+
trainer.train(lce.startingRound, trainingRounds);
921+
} else
922+
learner.saveLexicon(); // Writes .lex even if lexicon is empty.
923+
} finally {
924+
learner.doneTraining();
925+
}
926+
927+
if (lce.parser != null && testParser != null) {
928+
System.out.println("Testing " + getName());
929+
new Accuracy(true).test(learner, learner.getLabeler(), testParser);
930+
}
931+
932+
// save the final model.
933+
System.out.println("Writing " + getName());
930934
learner.save(); // Doesn't write .lex if lexicon is empty.
931935
} catch (Exception e) {
932936
System.err.println("LBJava ERROR: Exception while training " + getName() + ":");

lbjava/src/main/java/edu/illinois/cs/cogcomp/lbjava/learn/Learner.java

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ public abstract class Learner extends Classifier {
6666

6767
/** The number of candidate examples when a global object is passed here. */
6868
protected int candidates = 1;
69+
70+
/** this is set while training. */
71+
protected boolean intraining = false;
6972

7073
/**
7174
* This constructor is used by the LBJava compiler; it should never be called by a programmer.
@@ -259,7 +262,6 @@ public URL getModelLocation() {
259262
return lcFilePath;
260263
}
261264

262-
263265
/**
264266
* Sets the location of the lexicon as a regular file on this file system.
265267
*
@@ -289,7 +291,6 @@ public URL getLexiconLocation() {
289291
return lexFilePath;
290292
}
291293

292-
293294
/**
294295
* Establishes a new feature counting policy for this learner's lexicon.
295296
*
@@ -304,7 +305,6 @@ public void countFeatures(Lexicon.CountPolicy policy) {
304305
lexicon.countFeatures(policy);
305306
}
306307

307-
308308
/**
309309
* Returns this learner's feature lexicon after discarding any feature counts it may have been
310310
* storing. This method is likely only useful when the lexicon and its counts are currently
@@ -320,7 +320,6 @@ public Lexicon getLexiconDiscardCounts() {
320320
return lexicon;
321321
}
322322

323-
324323
/**
325324
* Returns a new, emtpy learner into which all of the parameters that control the behavior of
326325
* the algorithm have been copied. Here, "emtpy" means no learning has taken place.
@@ -331,7 +330,6 @@ public Learner emptyClone() {
331330
return clone;
332331
}
333332

334-
335333
/**
336334
* Trains the learning algorithm given an object as an example. By default, this simply converts
337335
* the example object into arrays and passes it to {@link #learn(int[],double[],int[],double[])}
@@ -345,7 +343,6 @@ public void learn(Object example) {
345343
(double[]) exampleArray[3]);
346344
}
347345

348-
349346
/**
350347
* Trains the learning algorithm given a feature vector as an example. This simply converts the
351348
* example object into arrays and passes it to {@link #learn(int[],double[],int[],double[])}.
@@ -633,6 +630,15 @@ public double realValue(int[] f, double[] v) {
633630
+ getClass().getName() + "'.");
634631
}
635632

633+
/**
634+
* Start training, this might involve training many models, for cross validation,
635+
* parameter tuning and so on.
636+
**/
637+
public void beginTraining() {
638+
intraining = true;
639+
}
640+
641+
636642

637643
/**
638644
* Overridden by subclasses to perform any required post-processing computations after all
@@ -642,6 +648,21 @@ public double realValue(int[] f, double[] v) {
642648
public void doneLearning() {}
643649

644650

651+
/**
652+
* Overridden by subclasses to perform any required post-training computations optimizations,
653+
* in particular, feature subset reduction. This default method does nothing.
654+
*/
655+
public void doneTraining() {
656+
if (intraining) {
657+
intraining = false;
658+
} else {
659+
throw new RuntimeException("calling doneLearning without previously calling beginTraining"
660+
+ " violates the lifecycle contract. Or perhaps the subclass does not call the superclass "
661+
+ "method. Contact the developer.");
662+
}
663+
}
664+
665+
645666
/**
646667
* This method is sometimes called before training begins, although it is not guaranteed to be
647668
* called at all. It allows the number of examples and number of features to be passed to the

lbjava/src/main/java/edu/illinois/cs/cogcomp/lbjava/learn/Lexicon.java

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99

1010
import java.io.Serializable;
1111
import java.net.URL;
12+
import java.util.Arrays;
1213
import java.util.Collections;
1314
import java.util.HashMap;
1415
import java.util.Map;
1516

1617
import edu.illinois.cs.cogcomp.core.datastructures.vectors.*;
18+
import edu.illinois.cs.cogcomp.lbjava.classify.DiscreteConjunctiveFeature;
1719
import edu.illinois.cs.cogcomp.lbjava.classify.Feature;
20+
import edu.illinois.cs.cogcomp.lbjava.classify.RealConjunctiveFeature;
1821
import edu.illinois.cs.cogcomp.lbjava.util.ByteString;
1922
import edu.illinois.cs.cogcomp.lbjava.util.ClassUtils;
2023
import edu.illinois.cs.cogcomp.lbjava.util.FVector;
@@ -305,7 +308,7 @@ public boolean contains(Feature f) {
305308
*
306309
* @param f The feature to look up.
307310
* @return The integer key that the feature maps to.
308-
**/
311+
**/
309312
public int lookup(Feature f) {
310313
return lookup(f, false, -1);
311314
}
@@ -661,6 +664,36 @@ public void discardPrunedFeatures() {
661664
pruneCutoff = -1;
662665
}
663666

667+
/**
668+
* Discard features at the provided indices. This operation is performed
669+
* last to first so we can do it in place. This method will sort the input
670+
* array.
671+
* @param dumpthese the indexes of the features to dump.
672+
*/
673+
public void discardPrunedFeatures(int [] dumpthese) {
674+
Arrays.sort(dumpthese);
675+
lexiconInv.remove(dumpthese);
676+
677+
// this compresses the FVector
678+
lexiconInv = new FVector(lexiconInv);
679+
if (lexicon != null) {
680+
681+
// reconstitute the lexicon.
682+
lexicon.clear();
683+
for (int i = 0; i < lexiconInv.size();i++) {
684+
lexicon.put(lexiconInv.get(i), new Integer(i));
685+
}
686+
687+
// sanity check, make sure the indices in the lexicon map matches the index in the feature vector
688+
for (int i = 0; i < lexiconInv.size();i++) {
689+
if (i != ((Integer)lexicon.get(lexiconInv.get(i))).intValue()) {
690+
throw new RuntimeException("After optimization pruning, the index in the lexicon did "
691+
+ "not match the inverted index.");
692+
}
693+
}
694+
}
695+
}
696+
664697

665698
/**
666699
* <!-- clone() --> Returns a deep clone of this lexicon implemented as a <code>HashMap</code>.
@@ -742,10 +775,9 @@ public int compare(int i1, int i2) {
742775
ByteString previousBSIdentifier = null;
743776
out.writeInt(indexes.length);
744777
out.writeInt(pruneCutoff);
745-
746778
for (int i = 0; i < indexes.length; ++i) {
747779
Feature f = inverse.get(indexes[i]);
748-
previousClassName =
780+
previousClassName =
749781
f.lexWrite(out, this, previousClassName, previousPackage, previousClassifier,
750782
previousSIdentifier, previousBSIdentifier);
751783
previousPackage = f.getPackage();
@@ -757,7 +789,6 @@ else if (f.hasByteStringIdentifier())
757789

758790
out.writeInt(indexes[i]);
759791
}
760-
761792
if (featureCounts == null)
762793
out.writeInt(0);
763794
else
@@ -801,14 +832,12 @@ public void read(ExceptionlessInputStream in, boolean readCounts) {
801832
pruneCutoff = in.readInt();
802833
lexicon = null;
803834
lexiconInv = new FVector(N);
804-
805835
for (int i = 0; i < N; ++i) {
806836
Feature f =
807837
Feature.lexReadFeature(in, this, previousClass, previousPackage,
808838
previousClassifier, previousSIdentifier, previousBSIdentifier);
809839
int index = in.readInt();
810840
lexiconInv.set(index, f);
811-
812841
previousClass = f.getClass();
813842
previousPackage = f.getPackage();
814843
previousClassifier = f.getGeneratingClassifier();
@@ -817,7 +846,7 @@ public void read(ExceptionlessInputStream in, boolean readCounts) {
817846
else if (f.hasByteStringIdentifier())
818847
previousBSIdentifier = f.getByteStringIdentifier();
819848
}
820-
849+
821850
if (readCounts) {
822851
featureCounts = new IVector();
823852
featureCounts.read(in);

0 commit comments

Comments
 (0)