diff --git a/.github/workflows/gradle.yml b/.github/workflows/gradle.yml index 3d5ca0a..236bdf9 100644 --- a/.github/workflows/gradle.yml +++ b/.github/workflows/gradle.yml @@ -21,6 +21,11 @@ jobs: java-version: ${{ matrix.java }} - name: Build with Gradle run: ./gradlew build + - name: Upload artifacts + uses: actions/upload-artifact@v2 + with: + name: Package + path: build/libs publishCoverage: runs-on: ubuntu-latest diff --git a/src/main/java/de/bwaldvogel/liblinear/Linear.java b/src/main/java/de/bwaldvogel/liblinear/Linear.java index f3347fd..35a9766 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Linear.java +++ b/src/main/java/de/bwaldvogel/liblinear/Linear.java @@ -2252,7 +2252,7 @@ private static void checkProblemSize(int n, int nr_class) { private static void train_one(Problem prob, Parameter param, double[] w, double Cp, double Cn) { SolverType solver_type = param.solverType; - int dual_solver_max_iter = 300; + int dual_solver_max_iter = param.dual_solver_max_iters; int iter; // upstream: (solver_type==L2R_L2LOSS_SVR || solver_type==L2R_L1LOSS_SVR_DUAL || solver_type==L2R_L2LOSS_SVR_DUAL) diff --git a/src/main/java/de/bwaldvogel/liblinear/Model.java b/src/main/java/de/bwaldvogel/liblinear/Model.java index 6e39859..2016e4a 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Model.java +++ b/src/main/java/de/bwaldvogel/liblinear/Model.java @@ -88,7 +88,7 @@ public double[] getFeatureWeights() { * @return true for logistic regression solvers */ public boolean isProbabilityModel() { - return solverType.isLogisticRegressionSolver(); + return true; //solverType.isLogisticRegressionSolver(); } /** diff --git a/src/main/java/de/bwaldvogel/liblinear/Parameter.java b/src/main/java/de/bwaldvogel/liblinear/Parameter.java index 6608592..e4ea595 100644 --- a/src/main/java/de/bwaldvogel/liblinear/Parameter.java +++ b/src/main/java/de/bwaldvogel/liblinear/Parameter.java @@ -19,6 +19,8 @@ public final class Parameter implements Cloneable { double eps; int max_iters = 1000; // maximal iterations + + int dual_solver_max_iters = 300; SolverType solverType; @@ -52,6 +54,14 @@ public Parameter(SolverType solver, double C, int max_iters, double eps) { setMaxIters(max_iters); } + public Parameter(SolverType solver, double C, int max_iters, int dual_solver_max_iters, double eps) { + setSolverType(solver); + setC(C); + setEps(eps); + setMaxIters(max_iters); + setDualSolverMaxIters(dual_solver_max_iters); + } + public Parameter(SolverType solverType, double C, double eps, double p) { setSolverType(solverType); setC(C); @@ -67,6 +77,15 @@ public Parameter(SolverType solverType, double C, double eps, int max_iters, dou setP(p); } + public Parameter(SolverType solverType, double C, double eps, int max_iters, int dual_solver_max_iters, double p) { + setSolverType(solverType); + setC(C); + setEps(eps); + setMaxIters(max_iters); + setDualSolverMaxIters(dual_solver_max_iters); + setP(p); + } + /** *
nr_weight, weight_label, and weight are used to change the penalty
* for some classes (If the weight for a class is not changed, it is
@@ -154,6 +173,16 @@ public int getMaxIters() {
return max_iters;
}
+ public void setDualSolverMaxIters(int iters) {
+ if (iters <= 0)
+ throw new IllegalArgumentException("dual solver max iters not be <= 0");
+ this.dual_solver_max_iters = iters;
+ }
+
+ public int getDualSolverMaxIters() {
+ return dual_solver_max_iters;
+ }
+
public void setSolverType(SolverType solverType) {
if (solverType == null)
throw new IllegalArgumentException("solver type must not be null");
@@ -222,7 +251,7 @@ public void setRandom(Random random) {
@Override
public Parameter clone() {
- Parameter clone = new Parameter(solverType, C, eps, max_iters, p);
+ Parameter clone = new Parameter(solverType, C, eps, max_iters, dual_solver_max_iters, p);
clone.weight = weight == null ? null : weight.clone();
clone.weightLabel = weightLabel == null ? null : weightLabel.clone();
clone.init_sol = init_sol;
diff --git a/src/main/java/de/bwaldvogel/liblinear/Predict.java b/src/main/java/de/bwaldvogel/liblinear/Predict.java
index 7d87e9b..9d64b65 100644
--- a/src/main/java/de/bwaldvogel/liblinear/Predict.java
+++ b/src/main/java/de/bwaldvogel/liblinear/Predict.java
@@ -135,15 +135,14 @@ static void doPredict(BufferedReader reader, Writer writer, Model model, boolean
}
}
- private static void exit_with_help() {
+ private static void print_help() {
System.out.printf("Usage: predict [options] test_file model_file output_file%n" //
+ "options:%n" //
+ "-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0); currently for logistic regression only%n" //
+ "-q quiet mode (no outputs)%n");
- System.exit(1);
}
- public static void main(String[] argv) throws IOException {
+ public static int run(String[] argv) throws IOException {
// Note: This flag is _static_ in predict.c but it causes a thread-safety issue as reported in https://github.com/bwaldvogel/liblinear-java/issues/38
boolean flag_predict_probability = false;
int i;
@@ -158,7 +157,8 @@ public static void main(String[] argv) throws IOException {
try {
flag_predict_probability = (atoi(argv[i]) != 0);
} catch (NumberFormatException e) {
- exit_with_help();
+ print_help();
+ return 1;
}
break;
@@ -169,12 +169,13 @@ public static void main(String[] argv) throws IOException {
default:
System.err.printf("unknown option: -%d%n", argv[i - 1].charAt(1));
- exit_with_help();
- break;
+ print_help();
+ return 1;
}
}
if (i >= argv.length || argv.length <= i + 2) {
- exit_with_help();
+ print_help();
+ return 1;
}
try (FileInputStream in = new FileInputStream(argv[i]);
@@ -184,5 +185,11 @@ public static void main(String[] argv) throws IOException {
Model model = Linear.loadModel(Paths.get(argv[i + 1]));
doPredict(reader, writer, model, flag_predict_probability);
}
+
+ return 0;
+ }
+
+ public static void main(String[] argv) throws IOException {
+ System.exit(run(argv));
}
}
diff --git a/src/main/java/de/bwaldvogel/liblinear/Train.java b/src/main/java/de/bwaldvogel/liblinear/Train.java
index c80a8ba..52f8fae 100644
--- a/src/main/java/de/bwaldvogel/liblinear/Train.java
+++ b/src/main/java/de/bwaldvogel/liblinear/Train.java
@@ -21,7 +21,7 @@
public class Train {
public static void main(String[] args) throws IOException, InvalidInputDataException {
- new Train().run(args);
+ System.exit(new Train().run(args));
}
private double bias = 1;
@@ -92,7 +92,7 @@ private void do_cross_validation() {
}
}
- private void exit_with_help() {
+ private void print_help() {
System.out.printf("Usage: train [options] training_set_file [model_file]%n" //
+ "options:%n"
+ "-s type : set type of solver (default 1)%n"
@@ -136,7 +136,6 @@ private void exit_with_help() {
+ "-v n: n-fold cross validation mode%n"
+ "-C : find parameters (C for -s 0, 2 and C, p for -s 11)%n"
+ "-q : quiet mode (no outputs)%n");
- System.exit(1);
}
public Problem getProblem() {
@@ -151,7 +150,7 @@ public Parameter getParameter() {
return param;
}
- public void parse_command_line(String argv[]) {
+ public boolean parse_command_line(String argv[]) {
int i;
// eps: see setting below
@@ -164,8 +163,10 @@ public void parse_command_line(String argv[]) {
for (i = 0; i < argv.length; i++) {
if (argv[i].charAt(0) != '-')
break;
- if (++i >= argv.length)
- exit_with_help();
+ if (++i >= argv.length) {
+ print_help();
+ return false;
+ }
switch (argv[i - 1].charAt(1)) {
case 's':
param.solverType = SolverType.getById(atoi(argv[i]));
@@ -185,6 +186,15 @@ public void parse_command_line(String argv[]) {
case 'e':
param.setEps(atof(argv[i]));
break;
+ case 'm':
+ // ignore for compatibility with multicore liblinear
+ break;
+ case 'i':
+ param.setMaxIters(atoi(argv[i]));
+ break;
+ case 'd':
+ param.setDualSolverMaxIters(atoi(argv[i]));
+ break;
case 'B':
bias = atof(argv[i]);
break;
@@ -199,7 +209,8 @@ public void parse_command_line(String argv[]) {
nr_fold = atoi(argv[i]);
if (nr_fold < 2) {
System.err.println("n-fold cross validation: n must >= 2");
- exit_with_help();
+ print_help();
+ return false;
}
break;
case 'q':
@@ -216,14 +227,17 @@ public void parse_command_line(String argv[]) {
break;
default:
System.err.println("unknown option");
- exit_with_help();
+ print_help();
+ return false;
}
}
// determine filenames
- if (i >= argv.length)
- exit_with_help();
+ if (i >= argv.length) {
+ print_help();
+ return false;
+ }
inputFilename = argv[i];
@@ -244,7 +258,8 @@ public void parse_command_line(String argv[]) {
param.setSolverType(L2R_L2LOSS_SVC);
} else if (param.getSolverType() != L2R_LR && param.getSolverType() != L2R_L2LOSS_SVC && param.getSolverType() != L2R_L2LOSS_SVR) {
System.err.printf("Warm-start parameter search only available for -s 0, -s 2 and -s 11%n");
- exit_with_help();
+ print_help();
+ return false;
}
}
@@ -278,6 +293,8 @@ public void parse_command_line(String argv[]) {
throw new IllegalStateException("unknown solver type: " + param.solverType);
}
}
+
+ return true;
}
/**
@@ -447,8 +464,9 @@ private static Problem constructProblem(List