From 326f848fb80f3d52ffe3ad6b7f0f25611785bd13 Mon Sep 17 00:00:00 2001 From: Googulator Date: Mon, 12 Apr 2021 23:49:34 +0200 Subject: [PATCH 01/15] Fix flag_predict_probability thread safety issue --- src/main/java/de/bwaldvogel/liblinear/Predict.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/de/bwaldvogel/liblinear/Predict.java b/src/main/java/de/bwaldvogel/liblinear/Predict.java index c988712..085fe25 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Predict.java +++ b/src/main/java/de/bwaldvogel/liblinear/Predict.java @@ -21,7 +21,7 @@ public class Predict { - private static boolean flag_predict_probability = false; + private boolean flag_predict_probability = false; private static final Pattern COLON = Pattern.compile(":"); From 548234d026de901044e3828de28ac31bc7958a59 Mon Sep 17 00:00:00 2001 From: Googulator Date: Mon, 12 Apr 2021 23:58:59 +0200 Subject: [PATCH 02/15] Parameter: Allow overriding dual_solver_max_iters --- .../de/bwaldvogel/liblinear/Parameter.java | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/src/main/java/de/bwaldvogel/liblinear/Parameter.java b/src/main/java/de/bwaldvogel/liblinear/Parameter.java index 8784a34..5036a33 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Parameter.java +++ b/src/main/java/de/bwaldvogel/liblinear/Parameter.java @@ -11,6 +11,8 @@ public final class Parameter implements Cloneable { double eps; int max_iters = 1000; // maximal iterations + + int dual_solver_max_iters = 300; SolverType solverType; @@ -42,6 +44,14 @@ public Parameter(SolverType solver, double C, int max_iters, double eps) { setMaxIters(max_iters); } + public Parameter(SolverType solver, double C, int max_iters, int dual_solver_max_iters, double eps) { + setSolverType(solver); + setC(C); + setEps(eps); + setMaxIters(max_iters); + setDualSolverMaxIters(dual_solver_max_iters); + } + public Parameter(SolverType solverType, double C, double eps, double p) { setSolverType(solverType); setC(C); @@ -57,6 +67,15 @@ public Parameter(SolverType solverType, double C, double eps, int max_iters, dou setP(p); } + public Parameter(SolverType solverType, double C, double eps, int max_iters, int dual_solver_max_iters, double p) { + setSolverType(solverType); + setC(C); + setEps(eps); + setMaxIters(max_iters); + setDualSolverMaxIters(dual_solver_max_iters); + setP(p); + } + /** *

nr_weight, weight_label, and weight are used to change the penalty * for some classes (If the weight for a class is not changed, it is @@ -144,6 +163,16 @@ public int getMaxIters() { return max_iters; } + public void setDualSolverMaxIters(int iters) { + if (iters <= 0) + throw new IllegalArgumentException("dual solver max iters not be <= 0"); + this.dual_solver_max_iters = iters; + } + + public int getDualSolverMaxIters() { + return dual_solver_max_iters; + } + public void setSolverType(SolverType solverType) { if (solverType == null) throw new IllegalArgumentException("solver type must not be null"); @@ -208,7 +237,7 @@ public boolean isRegularizeBias() { @Override public Parameter clone() { - Parameter clone = new Parameter(solverType, C, eps, max_iters, p); + Parameter clone = new Parameter(solverType, C, eps, max_iters, dual_solver_max_iters, p); clone.weight = weight == null ? null : weight.clone(); clone.weightLabel = weightLabel == null ? null : weightLabel.clone(); clone.init_sol = init_sol; From a01b946f88ad2dff49517bdb304cc5ad65171bb6 Mon Sep 17 00:00:00 2001 From: Googulator Date: Tue, 13 Apr 2021 00:01:17 +0200 Subject: [PATCH 03/15] Allow setting max_iters and dual_solver_max_iters via command line --- src/main/java/de/bwaldvogel/liblinear/Train.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/main/java/de/bwaldvogel/liblinear/Train.java b/src/main/java/de/bwaldvogel/liblinear/Train.java index c80a8ba..e1e0a28 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Train.java +++ b/src/main/java/de/bwaldvogel/liblinear/Train.java @@ -185,6 +185,12 @@ public void parse_command_line(String argv[]) { case 'e': param.setEps(atof(argv[i])); break; + case 'm': + param.setMaxIters(atoi(argv[i])); + break; + case 'd': + param.setDualSolverMaxIters(atoi(argv[i])); + break; case 'B': bias = atof(argv[i]); break; From 7158d5f128f6074b1c89de8353b91a6e7eada5a0 Mon Sep 17 00:00:00 2001 From: Googulator Date: Tue, 13 Apr 2021 00:03:50 +0200 Subject: [PATCH 04/15] Linear: use dual_solver_max_iters from param --- src/main/java/de/bwaldvogel/liblinear/Linear.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/de/bwaldvogel/liblinear/Linear.java b/src/main/java/de/bwaldvogel/liblinear/Linear.java index bd7baa7..7b2c6c2 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Linear.java +++ b/src/main/java/de/bwaldvogel/liblinear/Linear.java @@ -2255,7 +2255,7 @@ private static void checkProblemSize(int n, int nr_class) { private static void train_one(Problem prob, Parameter param, double[] w, double Cp, double Cn) { SolverType solver_type = param.solverType; - int dual_solver_max_iter = 300; + int dual_solver_max_iter = param.dual_solver_max_iters; int iter; // upstream: (solver_type==L2R_L2LOSS_SVR || solver_type==L2R_L1LOSS_SVR_DUAL || solver_type==L2R_L2LOSS_SVR_DUAL) From ae696b6ccb52fdb5dbf9a4ae8bd16e4b4c814223 Mon Sep 17 00:00:00 2001 From: Googulator Date: Tue, 13 Apr 2021 00:06:10 +0200 Subject: [PATCH 05/15] Properly fix flag_predict_probability thread safety --- src/main/java/de/bwaldvogel/liblinear/Predict.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/main/java/de/bwaldvogel/liblinear/Predict.java b/src/main/java/de/bwaldvogel/liblinear/Predict.java index 085fe25..1274f7a 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Predict.java +++ b/src/main/java/de/bwaldvogel/liblinear/Predict.java @@ -21,14 +21,12 @@ public class Predict { - private boolean flag_predict_probability = false; - private static final Pattern COLON = Pattern.compile(":"); /** *

Note: The streams are NOT closed

*/ - static void doPredict(BufferedReader reader, Writer writer, Model model) throws IOException { + void doPredict(BufferedReader reader, Writer writer, Model model, boolean flag_predict_probability) throws IOException { int correct = 0; int total = 0; double error = 0; @@ -147,6 +145,7 @@ private static void exit_with_help() { public static void main(String[] argv) throws IOException { int i; + boolean flag_predict_probability = false; // parse options for (i = 0; i < argv.length; i++) { @@ -182,7 +181,7 @@ public static void main(String[] argv) throws IOException { FileOutputStream out = new FileOutputStream(argv[i + 2]); Writer writer = new BufferedWriter(new OutputStreamWriter(out, Linear.FILE_CHARSET))) { Model model = Linear.loadModel(Paths.get(argv[i + 1])); - doPredict(reader, writer, model); + doPredict(reader, writer, model, flag_predict_probability); } } } From 23a76db4ac32f17ee20522b84ffc2256309a2ff9 Mon Sep 17 00:00:00 2001 From: Googulator Date: Tue, 13 Apr 2021 00:07:51 +0200 Subject: [PATCH 06/15] Fix accidentally removed static --- src/main/java/de/bwaldvogel/liblinear/Predict.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/de/bwaldvogel/liblinear/Predict.java b/src/main/java/de/bwaldvogel/liblinear/Predict.java index 1274f7a..f5652d9 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Predict.java +++ b/src/main/java/de/bwaldvogel/liblinear/Predict.java @@ -26,7 +26,7 @@ public class Predict { /** *

Note: The streams are NOT closed

*/ - void doPredict(BufferedReader reader, Writer writer, Model model, boolean flag_predict_probability) throws IOException { + static void doPredict(BufferedReader reader, Writer writer, Model model, boolean flag_predict_probability) throws IOException { int correct = 0; int total = 0; double error = 0; From 92c341b2bc915192ccf0bd42892dd7f785359ba8 Mon Sep 17 00:00:00 2001 From: Googulator Date: Tue, 13 Apr 2021 00:11:28 +0200 Subject: [PATCH 07/15] Update unit test to match doPredict parameters --- src/test/java/de/bwaldvogel/liblinear/PredictTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/de/bwaldvogel/liblinear/PredictTest.java b/src/test/java/de/bwaldvogel/liblinear/PredictTest.java index ac4cfa7..af8fc94 100644 --- a/src/test/java/de/bwaldvogel/liblinear/PredictTest.java +++ b/src/test/java/de/bwaldvogel/liblinear/PredictTest.java @@ -42,7 +42,7 @@ public void tearDown() { private void testWithLines(StringBuilder sb) throws Exception { try (StringReader stringReader = new StringReader(sb.toString()); BufferedReader reader = new BufferedReader(stringReader)) { - Predict.doPredict(reader, writer, testModel); + Predict.doPredict(reader, writer, testModel, false); } } From 3e10753ccf57d225d4654e11b64c85515916741b Mon Sep 17 00:00:00 2001 From: Googulator Date: Tue, 13 Apr 2021 00:27:16 +0200 Subject: [PATCH 08/15] Upload artifacts after build --- .github/workflows/gradle.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/gradle.yml b/.github/workflows/gradle.yml index e60a755..b3e542c 100644 --- a/.github/workflows/gradle.yml +++ b/.github/workflows/gradle.yml @@ -21,6 +21,11 @@ jobs: java-version: ${{ matrix.java }} - name: Build with Gradle run: ./gradlew build + - name: Upload artifacts + uses: actions/upload-artifact@v2 + with: + name: Package + path: build/libs publishCoverage: runs-on: ubuntu-latest From 4f20b9b2594c753519f64e97d40c1b84071c56c3 Mon Sep 17 00:00:00 2001 From: Googulator Date: Tue, 13 Apr 2021 03:20:17 +0200 Subject: [PATCH 09/15] Predict: allow caller to recover from invalid argument --- .../java/de/bwaldvogel/liblinear/Predict.java | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/main/java/de/bwaldvogel/liblinear/Predict.java b/src/main/java/de/bwaldvogel/liblinear/Predict.java index f5652d9..a04a284 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Predict.java +++ b/src/main/java/de/bwaldvogel/liblinear/Predict.java @@ -135,15 +135,14 @@ static void doPredict(BufferedReader reader, Writer writer, Model model, boolean } } - private static void exit_with_help() { + private static void print_help() { System.out.printf("Usage: predict [options] test_file model_file output_file%n" // + "options:%n" // + "-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0); currently for logistic regression only%n" // + "-q quiet mode (no outputs)%n"); - System.exit(1); } - public static void main(String[] argv) throws IOException { + public static int run(String[] argv) throws IOException { int i; boolean flag_predict_probability = false; @@ -157,7 +156,8 @@ public static void main(String[] argv) throws IOException { try { flag_predict_probability = (atoi(argv[i]) != 0); } catch (NumberFormatException e) { - exit_with_help(); + print_help(); + return 1; } break; @@ -168,12 +168,13 @@ public static void main(String[] argv) throws IOException { default: System.err.printf("unknown option: -%d%n", argv[i - 1].charAt(1)); - exit_with_help(); - break; + print_help(); + return 1; } } if (i >= argv.length || argv.length <= i + 2) { - exit_with_help(); + print_help(); + return 1; } try (FileInputStream in = new FileInputStream(argv[i]); @@ -183,5 +184,11 @@ public static void main(String[] argv) throws IOException { Model model = Linear.loadModel(Paths.get(argv[i + 1])); doPredict(reader, writer, model, flag_predict_probability); } + + return 0; + } + + public static void main(String[] argv) throws IOException { + System.exit(run(argv)); } } From 281504788b9753eb163febfd694d1bdd02c74c0d Mon Sep 17 00:00:00 2001 From: Googulator Date: Tue, 13 Apr 2021 03:26:21 +0200 Subject: [PATCH 10/15] Train: allow caller a chance to recover from invalid argument --- .../java/de/bwaldvogel/liblinear/Train.java | 37 ++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/src/main/java/de/bwaldvogel/liblinear/Train.java b/src/main/java/de/bwaldvogel/liblinear/Train.java index e1e0a28..bda9b6f 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Train.java +++ b/src/main/java/de/bwaldvogel/liblinear/Train.java @@ -21,7 +21,7 @@ public class Train { public static void main(String[] args) throws IOException, InvalidInputDataException { - new Train().run(args); + System.exit(new Train().run(args)); } private double bias = 1; @@ -92,7 +92,7 @@ private void do_cross_validation() { } } - private void exit_with_help() { + private void print_help() { System.out.printf("Usage: train [options] training_set_file [model_file]%n" // + "options:%n" + "-s type : set type of solver (default 1)%n" @@ -136,7 +136,6 @@ private void exit_with_help() { + "-v n: n-fold cross validation mode%n" + "-C : find parameters (C for -s 0, 2 and C, p for -s 11)%n" + "-q : quiet mode (no outputs)%n"); - System.exit(1); } public Problem getProblem() { @@ -151,7 +150,7 @@ public Parameter getParameter() { return param; } - public void parse_command_line(String argv[]) { + public boolean parse_command_line(String argv[]) { int i; // eps: see setting below @@ -164,8 +163,10 @@ public void parse_command_line(String argv[]) { for (i = 0; i < argv.length; i++) { if (argv[i].charAt(0) != '-') break; - if (++i >= argv.length) - exit_with_help(); + if (++i >= argv.length) { + print_help(); + return false; + } switch (argv[i - 1].charAt(1)) { case 's': param.solverType = SolverType.getById(atoi(argv[i])); @@ -205,7 +206,8 @@ public void parse_command_line(String argv[]) { nr_fold = atoi(argv[i]); if (nr_fold < 2) { System.err.println("n-fold cross validation: n must >= 2"); - exit_with_help(); + print_help(); + return false; } break; case 'q': @@ -222,14 +224,17 @@ public void parse_command_line(String argv[]) { break; default: System.err.println("unknown option"); - exit_with_help(); + print_help(); + return false; } } // determine filenames - if (i >= argv.length) - exit_with_help(); + if (i >= argv.length) { + print_help(); + return 1; + } inputFilename = argv[i]; @@ -250,7 +255,8 @@ public void parse_command_line(String argv[]) { param.setSolverType(L2R_L2LOSS_SVC); } else if (param.getSolverType() != L2R_LR && param.getSolverType() != L2R_L2LOSS_SVC && param.getSolverType() != L2R_L2LOSS_SVR) { System.err.printf("Warm-start parameter search only available for -s 0, -s 2 and -s 11%n"); - exit_with_help(); + print_help(); + return false; } } @@ -284,6 +290,8 @@ public void parse_command_line(String argv[]) { throw new IllegalStateException("unknown solver type: " + param.solverType); } } + + return true; } /** @@ -453,8 +461,9 @@ private static Problem constructProblem(List vy, List vx, int return prob; } - private void run(String[] args) throws IOException, InvalidInputDataException { - parse_command_line(args); + private int run(String[] args) throws IOException, InvalidInputDataException { + if (!parse_command_line(args)) + return 1; readProblem(inputFilename); if (find_parameters) { do_find_parameters(); @@ -464,6 +473,8 @@ private void run(String[] args) throws IOException, InvalidInputDataException { Model model = Linear.train(prob, param); Linear.saveModel(Paths.get(modelFilename), model); } + + return 0; } boolean isFindParameters() { From 33b38309b94976b84dc70626b63503ac85f1f002 Mon Sep 17 00:00:00 2001 From: Googulator Date: Tue, 13 Apr 2021 03:29:52 +0200 Subject: [PATCH 11/15] Train: fix type confusion --- src/main/java/de/bwaldvogel/liblinear/Train.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/de/bwaldvogel/liblinear/Train.java b/src/main/java/de/bwaldvogel/liblinear/Train.java index bda9b6f..b7d68f2 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Train.java +++ b/src/main/java/de/bwaldvogel/liblinear/Train.java @@ -233,7 +233,7 @@ public boolean parse_command_line(String argv[]) { if (i >= argv.length) { print_help(); - return 1; + return false; } inputFilename = argv[i]; From ac59fb9f0dfe1a12811fd2f6f620ab9405146798 Mon Sep 17 00:00:00 2001 From: Googulator Date: Thu, 15 Apr 2021 22:24:15 +0200 Subject: [PATCH 12/15] Change iteration limit to "-i", reserve "-m" for multicore compatibility --- src/main/java/de/bwaldvogel/liblinear/Train.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/main/java/de/bwaldvogel/liblinear/Train.java b/src/main/java/de/bwaldvogel/liblinear/Train.java index b7d68f2..7c4f57b 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Train.java +++ b/src/main/java/de/bwaldvogel/liblinear/Train.java @@ -187,6 +187,9 @@ public boolean parse_command_line(String argv[]) { param.setEps(atof(argv[i])); break; case 'm': + // ignore for compatibility with multicore liblinear + break; + case 'i': param.setMaxIters(atoi(argv[i])); break; case 'd': From e6960420ba3e7f6220f823a67f1c54e558cb16c4 Mon Sep 17 00:00:00 2001 From: Googulator Date: Thu, 15 Apr 2021 22:58:39 +0200 Subject: [PATCH 13/15] Make run() public --- src/main/java/de/bwaldvogel/liblinear/Train.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/de/bwaldvogel/liblinear/Train.java b/src/main/java/de/bwaldvogel/liblinear/Train.java index 7c4f57b..52f8fae 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Train.java +++ b/src/main/java/de/bwaldvogel/liblinear/Train.java @@ -464,7 +464,7 @@ private static Problem constructProblem(List vy, List vx, int return prob; } - private int run(String[] args) throws IOException, InvalidInputDataException { + public int run(String[] args) throws IOException, InvalidInputDataException { if (!parse_command_line(args)) return 1; readProblem(inputFilename); From 729f0595f56136839233395b43dc61493536dfd5 Mon Sep 17 00:00:00 2001 From: Googulator Date: Wed, 19 May 2021 23:41:31 +0200 Subject: [PATCH 14/15] Unlock (inaccurate) probability estimation for all solvers This isn't perfectly accurate, but should at least be in order. --- src/main/java/de/bwaldvogel/liblinear/Model.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/de/bwaldvogel/liblinear/Model.java b/src/main/java/de/bwaldvogel/liblinear/Model.java index 6e39859..2016e4a 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Model.java +++ b/src/main/java/de/bwaldvogel/liblinear/Model.java @@ -88,7 +88,7 @@ public double[] getFeatureWeights() { * @return true for logistic regression solvers */ public boolean isProbabilityModel() { - return solverType.isLogisticRegressionSolver(); + return true; //solverType.isLogisticRegressionSolver(); } /** From 701fb52ce8bc4d260c7cc2f8487264c52cdc2974 Mon Sep 17 00:00:00 2001 From: Googulator Date: Wed, 19 May 2021 23:46:05 +0200 Subject: [PATCH 15/15] Disable test for predictProbability on non-LR solvers The main code has been patched to allow predictProbability on these solvers, so this test is now useless. --- src/test/java/de/bwaldvogel/liblinear/LinearTest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/java/de/bwaldvogel/liblinear/LinearTest.java b/src/test/java/de/bwaldvogel/liblinear/LinearTest.java index 477626d..b2f7cd1 100644 --- a/src/test/java/de/bwaldvogel/liblinear/LinearTest.java +++ b/src/test/java/de/bwaldvogel/liblinear/LinearTest.java @@ -361,7 +361,7 @@ void testTrain_IllegalParameters_InitialSol() { Model model = Linear.train(prob, param); assertThat(model).isNotNull(); } - +/* @Test void testPredictProbabilityWrongSolver() throws Exception { Problem prob = new Problem(); @@ -383,7 +383,7 @@ void testPredictProbabilityWrongSolver() throws Exception { + " This is currently only supported by the following solvers:" + " L2R_LR, L1R_LR, L2R_LR_DUAL"); } - +*/ @Test void testAtoi() { assertThat(Linear.atoi("+25")).isEqualTo(25);