From 7ffa87a3fd88427297ac1176f8b44fcaf10dad66 Mon Sep 17 00:00:00 2001 From: Jonathon Hare Date: Wed, 17 Jul 2013 15:27:12 +0100 Subject: [PATCH 1/4] adding dense version of liblinear --- .../denseliblinear/ArraySorter.java | 91 + .../denseliblinear/DoubleArrayPointer.java | 27 + .../bwaldvogel/denseliblinear/Function.java | 13 + .../denseliblinear/IntArrayPointer.java | 27 + .../InvalidInputDataException.java | 57 + .../denseliblinear/L2R_L2_SvcFunction.java | 117 + .../denseliblinear/L2R_L2_SvrFunction.java | 67 + .../denseliblinear/L2R_LrFunction.java | 108 + .../de/bwaldvogel/denseliblinear/Linear.java | 1912 +++++++++++++++++ .../de/bwaldvogel/denseliblinear/Model.java | 178 ++ .../bwaldvogel/denseliblinear/Parameter.java | 120 ++ .../de/bwaldvogel/denseliblinear/Predict.java | 193 ++ .../de/bwaldvogel/denseliblinear/Problem.java | 62 + .../denseliblinear/SolverMCSVM_CS.java | 293 +++ .../bwaldvogel/denseliblinear/SolverType.java | 129 ++ .../de/bwaldvogel/denseliblinear/Train.java | 420 ++++ .../de/bwaldvogel/denseliblinear/Tron.java | 260 +++ .../denseliblinear/ArrayPointerTest.java | 63 + .../denseliblinear/ArraySorterTest.java | 58 + .../bwaldvogel/denseliblinear/LinearTest.java | 517 +++++ .../denseliblinear/ParameterTest.java | 127 ++ .../denseliblinear/PredictTest.java | 57 + .../bwaldvogel/denseliblinear/TrainTest.java | 210 ++ 23 files changed, 5106 insertions(+) create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/ArraySorter.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/DoubleArrayPointer.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Function.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/IntArrayPointer.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/InvalidInputDataException.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvcFunction.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvrFunction.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/L2R_LrFunction.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Linear.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Model.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Parameter.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Predict.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Problem.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/SolverMCSVM_CS.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/SolverType.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Train.java create mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Tron.java create mode 100644 src/test/java/de/bwaldvogel/denseliblinear/ArrayPointerTest.java create mode 100644 src/test/java/de/bwaldvogel/denseliblinear/ArraySorterTest.java create mode 100644 src/test/java/de/bwaldvogel/denseliblinear/LinearTest.java create mode 100644 src/test/java/de/bwaldvogel/denseliblinear/ParameterTest.java create mode 100644 src/test/java/de/bwaldvogel/denseliblinear/PredictTest.java create mode 100644 src/test/java/de/bwaldvogel/denseliblinear/TrainTest.java diff --git a/src/main/java/de/bwaldvogel/denseliblinear/ArraySorter.java b/src/main/java/de/bwaldvogel/denseliblinear/ArraySorter.java new file mode 100644 index 0000000..06e8e50 --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/ArraySorter.java @@ -0,0 +1,91 @@ +package de.bwaldvogel.denseliblinear; + + +final class ArraySorter { + + /** + *

Sorts the specified array of doubles into descending order.

+ * + * This code is borrowed from Sun's JDK 1.6.0.07 + */ + public static void reversedMergesort(double[] a) { + reversedMergesort(a, 0, a.length); + } + + private static void reversedMergesort(double x[], int off, int len) { + // Insertion sort on smallest arrays + if (len < 7) { + for (int i = off; i < len + off; i++) + for (int j = i; j > off && x[j - 1] < x[j]; j--) + swap(x, j, j - 1); + return; + } + + // Choose a partition element, v + int m = off + (len >> 1); // Small arrays, middle element + if (len > 7) { + int l = off; + int n = off + len - 1; + if (len > 40) { // Big arrays, pseudomedian of 9 + int s = len / 8; + l = med3(x, l, l + s, l + 2 * s); + m = med3(x, m - s, m, m + s); + n = med3(x, n - 2 * s, n - s, n); + } + m = med3(x, l, m, n); // Mid-size, med of 3 + } + double v = x[m]; + + // Establish Invariant: v* (v)* v* + int a = off, b = a, c = off + len - 1, d = c; + while (true) { + while (b <= c && x[b] >= v) { + if (x[b] == v) swap(x, a++, b); + b++; + } + while (c >= b && x[c] <= v) { + if (x[c] == v) swap(x, c, d--); + c--; + } + if (b > c) break; + swap(x, b++, c--); + } + + // Swap partition elements back to middle + int s, n = off + len; + s = Math.min(a - off, b - a); + vecswap(x, off, b - s, s); + s = Math.min(d - c, n - d - 1); + vecswap(x, b, n - s, s); + + // Recursively sort non-partition-elements + if ((s = b - a) > 1) reversedMergesort(x, off, s); + if ((s = d - c) > 1) reversedMergesort(x, n - s, s); + } + + /** + * Swaps x[a] with x[b]. + */ + private static void swap(double x[], int a, int b) { + double t = x[a]; + x[a] = x[b]; + x[b] = t; + } + + /** + * Swaps x[a .. (a+n-1)] with x[b .. (b+n-1)]. + */ + private static void vecswap(double x[], int a, int b, int n) { + for (int i = 0; i < n; i++, a++, b++) + swap(x, a, b); + } + + /** + * Returns the index of the median of the three indexed doubles. + */ + private static int med3(double x[], int a, int b, int c) { + return (x[a] < x[b] ? (x[b] < x[c] ? b : x[a] < x[c] ? c : a) : (x[b] > x[c] ? b : x[a] > x[c] ? c : a)); + } + + +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/DoubleArrayPointer.java b/src/main/java/de/bwaldvogel/denseliblinear/DoubleArrayPointer.java new file mode 100644 index 0000000..1f6e1aa --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/DoubleArrayPointer.java @@ -0,0 +1,27 @@ +package de.bwaldvogel.denseliblinear; + + +final class DoubleArrayPointer { + + private final double[] _array; + private int _offset; + + + public void setOffset(int offset) { + if (offset < 0 || offset >= _array.length) throw new IllegalArgumentException("offset must be between 0 and the length of the array"); + _offset = offset; + } + + public DoubleArrayPointer( final double[] array, final int offset ) { + _array = array; + setOffset(offset); + } + + public double get(final int index) { + return _array[_offset + index]; + } + + public void set(final int index, final double value) { + _array[_offset + index] = value; + } +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Function.java b/src/main/java/de/bwaldvogel/denseliblinear/Function.java new file mode 100644 index 0000000..9a15c27 --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/Function.java @@ -0,0 +1,13 @@ +package de.bwaldvogel.denseliblinear; + +// origin: tron.h +interface Function { + + double fun(double[] w); + + void grad(double[] w, double[] g); + + void Hv(double[] s, double[] Hs); + + int get_nr_variable(); +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/IntArrayPointer.java b/src/main/java/de/bwaldvogel/denseliblinear/IntArrayPointer.java new file mode 100644 index 0000000..f8635fd --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/IntArrayPointer.java @@ -0,0 +1,27 @@ +package de.bwaldvogel.denseliblinear; + + +final class IntArrayPointer { + + private final int[] _array; + private int _offset; + + + public void setOffset(int offset) { + if (offset < 0 || offset >= _array.length) throw new IllegalArgumentException("offset must be between 0 and the length of the array"); + _offset = offset; + } + + public IntArrayPointer( final int[] array, final int offset ) { + _array = array; + setOffset(offset); + } + + public int get(final int index) { + return _array[_offset + index]; + } + + public void set(final int index, final int value) { + _array[_offset + index] = value; + } +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/InvalidInputDataException.java b/src/main/java/de/bwaldvogel/denseliblinear/InvalidInputDataException.java new file mode 100644 index 0000000..5991a64 --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/InvalidInputDataException.java @@ -0,0 +1,57 @@ +package de.bwaldvogel.denseliblinear; + +import java.io.File; + + +public class InvalidInputDataException extends Exception { + + private static final long serialVersionUID = 2945131732407207308L; + + private final int _line; + + private File _file; + + public InvalidInputDataException( String message, File file, int line ) { + super(message); + _file = file; + _line = line; + } + + public InvalidInputDataException( String message, String filename, int line ) { + this(message, new File(filename), line); + } + + public InvalidInputDataException( String message, File file, int lineNr, Exception cause ) { + super(message, cause); + _file = file; + _line = lineNr; + } + + public InvalidInputDataException( String message, String filename, int lineNr, Exception cause ) { + this(message, new File(filename), lineNr, cause); + } + + public File getFile() { + return _file; + } + + /** + * This methods returns the path of the file. + * The method name might be misleading. + * + * @deprecated use {@link #getFile()} instead + */ + public String getFilename() { + return _file.getPath(); + } + + public int getLine() { + return _line; + } + + @Override + public String toString() { + return super.toString() + " (" + _file + ":" + _line + ")"; + } + +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvcFunction.java b/src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvcFunction.java new file mode 100644 index 0000000..2a13238 --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvcFunction.java @@ -0,0 +1,117 @@ +package de.bwaldvogel.denseliblinear; + +class L2R_L2_SvcFunction implements Function { + + protected final Problem prob; + protected final double[] C; + protected final int[] I; + protected final double[] z; + + protected int sizeI; + + public L2R_L2_SvcFunction(Problem prob, double[] C) { + final int l = prob.l; + + this.prob = prob; + + z = new double[l]; + I = new int[l]; + this.C = C; + } + + @Override + public double fun(double[] w) { + int i; + double f = 0; + final double[] y = prob.y; + final int l = prob.l; + final int w_size = get_nr_variable(); + + Xv(w, z); + + for (i = 0; i < w_size; i++) + f += w[i] * w[i]; + f /= 2.0; + for (i = 0; i < l; i++) { + z[i] = y[i] * z[i]; + final double d = 1 - z[i]; + if (d > 0) + f += C[i] * d * d; + } + + return (f); + } + + @Override + public int get_nr_variable() { + return prob.n; + } + + @Override + public void grad(double[] w, double[] g) { + final double[] y = prob.y; + final int l = prob.l; + final int w_size = get_nr_variable(); + + sizeI = 0; + for (int i = 0; i < l; i++) { + if (z[i] < 1) { + z[sizeI] = C[i] * y[i] * (z[i] - 1); + I[sizeI] = i; + sizeI++; + } + } + subXTv(z, g); + + for (int i = 0; i < w_size; i++) + g[i] = w[i] + 2 * g[i]; + } + + @Override + public void Hv(double[] s, double[] Hs) { + int i; + final int w_size = get_nr_variable(); + final double[] wa = new double[sizeI]; + + subXv(s, wa); + for (i = 0; i < sizeI; i++) + wa[i] = C[I[i]] * wa[i]; + + subXTv(wa, Hs); + for (i = 0; i < w_size; i++) + Hs[i] = s[i] + 2 * Hs[i]; + } + + protected void subXTv(double[] v, double[] XTv) { + int i; + final int w_size = get_nr_variable(); + + for (i = 0; i < w_size; i++) + XTv[i] = 0; + + for (i = 0; i < sizeI; i++) { + for (int j = 0; j < prob.x[I[i]].length; j++) { + XTv[j] += v[i] * prob.x[I[i]][j]; + } + } + } + + private void subXv(double[] v, double[] Xv) { + for (int i = 0; i < sizeI; i++) { + Xv[i] = 0; + + for (int j = 0; j < prob.x[I[i]].length; j++) { + Xv[i] += v[j] * prob.x[I[i]][j]; + } + } + } + + protected void Xv(double[] v, double[] Xv) { + for (int i = 0; i < prob.l; i++) { + Xv[i] = 0; + for (int j = 0; j < prob.x[i].length; j++) { + Xv[i] += v[j] * prob.x[i][j]; + } + } + } +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvrFunction.java b/src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvrFunction.java new file mode 100644 index 0000000..d4de914 --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvrFunction.java @@ -0,0 +1,67 @@ +package de.bwaldvogel.denseliblinear; + +/** + * @since 1.91 + */ +public class L2R_L2_SvrFunction extends L2R_L2_SvcFunction { + + private double p; + + public L2R_L2_SvrFunction( Problem prob, double[] C, double p ) { + super(prob, C); + this.p = p; + } + + @Override + public double fun(double[] w) { + double f = 0; + double[] y = prob.y; + int l = prob.l; + int w_size = get_nr_variable(); + double d; + + Xv(w, z); + + for (int i = 0; i < w_size; i++) + f += w[i] * w[i]; + f /= 2; + for (int i = 0; i < l; i++) { + d = z[i] - y[i]; + if (d < -p) + f += C[i] * (d + p) * (d + p); + else if (d > p) f += C[i] * (d - p) * (d - p); + } + + return f; + } + + @Override + public void grad(double[] w, double[] g) { + double[] y = prob.y; + int l = prob.l; + int w_size = get_nr_variable(); + + sizeI = 0; + for (int i = 0; i < l; i++) { + double d = z[i] - y[i]; + + // generate index set I + if (d < -p) { + z[sizeI] = C[i] * (d + p); + I[sizeI] = i; + sizeI++; + } else if (d > p) { + z[sizeI] = C[i] * (d - p); + I[sizeI] = i; + sizeI++; + } + + } + subXTv(z, g); + + for (int i = 0; i < w_size; i++) + g[i] = w[i] + 2 * g[i]; + + } + +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/L2R_LrFunction.java b/src/main/java/de/bwaldvogel/denseliblinear/L2R_LrFunction.java new file mode 100644 index 0000000..faf68ce --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/L2R_LrFunction.java @@ -0,0 +1,108 @@ +package de.bwaldvogel.denseliblinear; + +class L2R_LrFunction implements Function { + + private final double[] C; + private final double[] z; + private final double[] D; + private final Problem prob; + + public L2R_LrFunction(Problem prob, double[] C) { + final int l = prob.l; + + this.prob = prob; + + z = new double[l]; + D = new double[l]; + this.C = C; + } + + private void Xv(double[] v, double[] Xv) { + for (int i = 0; i < prob.l; i++) { + Xv[i] = 0; + for (int j = 0; j < prob.x[i].length; j++) { + Xv[i] += v[j] * prob.x[i][j]; + } + } + } + + private void XTv(double[] v, double[] XTv) { + final int l = prob.l; + final int w_size = get_nr_variable(); + final double[][] x = prob.x; + + for (int i = 0; i < w_size; i++) + XTv[i] = 0; + + for (int i = 0; i < l; i++) { + for (int j = 0; j < prob.x[i].length; j++) { + XTv[j] += v[i] * x[i][j]; + } + } + } + + @Override + public double fun(double[] w) { + int i; + double f = 0; + final double[] y = prob.y; + final int l = prob.l; + final int w_size = get_nr_variable(); + + Xv(w, z); + + for (i = 0; i < w_size; i++) + f += w[i] * w[i]; + f /= 2.0; + for (i = 0; i < l; i++) { + final double yz = y[i] * z[i]; + if (yz >= 0) + f += C[i] * Math.log(1 + Math.exp(-yz)); + else + f += C[i] * (-yz + Math.log(1 + Math.exp(yz))); + } + + return (f); + } + + @Override + public void grad(double[] w, double[] g) { + int i; + final double[] y = prob.y; + final int l = prob.l; + final int w_size = get_nr_variable(); + + for (i = 0; i < l; i++) { + z[i] = 1 / (1 + Math.exp(-y[i] * z[i])); + D[i] = z[i] * (1 - z[i]); + z[i] = C[i] * (z[i] - 1) * y[i]; + } + XTv(z, g); + + for (i = 0; i < w_size; i++) + g[i] = w[i] + g[i]; + } + + @Override + public void Hv(double[] s, double[] Hs) { + int i; + final int l = prob.l; + final int w_size = get_nr_variable(); + final double[] wa = new double[l]; + + Xv(s, wa); + for (i = 0; i < l; i++) + wa[i] = C[i] * D[i] * wa[i]; + + XTv(wa, Hs); + for (i = 0; i < w_size; i++) + Hs[i] = s[i] + Hs[i]; + // delete[] wa; + } + + @Override + public int get_nr_variable() { + return prob.n; + } + +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Linear.java b/src/main/java/de/bwaldvogel/denseliblinear/Linear.java new file mode 100644 index 0000000..f2f8029 --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/Linear.java @@ -0,0 +1,1912 @@ +package de.bwaldvogel.denseliblinear; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.Closeable; +import java.io.EOFException; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.PrintStream; +import java.io.Reader; +import java.io.Writer; +import java.nio.charset.Charset; +import java.util.Formatter; +import java.util.Locale; +import java.util.Random; +import java.util.regex.Pattern; + +/** + *

Java port of liblinear

+ * + *

+ * The usage should be pretty similar to the C version of liblinear. + *

+ *

+ * Please consider reading the README file of liblinear. + *

+ * + *

+ * The port was done by Benedikt Waldvogel (mail at bwaldvogel.de) + *

+ * + * @version 1.92 + */ +public class Linear { + + static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1"); + + static final Locale DEFAULT_LOCALE = Locale.ENGLISH; + + private static Object OUTPUT_MUTEX = new Object(); + private static PrintStream DEBUG_OUTPUT = System.out; + + private static final long DEFAULT_RANDOM_SEED = 0L; + static Random random = new Random(DEFAULT_RANDOM_SEED); + + /** + * @param target + * predicted classes + */ + public static void crossValidation(Problem prob, Parameter param, int nr_fold, double[] target) { + int i; + final int[] fold_start = new int[nr_fold + 1]; + final int l = prob.l; + final int[] perm = new int[l]; + + for (i = 0; i < l; i++) + perm[i] = i; + for (i = 0; i < l; i++) { + final int j = i + random.nextInt(l - i); + swap(perm, i, j); + } + for (i = 0; i <= nr_fold; i++) + fold_start[i] = i * l / nr_fold; + + for (i = 0; i < nr_fold; i++) { + final int begin = fold_start[i]; + final int end = fold_start[i + 1]; + int j, k; + final Problem subprob = new Problem(); + + subprob.bias = prob.bias; + subprob.n = prob.n; + subprob.l = l - (end - begin); + subprob.x = new double[subprob.l][]; + subprob.y = new double[subprob.l]; + + k = 0; + for (j = 0; j < begin; j++) { + subprob.x[k] = prob.x[perm[j]]; + subprob.y[k] = prob.y[perm[j]]; + ++k; + } + for (j = end; j < l; j++) { + subprob.x[k] = prob.x[perm[j]]; + subprob.y[k] = prob.y[perm[j]]; + ++k; + } + final Model submodel = train(subprob, param); + for (j = begin; j < end; j++) + target[perm[j]] = predict(submodel, prob.x[perm[j]]); + } + } + + /** used as complex return type */ + private static class GroupClassesReturn { + + final int[] count; + final int[] label; + final int nr_class; + final int[] start; + + GroupClassesReturn(int nr_class, int[] label, int[] start, int[] count) { + this.nr_class = nr_class; + this.label = label; + this.start = start; + this.count = count; + } + } + + private static GroupClassesReturn groupClasses(Problem prob, int[] perm) { + final int l = prob.l; + int max_nr_class = 16; + int nr_class = 0; + + int[] label = new int[max_nr_class]; + int[] count = new int[max_nr_class]; + final int[] data_label = new int[l]; + int i; + + for (i = 0; i < l; i++) { + final int this_label = (int) prob.y[i]; + int j; + for (j = 0; j < nr_class; j++) { + if (this_label == label[j]) { + ++count[j]; + break; + } + } + data_label[i] = j; + if (j == nr_class) { + if (nr_class == max_nr_class) { + max_nr_class *= 2; + label = copyOf(label, max_nr_class); + count = copyOf(count, max_nr_class); + } + label[nr_class] = this_label; + count[nr_class] = 1; + ++nr_class; + } + } + + final int[] start = new int[nr_class]; + start[0] = 0; + for (i = 1; i < nr_class; i++) + start[i] = start[i - 1] + count[i - 1]; + for (i = 0; i < l; i++) { + perm[start[data_label[i]]] = i; + ++start[data_label[i]]; + } + start[0] = 0; + for (i = 1; i < nr_class; i++) + start[i] = start[i - 1] + count[i - 1]; + + return new GroupClassesReturn(nr_class, label, start, count); + } + + static void info(String message) { + synchronized (OUTPUT_MUTEX) { + if (DEBUG_OUTPUT == null) + return; + DEBUG_OUTPUT.printf(message); + DEBUG_OUTPUT.flush(); + } + } + + static void info(String format, Object... args) { + synchronized (OUTPUT_MUTEX) { + if (DEBUG_OUTPUT == null) + return; + DEBUG_OUTPUT.printf(format, args); + DEBUG_OUTPUT.flush(); + } + } + + /** + * @param s + * the string to parse for the double value + * @throws IllegalArgumentException + * if s is empty or represents NaN or Infinity + * @throws NumberFormatException + * see {@link Double#parseDouble(String)} + */ + static double atof(String s) { + if (s == null || s.length() < 1) + throw new IllegalArgumentException("Can't convert empty string to integer"); + final double d = Double.parseDouble(s); + if (Double.isNaN(d) || Double.isInfinite(d)) { + throw new IllegalArgumentException("NaN or Infinity in input: " + s); + } + return (d); + } + + /** + * @param s + * the string to parse for the integer value + * @throws IllegalArgumentException + * if s is empty + * @throws NumberFormatException + * see {@link Integer#parseInt(String)} + */ + static int atoi(String s) throws NumberFormatException { + if (s == null || s.length() < 1) + throw new IllegalArgumentException("Can't convert empty string to integer"); + // Integer.parseInt doesn't accept '+' prefixed strings + if (s.charAt(0) == '+') + s = s.substring(1); + return Integer.parseInt(s); + } + + /** + * Java5 'backport' of Arrays.copyOf + */ + public static double[] copyOf(double[] original, int newLength) { + final double[] copy = new double[newLength]; + System.arraycopy(original, 0, copy, 0, Math.min(original.length, newLength)); + return copy; + } + + /** + * Java5 'backport' of Arrays.copyOf + */ + public static int[] copyOf(int[] original, int newLength) { + final int[] copy = new int[newLength]; + System.arraycopy(original, 0, copy, 0, Math.min(original.length, newLength)); + return copy; + } + + /** + * Loads the model from inputReader. It uses + * {@link java.util.Locale#ENGLISH} for number formatting. + * + *

+ * Note: The inputReader is NOT closed after reading or in case of an + * exception. + *

+ */ + public static Model loadModel(Reader inputReader) throws IOException { + final Model model = new Model(); + + model.label = null; + + final Pattern whitespace = Pattern.compile("\\s+"); + + BufferedReader reader = null; + if (inputReader instanceof BufferedReader) { + reader = (BufferedReader) inputReader; + } else { + reader = new BufferedReader(inputReader); + } + + String line = null; + while ((line = reader.readLine()) != null) { + final String[] split = whitespace.split(line); + if (split[0].equals("solver_type")) { + final SolverType solver = SolverType.valueOf(split[1]); + if (solver == null) { + throw new RuntimeException("unknown solver type"); + } + model.solverType = solver; + } else if (split[0].equals("nr_class")) { + model.nr_class = atoi(split[1]); + Integer.parseInt(split[1]); + } else if (split[0].equals("nr_feature")) { + model.nr_feature = atoi(split[1]); + } else if (split[0].equals("bias")) { + model.bias = atof(split[1]); + } else if (split[0].equals("w")) { + break; + } else if (split[0].equals("label")) { + model.label = new int[model.nr_class]; + for (int i = 0; i < model.nr_class; i++) { + model.label[i] = atoi(split[i + 1]); + } + } else { + throw new RuntimeException("unknown text in model file: [" + line + "]"); + } + } + + int w_size = model.nr_feature; + if (model.bias >= 0) + w_size++; + + int nr_w = model.nr_class; + if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) + nr_w = 1; + + model.w = new double[w_size * nr_w]; + final int[] buffer = new int[128]; + + for (int i = 0; i < w_size; i++) { + for (int j = 0; j < nr_w; j++) { + int b = 0; + while (true) { + final int ch = reader.read(); + if (ch == -1) { + throw new EOFException("unexpected EOF"); + } + if (ch == ' ') { + model.w[i * nr_w + j] = atof(new String(buffer, 0, b)); + break; + } else { + buffer[b++] = ch; + } + } + } + } + + return model; + } + + /** + * Loads the model from the file with ISO-8859-1 charset. It uses + * {@link java.util.Locale#ENGLISH} for number formatting. + */ + public static Model loadModel(File modelFile) throws IOException { + final BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), + FILE_CHARSET)); + try { + return loadModel(inputReader); + } finally { + inputReader.close(); + } + } + + static void closeQuietly(Closeable c) { + if (c == null) + return; + try { + c.close(); + } catch (final Throwable t) { + } + } + + public static double predict(Model model, double[] x) { + final double[] dec_values = new double[model.nr_class]; + return predictValues(model, x, dec_values); + } + + /** + * @throws IllegalArgumentException + * if model is not probabilistic (see + * {@link Model#isProbabilityModel()}) + */ + public static double predictProbability(Model model, double[] x, double[] prob_estimates) + throws IllegalArgumentException + { + if (!model.isProbabilityModel()) { + final StringBuilder sb = new StringBuilder("probability output is only supported for logistic regression"); + sb.append(". This is currently only supported by the following solvers: "); + int i = 0; + for (final SolverType solverType : SolverType.values()) { + if (solverType.isLogisticRegressionSolver()) { + if (i++ > 0) { + sb.append(", "); + } + sb.append(solverType.name()); + } + } + throw new IllegalArgumentException(sb.toString()); + } + final int nr_class = model.nr_class; + int nr_w; + if (nr_class == 2) + nr_w = 1; + else + nr_w = nr_class; + + final double label = predictValues(model, x, prob_estimates); + for (int i = 0; i < nr_w; i++) + prob_estimates[i] = 1 / (1 + Math.exp(-prob_estimates[i])); + + if (nr_class == 2) // for binary classification + prob_estimates[1] = 1. - prob_estimates[0]; + else { + double sum = 0; + for (int i = 0; i < nr_class; i++) + sum += prob_estimates[i]; + + for (int i = 0; i < nr_class; i++) + prob_estimates[i] = prob_estimates[i] / sum; + } + + return label; + } + + public static double predictValues(Model model, double[] x, double[] dec_values) { + int n; + if (model.bias >= 0) + n = model.nr_feature + 1; + else + n = model.nr_feature; + + final double[] w = model.w; + + int nr_w; + if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) + nr_w = 1; + else + nr_w = model.nr_class; + + for (int i = 0; i < nr_w; i++) + dec_values[i] = 0; + + for (int idx = 0; idx < n; idx++) { + // the dimension of testing data may exceed that of training + for (int i = 0; i < nr_w; i++) { + dec_values[i] += w[idx * nr_w + i] * x[idx]; + } + } + + if (model.nr_class == 2) { + if (model.solverType.isSupportVectorRegression()) + return dec_values[0]; + else + return (dec_values[0] > 0) ? model.label[0] : model.label[1]; + } else { + int dec_max_idx = 0; + for (int i = 1; i < model.nr_class; i++) { + if (dec_values[i] > dec_values[dec_max_idx]) + dec_max_idx = i; + } + return model.label[dec_max_idx]; + } + } + + static void printf(Formatter formatter, String format, Object... args) throws IOException { + formatter.format(format, args); + final IOException ioException = formatter.ioException(); + if (ioException != null) + throw ioException; + } + + /** + * Writes the model to the modelOutput. It uses + * {@link java.util.Locale#ENGLISH} for number formatting. + * + *

+ * Note: The modelOutput is closed after reading or in case of an + * exception. + *

+ */ + public static void saveModel(Writer modelOutput, Model model) throws IOException { + final int nr_feature = model.nr_feature; + int w_size = nr_feature; + if (model.bias >= 0) + w_size++; + + int nr_w = model.nr_class; + if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) + nr_w = 1; + + final Formatter formatter = new Formatter(modelOutput, DEFAULT_LOCALE); + try { + printf(formatter, "solver_type %s\n", model.solverType.name()); + printf(formatter, "nr_class %d\n", model.nr_class); + + if (model.label != null) { + printf(formatter, "label"); + for (int i = 0; i < model.nr_class; i++) { + printf(formatter, " %d", model.label[i]); + } + printf(formatter, "\n"); + } + + printf(formatter, "nr_feature %d\n", nr_feature); + printf(formatter, "bias %.16g\n", model.bias); + + printf(formatter, "w\n"); + for (int i = 0; i < w_size; i++) { + for (int j = 0; j < nr_w; j++) { + final double value = model.w[i * nr_w + j]; + + /** + * this optimization is the reason for + * {@link Model#equals(double[], double[])} + */ + if (value == 0.0) { + printf(formatter, "%d ", 0); + } else { + printf(formatter, "%.16g ", value); + } + } + printf(formatter, "\n"); + } + + formatter.flush(); + final IOException ioException = formatter.ioException(); + if (ioException != null) + throw ioException; + } finally { + formatter.close(); + } + } + + /** + * Writes the model to the file with ISO-8859-1 charset. It uses + * {@link java.util.Locale#ENGLISH} for number formatting. + */ + public static void saveModel(File modelFile, Model model) throws IOException { + final BufferedWriter modelOutput = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(modelFile), + FILE_CHARSET)); + saveModel(modelOutput, model); + } + + /* + * this method corresponds to the following define in the C version: #define + * GETI(i) (y[i]+1) + */ + private static int GETI(byte[] y, int i) { + return y[i] + 1; + } + + /** + * A coordinate descent algorithm for L1-loss and L2-loss SVM dual problems + * + *
+	 *  min_\alpha  0.5(\alpha^T (Q + D)\alpha) - e^T \alpha,
+	 *    s.t.      0 <= \alpha_i <= upper_bound_i,
+	 * 
+	 *  where Qij = yi yj xi^T xj and
+	 *  D is a diagonal matrix
+	 * 
+	 * In L1-SVM case:
+	 *     upper_bound_i = Cp if y_i = 1
+	 *      upper_bound_i = Cn if y_i = -1
+	 *      D_ii = 0
+	 * In L2-SVM case:
+	 *      upper_bound_i = INF
+	 *      D_ii = 1/(2*Cp) if y_i = 1
+	 *      D_ii = 1/(2*Cn) if y_i = -1
+	 * 
+	 * Given:
+	 * x, y, Cp, Cn
+	 * eps is the stopping tolerance
+	 * 
+	 * solution will be put in w
+	 * 
+	 * See Algorithm 3 of Hsieh et al., ICML 2008
+	 * 
+ */ + private static void solve_l2r_l1l2_svc(Problem prob, double[] w, double eps, double Cp, double Cn, + SolverType solver_type) + { + final int l = prob.l; + final int w_size = prob.n; + int i, s, iter = 0; + double C, d, G; + final double[] QD = new double[l]; + final int max_iter = 1000; + final int[] index = new int[l]; + final double[] alpha = new double[l]; + final byte[] y = new byte[l]; + int active_size = l; + + // PG: projected gradient, for shrinking and stopping + double PG; + double PGmax_old = Double.POSITIVE_INFINITY; + double PGmin_old = Double.NEGATIVE_INFINITY; + double PGmax_new, PGmin_new; + + // default solver_type: L2R_L2LOSS_SVC_DUAL + final double diag[] = new double[] { 0.5 / Cn, 0, 0.5 / Cp }; + final double upper_bound[] = new double[] { Double.POSITIVE_INFINITY, 0, Double.POSITIVE_INFINITY }; + if (solver_type == SolverType.L2R_L1LOSS_SVC_DUAL) { + diag[0] = 0; + diag[2] = 0; + upper_bound[0] = Cn; + upper_bound[2] = Cp; + } + + for (i = 0; i < l; i++) { + if (prob.y[i] > 0) { + y[i] = +1; + } else { + y[i] = -1; + } + } + + // Initial alpha can be set here. Note that + // 0 <= alpha[i] <= upper_bound[GETI(i)] + for (i = 0; i < l; i++) + alpha[i] = 0; + + for (i = 0; i < w_size; i++) + w[i] = 0; + for (i = 0; i < l; i++) { + QD[i] = diag[GETI(y, i)]; + + for (int j = 0; j < w_size; j++) { + final double val = prob.x[i][j]; + QD[i] += val * val; + w[j] += y[i] * alpha[i] * val; + } + index[i] = i; + } + + while (iter < max_iter) { + PGmax_new = Double.NEGATIVE_INFINITY; + PGmin_new = Double.POSITIVE_INFINITY; + + for (i = 0; i < active_size; i++) { + final int j = i + random.nextInt(active_size - i); + swap(index, i, j); + } + + for (s = 0; s < active_size; s++) { + i = index[s]; + G = 0; + final byte yi = y[i]; + + for (int j = 0; j < w_size; j++) { + G += w[j] * prob.x[i][j]; + } + G = G * yi - 1; + + C = upper_bound[GETI(y, i)]; + G += alpha[i] * diag[GETI(y, i)]; + + PG = 0; + if (alpha[i] == 0) { + if (G > PGmax_old) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } else if (G < 0) { + PG = G; + } + } else if (alpha[i] == C) { + if (G < PGmin_old) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } else if (G > 0) { + PG = G; + } + } else { + PG = G; + } + + PGmax_new = Math.max(PGmax_new, PG); + PGmin_new = Math.min(PGmin_new, PG); + + if (Math.abs(PG) > 1.0e-12) { + final double alpha_old = alpha[i]; + alpha[i] = Math.min(Math.max(alpha[i] - G / QD[i], 0.0), C); + d = (alpha[i] - alpha_old) * yi; + + for (int j = 0; j < w_size; j++) { + w[j] += d * prob.x[i][j]; + } + } + } + + iter++; + if (iter % 10 == 0) + info("."); + + if (PGmax_new - PGmin_new <= eps) { + if (active_size == l) + break; + else { + active_size = l; + info("*"); + PGmax_old = Double.POSITIVE_INFINITY; + PGmin_old = Double.NEGATIVE_INFINITY; + continue; + } + } + PGmax_old = PGmax_new; + PGmin_old = PGmin_new; + if (PGmax_old <= 0) + PGmax_old = Double.POSITIVE_INFINITY; + if (PGmin_old >= 0) + PGmin_old = Double.NEGATIVE_INFINITY; + } + + info("%noptimization finished, #iter = %d%n", iter); + if (iter >= max_iter) + info("%nWARNING: reaching max number of iterations%nUsing -s 2 may be faster (also see FAQ)%n%n"); + + // calculate objective value + + double v = 0; + int nSV = 0; + for (i = 0; i < w_size; i++) + v += w[i] * w[i]; + for (i = 0; i < l; i++) { + v += alpha[i] * (alpha[i] * diag[GETI(y, i)] - 2); + if (alpha[i] > 0) + ++nSV; + } + info("Objective value = %g%n", v / 2); + info("nSV = %d%n", nSV); + } + + // To support weights for instances, use GETI(i) (i) + private static int GETI_SVR(int i) { + return 0; + } + + /** + * A coordinate descent algorithm for L1-loss and L2-loss epsilon-SVR dual + * problem + * + * min_\beta 0.5\beta^T (Q + diag(lambda)) \beta - p \sum_{i=1}^l|\beta_i| + + * \sum_{i=1}^l yi\beta_i, s.t. -upper_bound_i <= \beta_i <= upper_bound_i, + * + * where Qij = xi^T xj and D is a diagonal matrix + * + * In L1-SVM case: upper_bound_i = C lambda_i = 0 In L2-SVM case: + * upper_bound_i = INF lambda_i = 1/(2*C) + * + * Given: x, y, p, C eps is the stopping tolerance + * + * solution will be put in w + * + * See Algorithm 4 of Ho and Lin, 2012 + */ + private static void solve_l2r_l1l2_svr(Problem prob, double[] w, Parameter param) { + final int l = prob.l; + final double C = param.C; + final double p = param.p; + final int w_size = prob.n; + final double eps = param.eps; + int i, s, iter = 0; + final int max_iter = 1000; + int active_size = l; + final int[] index = new int[l]; + + double d, G, H; + double Gmax_old = Double.POSITIVE_INFINITY; + double Gmax_new, Gnorm1_new; + double Gnorm1_init = 0; // initialize to 0 to get rid of Eclipse + // warning/error + final double[] beta = new double[l]; + final double[] QD = new double[l]; + final double[] y = prob.y; + + // L2R_L2LOSS_SVR_DUAL + final double[] lambda = new double[] { 0.5 / C }; + final double[] upper_bound = new double[] { Double.POSITIVE_INFINITY }; + + if (param.solverType == SolverType.L2R_L1LOSS_SVR_DUAL) { + lambda[0] = 0; + upper_bound[0] = C; + } + + // Initial beta can be set here. Note that + // -upper_bound <= beta[i] <= upper_bound + for (i = 0; i < l; i++) + beta[i] = 0; + + for (i = 0; i < w_size; i++) + w[i] = 0; + for (i = 0; i < l; i++) { + QD[i] = 0; + for (int j = 0; j < w_size; j++) { + final double val = prob.x[i][j]; + QD[i] += val * val; + w[j] += beta[i] * val; + } + + index[i] = i; + } + + while (iter < max_iter) { + Gmax_new = 0; + Gnorm1_new = 0; + + for (i = 0; i < active_size; i++) { + final int j = i + random.nextInt(active_size - i); + swap(index, i, j); + } + + for (s = 0; s < active_size; s++) { + i = index[s]; + G = -y[i] + lambda[GETI_SVR(i)] * beta[i]; + H = QD[i] + lambda[GETI_SVR(i)]; + + for (int ind = 0; ind < w_size; ind++) { + final double val = prob.x[i][ind]; + G += val * w[ind]; + } + + final double Gp = G + p; + final double Gn = G - p; + double violation = 0; + if (beta[i] == 0) { + if (Gp < 0) + violation = -Gp; + else if (Gn > 0) + violation = Gn; + else if (Gp > Gmax_old && Gn < -Gmax_old) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } + } else if (beta[i] >= upper_bound[GETI_SVR(i)]) { + if (Gp > 0) + violation = Gp; + else if (Gp < -Gmax_old) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } + } else if (beta[i] <= -upper_bound[GETI_SVR(i)]) { + if (Gn < 0) + violation = -Gn; + else if (Gn > Gmax_old) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } + } else if (beta[i] > 0) + violation = Math.abs(Gp); + else + violation = Math.abs(Gn); + + Gmax_new = Math.max(Gmax_new, violation); + Gnorm1_new += violation; + + // obtain Newton direction d + if (Gp < H * beta[i]) + d = -Gp / H; + else if (Gn > H * beta[i]) + d = -Gn / H; + else + d = -beta[i]; + + if (Math.abs(d) < 1.0e-12) + continue; + + final double beta_old = beta[i]; + beta[i] = Math.min(Math.max(beta[i] + d, -upper_bound[GETI_SVR(i)]), upper_bound[GETI_SVR(i)]); + d = beta[i] - beta_old; + + if (d != 0) { + for (int j = 0; j < w_size; j++) { + w[j] += d * prob.x[i][j]; + } + } + } + + if (iter == 0) + Gnorm1_init = Gnorm1_new; + iter++; + if (iter % 10 == 0) + info("."); + + if (Gnorm1_new <= eps * Gnorm1_init) { + if (active_size == l) + break; + else { + active_size = l; + info("*"); + Gmax_old = Double.POSITIVE_INFINITY; + continue; + } + } + + Gmax_old = Gmax_new; + } + + info("%noptimization finished, #iter = %d%n", iter); + if (iter >= max_iter) + info("%nWARNING: reaching max number of iterations%nUsing -s 11 may be faster%n%n"); + + // calculate objective value + double v = 0; + int nSV = 0; + for (i = 0; i < w_size; i++) + v += w[i] * w[i]; + v = 0.5 * v; + for (i = 0; i < l; i++) { + v += p * Math.abs(beta[i]) - y[i] * beta[i] + 0.5 * lambda[GETI_SVR(i)] * beta[i] * beta[i]; + if (beta[i] != 0) + nSV++; + } + + info("Objective value = %g%n", v); + info("nSV = %d%n", nSV); + } + + /** + * A coordinate descent algorithm for the dual of L2-regularized logistic + * regression problems + * + *
+	 *  min_\alpha  0.5(\alpha^T Q \alpha) + \sum \alpha_i log (\alpha_i) + (upper_bound_i - \alpha_i) log (upper_bound_i - \alpha_i) ,
+	 *     s.t.      0 <= \alpha_i <= upper_bound_i,
+	 * 
+	 *  where Qij = yi yj xi^T xj and
+	 *  upper_bound_i = Cp if y_i = 1
+	 *  upper_bound_i = Cn if y_i = -1
+	 * 
+	 * Given:
+	 * x, y, Cp, Cn
+	 * eps is the stopping tolerance
+	 * 
+	 * solution will be put in w
+	 * 
+	 * See Algorithm 5 of Yu et al., MLJ 2010
+	 * 
+ * + * @since 1.7 + */ + private static void solve_l2r_lr_dual(Problem prob, double w[], double eps, double Cp, double Cn) { + final int l = prob.l; + final int w_size = prob.n; + int i, s, iter = 0; + final double xTx[] = new double[l]; + final int max_iter = 1000; + final int index[] = new int[l]; + final double alpha[] = new double[2 * l]; // store alpha and C - alpha + final byte y[] = new byte[l]; + final int max_inner_iter = 100; // for inner Newton + double innereps = 1e-2; + final double innereps_min = Math.min(1e-8, eps); + final double upper_bound[] = new double[] { Cn, 0, Cp }; + + for (i = 0; i < l; i++) { + if (prob.y[i] > 0) { + y[i] = +1; + } else { + y[i] = -1; + } + } + + // Initial alpha can be set here. Note that + // 0 < alpha[i] < upper_bound[GETI(i)] + // alpha[2*i] + alpha[2*i+1] = upper_bound[GETI(i)] + for (i = 0; i < l; i++) { + alpha[2 * i] = Math.min(0.001 * upper_bound[GETI(y, i)], 1e-8); + alpha[2 * i + 1] = upper_bound[GETI(y, i)] - alpha[2 * i]; + } + + for (i = 0; i < w_size; i++) + w[i] = 0; + for (i = 0; i < l; i++) { + xTx[i] = 0; + for (int j = 0; j < w_size; j++) { + final double val = prob.x[i][j]; + xTx[i] += val * val; + w[j] += y[i] * alpha[2 * i] * val; + } + index[i] = i; + } + + while (iter < max_iter) { + for (i = 0; i < l; i++) { + final int j = i + random.nextInt(l - i); + swap(index, i, j); + } + int newton_iter = 0; + double Gmax = 0; + for (s = 0; s < l; s++) { + i = index[s]; + final byte yi = y[i]; + final double C = upper_bound[GETI(y, i)]; + double ywTx = 0; + final double xisq = xTx[i]; + for (int j = 0; j < w_size; j++) { + ywTx += w[j] * prob.x[i][j]; + } + ywTx *= y[i]; + final double a = xisq, b = ywTx; + + // Decide to minimize g_1(z) or g_2(z) + int ind1 = 2 * i, ind2 = 2 * i + 1, sign = 1; + if (0.5 * a * (alpha[ind2] - alpha[ind1]) + b < 0) { + ind1 = 2 * i + 1; + ind2 = 2 * i; + sign = -1; + } + + // g_t(z) = z*log(z) + (C-z)*log(C-z) + 0.5a(z-alpha_old)^2 + + // sign*b(z-alpha_old) + final double alpha_old = alpha[ind1]; + double z = alpha_old; + if (C - z < 0.5 * C) + z = 0.1 * z; + double gp = a * (z - alpha_old) + sign * b + Math.log(z / (C - z)); + Gmax = Math.max(Gmax, Math.abs(gp)); + + // Newton method on the sub-problem + final double eta = 0.1; // xi in the paper + int inner_iter = 0; + while (inner_iter <= max_inner_iter) { + if (Math.abs(gp) < innereps) + break; + final double gpp = a + C / (C - z) / z; + final double tmpz = z - gp / gpp; + if (tmpz <= 0) + z *= eta; + else + // tmpz in (0, C) + z = tmpz; + gp = a * (z - alpha_old) + sign * b + Math.log(z / (C - z)); + newton_iter++; + inner_iter++; + } + + if (inner_iter > 0) // update w + { + alpha[ind1] = z; + alpha[ind2] = C - z; + for (int j = 0; j < w_size; j++) { + w[j] += sign * (z - alpha_old) * yi * prob.x[i][j]; + } + } + } + + iter++; + if (iter % 10 == 0) + info("."); + + if (Gmax < eps) + break; + + if (newton_iter <= l / 10) { + innereps = Math.max(innereps_min, 0.1 * innereps); + } + + } + + info("%noptimization finished, #iter = %d%n", iter); + if (iter >= max_iter) + info("%nWARNING: reaching max number of iterations%nUsing -s 0 may be faster (also see FAQ)%n%n"); + + // calculate objective value + + double v = 0; + for (i = 0; i < w_size; i++) + v += w[i] * w[i]; + v *= 0.5; + for (i = 0; i < l; i++) + v += alpha[2 * i] * Math.log(alpha[2 * i]) + alpha[2 * i + 1] * Math.log(alpha[2 * i + 1]) + - upper_bound[GETI(y, i)] + * Math.log(upper_bound[GETI(y, i)]); + info("Objective value = %g%n", v); + } + + /** + * A coordinate descent algorithm for L1-regularized L2-loss support vector + * classification + * + *
+	 *  min_w \sum |wj| + C \sum max(0, 1-yi w^T xi)^2,
+	 * 
+	 * Given:
+	 * x, y, Cp, Cn
+	 * eps is the stopping tolerance
+	 * 
+	 * solution will be put in w
+	 * 
+	 * See Yuan et al. (2010) and appendix of LIBLINEAR paper, Fan et al. (2008)
+	 * 
+ * + * @since 1.5 + */ + private static void solve_l1r_l2_svc(Problem prob_col, double[] w, double eps, double Cp, double Cn) { + final int l = prob_col.l; + final int w_size = prob_col.n; + int j, s, iter = 0; + final int max_iter = 1000; + int active_size = w_size; + final int max_num_linesearch = 20; + + final double sigma = 0.01; + double d, G_loss, G, H; + double Gmax_old = Double.POSITIVE_INFINITY; + double Gmax_new, Gnorm1_new; + double Gnorm1_init = 0; // eclipse moans this variable might not be + // initialized + double d_old, d_diff; + double loss_old = 0; // eclipse moans this variable might not be + // initialized + double loss_new; + double appxcond, cond; + + final int[] index = new int[w_size]; + final byte[] y = new byte[l]; + final double[] b = new double[l]; // b = 1-ywTx + final double[] xj_sq = new double[w_size]; + + final double[] C = new double[] { Cn, 0, Cp }; + + // Initial w can be set here. + for (j = 0; j < w_size; j++) + w[j] = 0; + + for (j = 0; j < l; j++) { + b[j] = 1; + if (prob_col.y[j] > 0) + y[j] = 1; + else + y[j] = -1; + } + for (j = 0; j < w_size; j++) { + index[j] = j; + xj_sq[j] = 0; + for (int ind = 0; ind < w_size; ind++) { + prob_col.x[j][ind] = prob_col.x[j][ind] * y[ind]; // x->value + // stores + // yi*xij + final double val = prob_col.x[j][ind]; + b[ind] -= w[j] * val; + + xj_sq[j] += C[GETI(y, ind)] * val * val; + } + } + + while (iter < max_iter) { + Gmax_new = 0; + Gnorm1_new = 0; + + for (j = 0; j < active_size; j++) { + final int i = j + random.nextInt(active_size - j); + swap(index, i, j); + } + + for (s = 0; s < active_size; s++) { + j = index[s]; + G_loss = 0; + H = 0; + + for (int ind = 0; ind < w_size; ind++) { + if (b[ind] > 0) { + final double val = prob_col.x[j][ind]; + final double tmp = C[GETI(y, ind)] * val; + G_loss -= tmp * b[ind]; + H += tmp * val; + } + } + G_loss *= 2; + + G = G_loss; + H *= 2; + H = Math.max(H, 1e-12); + + final double Gp = G + 1; + final double Gn = G - 1; + double violation = 0; + if (w[j] == 0) { + if (Gp < 0) + violation = -Gp; + else if (Gn > 0) + violation = Gn; + else if (Gp > Gmax_old / l && Gn < -Gmax_old / l) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } + } else if (w[j] > 0) + violation = Math.abs(Gp); + else + violation = Math.abs(Gn); + + Gmax_new = Math.max(Gmax_new, violation); + Gnorm1_new += violation; + + // obtain Newton direction d + if (Gp < H * w[j]) + d = -Gp / H; + else if (Gn > H * w[j]) + d = -Gn / H; + else + d = -w[j]; + + if (Math.abs(d) < 1.0e-12) + continue; + + double delta = Math.abs(w[j] + d) - Math.abs(w[j]) + G * d; + d_old = 0; + int num_linesearch; + for (num_linesearch = 0; num_linesearch < max_num_linesearch; num_linesearch++) { + d_diff = d_old - d; + cond = Math.abs(w[j] + d) - Math.abs(w[j]) - sigma * delta; + + appxcond = xj_sq[j] * d * d + G_loss * d + cond; + if (appxcond <= 0) { + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + b[ind] += d_diff * prob_col.x[j][ind]; + } + break; + } + + if (num_linesearch == 0) { + loss_old = 0; + loss_new = 0; + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + if (b[ind] > 0) { + loss_old += C[GETI(y, ind)] * b[ind] * b[ind]; + } + final double b_new = b[ind] + d_diff * prob_col.x[j][ind]; + b[ind] = b_new; + if (b_new > 0) { + loss_new += C[GETI(y, ind)] * b_new * b_new; + } + } + } else { + loss_new = 0; + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + final double b_new = b[ind] + d_diff * prob_col.x[j][ind]; + b[ind] = b_new; + if (b_new > 0) { + loss_new += C[GETI(y, ind)] * b_new * b_new; + } + } + } + + cond = cond + loss_new - loss_old; + if (cond <= 0) + break; + else { + d_old = d; + d *= 0.5; + delta *= 0.5; + } + } + + w[j] += d; + + // recompute b[] if line search takes too many steps + if (num_linesearch >= max_num_linesearch) { + info("#"); + for (int i = 0; i < l; i++) + b[i] = 1; + + for (int i = 0; i < w_size; i++) { + if (w[i] == 0) + continue; + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + b[ind] -= w[i] * prob_col.x[j][ind]; + } + } + } + } + + if (iter == 0) { + Gnorm1_init = Gnorm1_new; + } + iter++; + if (iter % 10 == 0) + info("."); + + if (Gmax_new <= eps * Gnorm1_init) { + if (active_size == w_size) + break; + else { + active_size = w_size; + info("*"); + Gmax_old = Double.POSITIVE_INFINITY; + continue; + } + } + + Gmax_old = Gmax_new; + } + + info("%noptimization finished, #iter = %d%n", iter); + if (iter >= max_iter) + info("%nWARNING: reaching max number of iterations%n"); + + // calculate objective value + + double v = 0; + int nnz = 0; + for (j = 0; j < w_size; j++) { + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + prob_col.x[j][ind] = prob_col.x[j][ind] * prob_col.y[ind]; // restore + // x->value + } + if (w[j] != 0) { + v += Math.abs(w[j]); + nnz++; + } + } + for (j = 0; j < l; j++) + if (b[j] > 0) + v += C[GETI(y, j)] * b[j] * b[j]; + + info("Objective value = %g%n", v); + info("#nonzeros/#features = %d/%d%n", nnz, w_size); + } + + /** + * A coordinate descent algorithm for L1-regularized logistic regression + * problems + * + *
+	 *  min_w \sum |wj| + C \sum log(1+exp(-yi w^T xi)),
+	 * 
+	 * Given:
+	 * x, y, Cp, Cn
+	 * eps is the stopping tolerance
+	 * 
+	 * solution will be put in w
+	 * 
+	 * See Yuan et al. (2011) and appendix of LIBLINEAR paper, Fan et al. (2008)
+	 * 
+ * + * @since 1.5 + */ + private static void solve_l1r_lr(Problem prob_col, double[] w, double eps, double Cp, double Cn) { + final int l = prob_col.l; + final int w_size = prob_col.n; + int j, s, newton_iter = 0, iter = 0; + final int max_newton_iter = 100; + final int max_iter = 1000; + final int max_num_linesearch = 20; + int active_size; + int QP_active_size; + + final double nu = 1e-12; + double inner_eps = 1; + final double sigma = 0.01; + double w_norm, w_norm_new; + double z, G, H; + double Gnorm1_init = 0; // eclipse moans this variable might not be + // initialized + double Gmax_old = Double.POSITIVE_INFINITY; + double Gmax_new, Gnorm1_new; + double QP_Gmax_old = Double.POSITIVE_INFINITY; + double QP_Gmax_new, QP_Gnorm1_new; + double delta, negsum_xTd, cond; + + final int[] index = new int[w_size]; + final byte[] y = new byte[l]; + final double[] Hdiag = new double[w_size]; + final double[] Grad = new double[w_size]; + final double[] wpd = new double[w_size]; + final double[] xjneg_sum = new double[w_size]; + final double[] xTd = new double[l]; + final double[] exp_wTx = new double[l]; + final double[] exp_wTx_new = new double[l]; + final double[] tau = new double[l]; + final double[] D = new double[l]; + + final double[] C = { Cn, 0, Cp }; + + // Initial w can be set here. + for (j = 0; j < w_size; j++) + w[j] = 0; + + for (j = 0; j < l; j++) { + if (prob_col.y[j] > 0) + y[j] = 1; + else + y[j] = -1; + + exp_wTx[j] = 0; + } + + w_norm = 0; + for (j = 0; j < w_size; j++) { + w_norm += Math.abs(w[j]); + wpd[j] = w[j]; + index[j] = j; + xjneg_sum[j] = 0; + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + final double val = prob_col.x[j][ind]; + exp_wTx[ind] += w[j] * val; + if (y[ind] == -1) { + xjneg_sum[j] += C[GETI(y, ind)] * val; + } + } + } + for (j = 0; j < l; j++) { + exp_wTx[j] = Math.exp(exp_wTx[j]); + final double tau_tmp = 1 / (1 + exp_wTx[j]); + tau[j] = C[GETI(y, j)] * tau_tmp; + D[j] = C[GETI(y, j)] * exp_wTx[j] * tau_tmp * tau_tmp; + } + + while (newton_iter < max_newton_iter) { + Gmax_new = 0; + Gnorm1_new = 0; + active_size = w_size; + + for (s = 0; s < active_size; s++) { + j = index[s]; + Hdiag[j] = nu; + Grad[j] = 0; + + double tmp = 0; + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + Hdiag[j] += prob_col.x[j][ind] * prob_col.x[j][ind] * D[ind]; + tmp += prob_col.x[j][ind] * tau[ind]; + } + Grad[j] = -tmp + xjneg_sum[j]; + + final double Gp = Grad[j] + 1; + final double Gn = Grad[j] - 1; + double violation = 0; + if (w[j] == 0) { + if (Gp < 0) + violation = -Gp; + else if (Gn > 0) + violation = Gn; + // outer-level shrinking + else if (Gp > Gmax_old / l && Gn < -Gmax_old / l) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } + } else if (w[j] > 0) + violation = Math.abs(Gp); + else + violation = Math.abs(Gn); + + Gmax_new = Math.max(Gmax_new, violation); + Gnorm1_new += violation; + } + + if (newton_iter == 0) + Gnorm1_init = Gnorm1_new; + + if (Gnorm1_new <= eps * Gnorm1_init) + break; + + iter = 0; + QP_Gmax_old = Double.POSITIVE_INFINITY; + QP_active_size = active_size; + + for (int i = 0; i < l; i++) + xTd[i] = 0; + + // optimize QP over wpd + while (iter < max_iter) { + QP_Gmax_new = 0; + QP_Gnorm1_new = 0; + + for (j = 0; j < QP_active_size; j++) { + final int i = random.nextInt(QP_active_size - j); + swap(index, i, j); + } + + for (s = 0; s < QP_active_size; s++) { + j = index[s]; + H = Hdiag[j]; + + G = Grad[j] + (wpd[j] - w[j]) * nu; + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + G += prob_col.x[j][ind] * D[ind] * xTd[ind]; + } + + final double Gp = G + 1; + final double Gn = G - 1; + double violation = 0; + if (wpd[j] == 0) { + if (Gp < 0) + violation = -Gp; + else if (Gn > 0) + violation = Gn; + // inner-level shrinking + else if (Gp > QP_Gmax_old / l && Gn < -QP_Gmax_old / l) { + QP_active_size--; + swap(index, s, QP_active_size); + s--; + continue; + } + } else if (wpd[j] > 0) + violation = Math.abs(Gp); + else + violation = Math.abs(Gn); + + QP_Gmax_new = Math.max(QP_Gmax_new, violation); + QP_Gnorm1_new += violation; + + // obtain solution of one-variable problem + if (Gp < H * wpd[j]) + z = -Gp / H; + else if (Gn > H * wpd[j]) + z = -Gn / H; + else + z = -wpd[j]; + + if (Math.abs(z) < 1.0e-12) + continue; + z = Math.min(Math.max(z, -10.0), 10.0); + + wpd[j] += z; + + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + xTd[ind] += prob_col.x[j][ind] * z; + } + } + + iter++; + + if (QP_Gnorm1_new <= inner_eps * Gnorm1_init) { + // inner stopping + if (QP_active_size == active_size) + break; + // active set reactivation + else { + QP_active_size = active_size; + QP_Gmax_old = Double.POSITIVE_INFINITY; + continue; + } + } + + QP_Gmax_old = QP_Gmax_new; + } + + if (iter >= max_iter) + info("WARNING: reaching max number of inner iterations%n"); + + delta = 0; + w_norm_new = 0; + for (j = 0; j < w_size; j++) { + delta += Grad[j] * (wpd[j] - w[j]); + if (wpd[j] != 0) + w_norm_new += Math.abs(wpd[j]); + } + delta += (w_norm_new - w_norm); + + negsum_xTd = 0; + for (int i = 0; i < l; i++) + if (y[i] == -1) + negsum_xTd += C[GETI(y, i)] * xTd[i]; + + int num_linesearch; + for (num_linesearch = 0; num_linesearch < max_num_linesearch; num_linesearch++) { + cond = w_norm_new - w_norm + negsum_xTd - sigma * delta; + + for (int i = 0; i < l; i++) { + final double exp_xTd = Math.exp(xTd[i]); + exp_wTx_new[i] = exp_wTx[i] * exp_xTd; + cond += C[GETI(y, i)] * Math.log((1 + exp_wTx_new[i]) / (exp_xTd + exp_wTx_new[i])); + } + + if (cond <= 0) { + w_norm = w_norm_new; + for (j = 0; j < w_size; j++) + w[j] = wpd[j]; + for (int i = 0; i < l; i++) { + exp_wTx[i] = exp_wTx_new[i]; + final double tau_tmp = 1 / (1 + exp_wTx[i]); + tau[i] = C[GETI(y, i)] * tau_tmp; + D[i] = C[GETI(y, i)] * exp_wTx[i] * tau_tmp * tau_tmp; + } + break; + } else { + w_norm_new = 0; + for (j = 0; j < w_size; j++) { + wpd[j] = (w[j] + wpd[j]) * 0.5; + if (wpd[j] != 0) + w_norm_new += Math.abs(wpd[j]); + } + delta *= 0.5; + negsum_xTd *= 0.5; + for (int i = 0; i < l; i++) + xTd[i] *= 0.5; + } + } + + // Recompute some info due to too many line search steps + if (num_linesearch >= max_num_linesearch) { + for (int i = 0; i < l; i++) + exp_wTx[i] = 0; + + for (int i = 0; i < w_size; i++) { + if (w[i] == 0) + continue; + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + exp_wTx[ind] += w[i] * prob_col.x[i][ind]; + } + } + + for (int i = 0; i < l; i++) + exp_wTx[i] = Math.exp(exp_wTx[i]); + } + + if (iter == 1) + inner_eps *= 0.25; + + newton_iter++; + Gmax_old = Gmax_new; + + info("iter %3d #CD cycles %d%n", newton_iter, iter); + } + + info("=========================%n"); + info("optimization finished, #iter = %d%n", newton_iter); + if (newton_iter >= max_newton_iter) + info("WARNING: reaching max number of iterations%n"); + + // calculate objective value + + double v = 0; + int nnz = 0; + for (j = 0; j < w_size; j++) + if (w[j] != 0) { + v += Math.abs(w[j]); + nnz++; + } + for (j = 0; j < l; j++) + if (y[j] == 1) + v += C[GETI(y, j)] * Math.log(1 + 1 / exp_wTx[j]); + else + v += C[GETI(y, j)] * Math.log(1 + exp_wTx[j]); + + info("Objective value = %g%n", v); + info("#nonzeros/#features = %d/%d%n", nnz, w_size); + } + + // transpose matrix X from row format to column format + static Problem transpose(Problem prob) { + final int l = prob.l; + final int n = prob.n; + final Problem prob_col = new Problem(); + prob_col.l = l; + prob_col.n = n; + prob_col.y = new double[l]; + prob_col.x = new double[n][]; + + for (int i = 0; i < l; i++) + prob_col.y[i] = prob.y[i]; + + for (int i = 0; i < n; i++) { + prob_col.x[i] = new double[l]; + } + + for (int i = 0; i < l; i++) { + for (int j = 0; j < n; j++) { + prob_col.x[j][i] = prob.x[i][j]; + } + } + + return prob_col; + } + + static void swap(double[] array, int idxA, int idxB) { + final double temp = array[idxA]; + array[idxA] = array[idxB]; + array[idxB] = temp; + } + + static void swap(int[] array, int idxA, int idxB) { + final int temp = array[idxA]; + array[idxA] = array[idxB]; + array[idxB] = temp; + } + + static void swap(IntArrayPointer array, int idxA, int idxB) { + final int temp = array.get(idxA); + array.set(idxA, array.get(idxB)); + array.set(idxB, temp); + } + + /** + * @throws IllegalArgumentException + * if the feature nodes of prob are not sorted in ascending + * order + */ + public static Model train(Problem prob, Parameter param) { + + if (prob == null) + throw new IllegalArgumentException("problem must not be null"); + if (param == null) + throw new IllegalArgumentException("parameter must not be null"); + + if (prob.n == 0) + throw new IllegalArgumentException("problem has zero features"); + if (prob.l == 0) + throw new IllegalArgumentException("problem has zero instances"); + + final int l = prob.l; + final int n = prob.n; + final int w_size = prob.n; + final Model model = new Model(); + + if (prob.bias >= 0) + model.nr_feature = n - 1; + else + model.nr_feature = n; + + model.solverType = param.solverType; + model.bias = prob.bias; + + if (param.solverType == SolverType.L2R_L2LOSS_SVR || // + param.solverType == SolverType.L2R_L1LOSS_SVR_DUAL || // + param.solverType == SolverType.L2R_L2LOSS_SVR_DUAL) + { + model.w = new double[w_size]; + model.nr_class = 2; + model.label = null; + + checkProblemSize(n, model.nr_class); + + train_one(prob, param, model.w, 0, 0); + } else { + final int[] perm = new int[l]; + + // group training data of the same class + final GroupClassesReturn rv = groupClasses(prob, perm); + final int nr_class = rv.nr_class; + final int[] label = rv.label; + final int[] start = rv.start; + final int[] count = rv.count; + + checkProblemSize(n, nr_class); + + model.nr_class = nr_class; + model.label = new int[nr_class]; + for (int i = 0; i < nr_class; i++) + model.label[i] = label[i]; + + // calculate weighted C + final double[] weighted_C = new double[nr_class]; + for (int i = 0; i < nr_class; i++) + weighted_C[i] = param.C; + for (int i = 0; i < param.getNumWeights(); i++) { + int j; + for (j = 0; j < nr_class; j++) + if (param.weightLabel[i] == label[j]) + break; + + if (j == nr_class) + throw new IllegalArgumentException("class label " + param.weightLabel[i] + + " specified in weight is not found"); + weighted_C[j] *= param.weight[i]; + } + + // constructing the subproblem + final double[][] x = new double[l][]; + for (int i = 0; i < l; i++) + x[i] = prob.x[perm[i]]; + + final Problem sub_prob = new Problem(); + sub_prob.l = l; + sub_prob.n = n; + sub_prob.x = new double[sub_prob.l][]; + sub_prob.y = new double[sub_prob.l]; + + for (int k = 0; k < sub_prob.l; k++) + sub_prob.x[k] = x[k]; + + // multi-class svm by Crammer and Singer + if (param.solverType == SolverType.MCSVM_CS) { + model.w = new double[n * nr_class]; + for (int i = 0; i < nr_class; i++) { + for (int j = start[i]; j < start[i] + count[i]; j++) { + sub_prob.y[j] = i; + } + } + + final SolverMCSVM_CS solver = new SolverMCSVM_CS(sub_prob, nr_class, weighted_C, param.eps); + solver.solve(model.w); + } else { + if (nr_class == 2) { + model.w = new double[w_size]; + + final int e0 = start[0] + count[0]; + int k = 0; + for (; k < e0; k++) + sub_prob.y[k] = +1; + for (; k < sub_prob.l; k++) + sub_prob.y[k] = -1; + + train_one(sub_prob, param, model.w, weighted_C[0], weighted_C[1]); + } else { + model.w = new double[w_size * nr_class]; + final double[] w = new double[w_size]; + for (int i = 0; i < nr_class; i++) { + final int si = start[i]; + final int ei = si + count[i]; + + int k = 0; + for (; k < si; k++) + sub_prob.y[k] = -1; + for (; k < ei; k++) + sub_prob.y[k] = +1; + for (; k < sub_prob.l; k++) + sub_prob.y[k] = -1; + + train_one(sub_prob, param, w, weighted_C[i], param.C); + + for (int j = 0; j < n; j++) + model.w[j * nr_class + i] = w[j]; + } + } + } + } + return model; + } + + /** + * verify the size and throw an exception early if the problem is too large + */ + private static void checkProblemSize(int n, int nr_class) { + if (n >= Integer.MAX_VALUE / nr_class || n * nr_class < 0) { + throw new IllegalArgumentException("'number of classes' * 'number of instances' is too large: " + nr_class + + "*" + n); + } + } + + private static void train_one(Problem prob, Parameter param, double[] w, double Cp, double Cn) { + final double eps = param.eps; + int pos = 0; + for (int i = 0; i < prob.l; i++) + if (prob.y[i] > 0) { + pos++; + } + final int neg = prob.l - pos; + + final double primal_solver_tol = eps * Math.max(Math.min(pos, neg), 1) / prob.l; + + Function fun_obj = null; + switch (param.solverType) { + case L2R_LR: { + final double[] C = new double[prob.l]; + for (int i = 0; i < prob.l; i++) { + if (prob.y[i] > 0) + C[i] = Cp; + else + C[i] = Cn; + } + fun_obj = new L2R_LrFunction(prob, C); + final Tron tron_obj = new Tron(fun_obj, primal_solver_tol); + tron_obj.tron(w); + break; + } + case L2R_L2LOSS_SVC: { + final double[] C = new double[prob.l]; + for (int i = 0; i < prob.l; i++) { + if (prob.y[i] > 0) + C[i] = Cp; + else + C[i] = Cn; + } + fun_obj = new L2R_L2_SvcFunction(prob, C); + final Tron tron_obj = new Tron(fun_obj, primal_solver_tol); + tron_obj.tron(w); + break; + } + case L2R_L2LOSS_SVC_DUAL: + solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, SolverType.L2R_L2LOSS_SVC_DUAL); + break; + case L2R_L1LOSS_SVC_DUAL: + solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, SolverType.L2R_L1LOSS_SVC_DUAL); + break; + case L1R_L2LOSS_SVC: { + final Problem prob_col = transpose(prob); + solve_l1r_l2_svc(prob_col, w, primal_solver_tol, Cp, Cn); + break; + } + case L1R_LR: { + final Problem prob_col = transpose(prob); + solve_l1r_lr(prob_col, w, primal_solver_tol, Cp, Cn); + break; + } + case L2R_LR_DUAL: + solve_l2r_lr_dual(prob, w, eps, Cp, Cn); + break; + case L2R_L2LOSS_SVR: { + final double[] C = new double[prob.l]; + for (int i = 0; i < prob.l; i++) + C[i] = param.C; + + fun_obj = new L2R_L2_SvrFunction(prob, C, param.p); + final Tron tron_obj = new Tron(fun_obj, param.eps); + tron_obj.tron(w); + break; + } + case L2R_L1LOSS_SVR_DUAL: + case L2R_L2LOSS_SVR_DUAL: + solve_l2r_l1l2_svr(prob, w, param); + break; + + default: + throw new IllegalStateException("unknown solver type: " + param.solverType); + } + } + + public static void disableDebugOutput() { + setDebugOutput(null); + } + + public static void enableDebugOutput() { + setDebugOutput(System.out); + } + + public static void setDebugOutput(PrintStream debugOutput) { + synchronized (OUTPUT_MUTEX) { + DEBUG_OUTPUT = debugOutput; + } + } + + /** + * resets the PRNG + * + * this is i.a. needed for regression testing (eg. the Weka wrapper) + */ + public static void resetRandom() { + random = new Random(DEFAULT_RANDOM_SEED); + } +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Model.java b/src/main/java/de/bwaldvogel/denseliblinear/Model.java new file mode 100644 index 0000000..670d858 --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/Model.java @@ -0,0 +1,178 @@ +package de.bwaldvogel.denseliblinear; + +import static de.bwaldvogel.denseliblinear.Linear.copyOf; + +import java.io.File; +import java.io.IOException; +import java.io.Reader; +import java.io.Serializable; +import java.io.Writer; +import java.util.Arrays; + + +/** + *

Model stores the model obtained from the training procedure

+ * + *

use {@link Linear#loadModel(File)} and {@link Linear#saveModel(File, Model)} to load/save it

+ */ +public final class Model implements Serializable { + + private static final long serialVersionUID = -6456047576741854834L; + + double bias; + + /** label of each class */ + int[] label; + + int nr_class; + + int nr_feature; + + SolverType solverType; + + /** feature weight array */ + double[] w; + + /** + * @return number of classes + */ + public int getNrClass() { + return nr_class; + } + + /** + * @return number of features + */ + public int getNrFeature() { + return nr_feature; + } + + public int[] getLabels() { + return copyOf(label, nr_class); + } + + /** + * The nr_feature*nr_class array w gives feature weights. We use one + * against the rest for multi-class classification, so each feature + * index corresponds to nr_class weight values. Weights are + * organized in the following way + * + *
+     * +------------------+------------------+------------+
+     * | nr_class weights | nr_class weights |  ...
+     * | for 1st feature  | for 2nd feature  |
+     * +------------------+------------------+------------+
+     * 
+ * + * If bias >= 0, x becomes [x; bias]. The number of features is + * increased by one, so w is a (nr_feature+1)*nr_class array. The + * value of bias is stored in the variable bias. + * @see #getBias() + * @return a copy of the feature weight array as described + */ + public double[] getFeatureWeights() { + return Linear.copyOf(w, w.length); + } + + /** + * @return true for logistic regression solvers + */ + public boolean isProbabilityModel() { + return solverType.isLogisticRegressionSolver(); + } + + /** + * @see #getFeatureWeights() + */ + public double getBias() { + return bias; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("Model"); + sb.append(" bias=").append(bias); + sb.append(" nr_class=").append(nr_class); + sb.append(" nr_feature=").append(nr_feature); + sb.append(" solverType=").append(solverType); + return sb.toString(); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + long temp; + temp = Double.doubleToLongBits(bias); + result = prime * result + (int)(temp ^ (temp >>> 32)); + result = prime * result + Arrays.hashCode(label); + result = prime * result + nr_class; + result = prime * result + nr_feature; + result = prime * result + ((solverType == null) ? 0 : solverType.hashCode()); + result = prime * result + Arrays.hashCode(w); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null) return false; + if (getClass() != obj.getClass()) return false; + Model other = (Model)obj; + if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false; + if (!Arrays.equals(label, other.label)) return false; + if (nr_class != other.nr_class) return false; + if (nr_feature != other.nr_feature) return false; + if (solverType == null) { + if (other.solverType != null) return false; + } else if (!solverType.equals(other.solverType)) return false; + if (!equals(w, other.w)) return false; + return true; + } + + /** + * don't use {@link Arrays#equals(double[], double[])} here, cause 0.0 and -0.0 should be handled the same + * + * @see Linear#saveModel(java.io.Writer, Model) + */ + protected static boolean equals(double[] a, double[] a2) { + if (a == a2) return true; + if (a == null || a2 == null) return false; + + int length = a.length; + if (a2.length != length) return false; + + for (int i = 0; i < length; i++) + if (a[i] != a2[i]) return false; + + return true; + } + + /** + * see {@link Linear#saveModel(java.io.File, Model)} + */ + public void save(File file) throws IOException { + Linear.saveModel(file, this); + } + + /** + * see {@link Linear#saveModel(Writer, Model)} + */ + public void save(Writer writer) throws IOException { + Linear.saveModel(writer, this); + } + + /** + * see {@link Linear#loadModel(File)} + */ + public static Model load(File file) throws IOException { + return Linear.loadModel(file); + } + + /** + * see {@link Linear#loadModel(Reader)} + */ + public static Model load(Reader inputReader) throws IOException { + return Linear.loadModel(inputReader); + } +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Parameter.java b/src/main/java/de/bwaldvogel/denseliblinear/Parameter.java new file mode 100644 index 0000000..012b0a1 --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/Parameter.java @@ -0,0 +1,120 @@ +package de.bwaldvogel.denseliblinear; + +import static de.bwaldvogel.denseliblinear.Linear.copyOf; + + +public final class Parameter { + + double C; + + /** stopping criteria */ + double eps; + + SolverType solverType; + + double[] weight = null; + + int[] weightLabel = null; + + double p; + + public Parameter( SolverType solver, double C, double eps ) { + this(solver, C, eps, 0.1); + } + + public Parameter( SolverType solverType, double C, double eps, double p ) { + setSolverType(solverType); + setC(C); + setEps(eps); + setP(p); + } + + /** + *

nr_weight, weight_label, and weight are used to change the penalty + * for some classes (If the weight for a class is not changed, it is + * set to 1). This is useful for training classifier using unbalanced + * input data or with asymmetric misclassification cost.

+ * + *

Each weight[i] corresponds to weight_label[i], meaning that + * the penalty of class weight_label[i] is scaled by a factor of weight[i].

+ * + *

If you do not want to change penalty for any of the classes, + * just set nr_weight to 0.

+ */ + public void setWeights(double[] weights, int[] weightLabels) { + if (weights == null) throw new IllegalArgumentException("'weight' must not be null"); + if (weightLabels == null || weightLabels.length != weights.length) + throw new IllegalArgumentException("'weightLabels' must have same length as 'weight'"); + this.weightLabel = copyOf(weightLabels, weightLabels.length); + this.weight = copyOf(weights, weights.length); + } + + /** + * @see #setWeights(double[], int[]) + */ + public double[] getWeights() { + return copyOf(weight, weight.length); + } + + /** + * @see #setWeights(double[], int[]) + */ + public int[] getWeightLabels() { + return copyOf(weightLabel, weightLabel.length); + } + + /** + * the number of weights + * @see #setWeights(double[], int[]) + */ + public int getNumWeights() { + if (weight == null) return 0; + return weight.length; + } + + /** + * C is the cost of constraints violation. (we usually use 1 to 1000) + */ + public void setC(double C) { + if (C <= 0) throw new IllegalArgumentException("C must not be <= 0"); + this.C = C; + } + + public double getC() { + return C; + } + + /** + * eps is the stopping criterion. (we usually use 0.01). + */ + public void setEps(double eps) { + if (eps <= 0) throw new IllegalArgumentException("eps must not be <= 0"); + this.eps = eps; + } + + public double getEps() { + return eps; + } + + public void setSolverType(SolverType solverType) { + if (solverType == null) throw new IllegalArgumentException("solver type must not be null"); + this.solverType = solverType; + } + + public SolverType getSolverType() { + return solverType; + } + + + /** + * set the epsilon in loss function of epsilon-SVR (default 0.1) + */ + public void setP(double p) { + if (p < 0) throw new IllegalArgumentException("p must not be less than 0"); + this.p = p; + } + + public double getP() { + return p; + } +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Predict.java b/src/main/java/de/bwaldvogel/denseliblinear/Predict.java new file mode 100644 index 0000000..4b60801 --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/Predict.java @@ -0,0 +1,193 @@ +package de.bwaldvogel.denseliblinear; + +import static de.bwaldvogel.denseliblinear.Linear.atof; +import static de.bwaldvogel.denseliblinear.Linear.atoi; +import static de.bwaldvogel.denseliblinear.Linear.closeQuietly; +import static de.bwaldvogel.denseliblinear.Linear.info; +import static de.bwaldvogel.denseliblinear.Linear.printf; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.Writer; +import java.util.Formatter; +import java.util.NoSuchElementException; +import java.util.StringTokenizer; +import java.util.regex.Pattern; + +public class Predict { + + private static boolean flag_predict_probability = false; + + private static final Pattern COLON = Pattern.compile(":"); + + /** + *

+ * Note: The streams are NOT closed + *

+ */ + static void doPredict(BufferedReader reader, Writer writer, Model model) throws IOException { + int correct = 0; + int total = 0; + double error = 0; + double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0; + + final int nr_class = model.getNrClass(); + double[] prob_estimates = null; + int n; + final int nr_feature = model.getNrFeature(); + if (model.bias >= 0) + n = nr_feature + 1; + else + n = nr_feature; + + if (flag_predict_probability && !model.isProbabilityModel()) { + throw new IllegalArgumentException("probability output is only supported for logistic regression"); + } + + final Formatter out = new Formatter(writer); + + if (flag_predict_probability) { + final int[] labels = model.getLabels(); + prob_estimates = new double[nr_class]; + + printf(out, "labels"); + for (int j = 0; j < nr_class; j++) + printf(out, " %d", labels[j]); + printf(out, "\n"); + } + + String line = null; + while ((line = reader.readLine()) != null) { + final double[] nodes = new double[n]; + final StringTokenizer st = new StringTokenizer(line, " \t\n"); + double target_label; + try { + final String label = st.nextToken(); + target_label = atof(label); + } catch (final NoSuchElementException e) { + throw new RuntimeException("Wrong input format at line " + (total + 1), e); + } + + while (st.hasMoreTokens()) { + final String[] split = COLON.split(st.nextToken(), 2); + if (split == null || split.length < 2) { + throw new RuntimeException("Wrong input format at line " + (total + 1)); + } + + try { + final int idx = atoi(split[0]); + final double val = atof(split[1]); + + // feature indices larger than those in training are not + // used + if (idx <= nr_feature) { + nodes[idx - 1] = val; + } + } catch (final NumberFormatException e) { + throw new RuntimeException("Wrong input format at line " + (total + 1), e); + } + } + + if (model.bias >= 0) { + nodes[n - 1] = model.bias; + } + + double predict_label; + + if (flag_predict_probability) { + assert prob_estimates != null; + predict_label = Linear.predictProbability(model, nodes, prob_estimates); + printf(out, "%g", predict_label); + for (int j = 0; j < model.nr_class; j++) + printf(out, " %g", prob_estimates[j]); + printf(out, "\n"); + } else { + predict_label = Linear.predict(model, nodes); + printf(out, "%g\n", predict_label); + } + + if (predict_label == target_label) { + ++correct; + } + + error += (predict_label - target_label) * (predict_label - target_label); + sump += predict_label; + sumt += target_label; + sumpp += predict_label * predict_label; + sumtt += target_label * target_label; + sumpt += predict_label * target_label; + ++total; + } + + if (model.solverType.isSupportVectorRegression()) // + { + info("Mean squared error = %g (regression)%n", error / total); + info("Squared correlation coefficient = %g (regression)%n", // + ((total * sumpt - sump * sumt) * (total * sumpt - sump * sumt)) + / ((total * sumpp - sump * sump) * (total * sumtt - sumt * sumt))); + } else { + info("Accuracy = %g%% (%d/%d)%n", (double) correct / total * 100, correct, total); + } + } + + private static void exit_with_help() { + System.out + .printf("Usage: predict [options] test_file model_file output_file%n" // + + "options:%n" // + + "-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0); currently for logistic regression only%n" // + + "-q quiet mode (no outputs)%n"); + System.exit(1); + } + + public static void main(String[] argv) throws IOException { + int i; + + // parse options + for (i = 0; i < argv.length; i++) { + if (argv[i].charAt(0) != '-') + break; + ++i; + switch (argv[i - 1].charAt(1)) { + case 'b': + try { + flag_predict_probability = (atoi(argv[i]) != 0); + } catch (final NumberFormatException e) { + exit_with_help(); + } + break; + + case 'q': + i--; + Linear.disableDebugOutput(); + break; + + default: + System.err.printf("unknown option: -%d%n", argv[i - 1].charAt(1)); + exit_with_help(); + break; + } + } + if (i >= argv.length || argv.length <= i + 2) { + exit_with_help(); + } + + BufferedReader reader = null; + Writer writer = null; + try { + reader = new BufferedReader(new InputStreamReader(new FileInputStream(argv[i]), Linear.FILE_CHARSET)); + writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(argv[i + 2]), Linear.FILE_CHARSET)); + + final Model model = Linear.loadModel(new File(argv[i + 1])); + doPredict(reader, writer, model); + } finally { + closeQuietly(reader); + closeQuietly(writer); + } + } +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Problem.java b/src/main/java/de/bwaldvogel/denseliblinear/Problem.java new file mode 100644 index 0000000..2ede79d --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/Problem.java @@ -0,0 +1,62 @@ +package de.bwaldvogel.denseliblinear; + +import java.io.File; +import java.io.IOException; + +/** + *

+ * Describes the problem + *

+ * + * For example, if we have the following training data: + * + *
+ *  LABEL       ATTR1   ATTR2   ATTR3   ATTR4   ATTR5
+ *  -----       -----   -----   -----   -----   -----
+ *  1           0       0.1     0.2     0       0
+ *  2           0       0.1     0.3    -1.2     0
+ *  1           0.4     0       0       0       0
+ *  2           0       0.1     0       1.4     0.5
+ *  3          -0.1    -0.2     0.1     1.1     0.1
+ * 
+ *  and bias = 1, then the components of problem are:
+ * 
+ *  l = 5
+ *  n = 6
+ * 
+ *  y -> 1 2 1 2 3
+ * 
+ *  x -> [ ] -> (2,0.1) (3,0.2) (6,1) (-1,?)
+ *       [ ] -> (2,0.1) (3,0.3) (4,-1.2) (6,1) (-1,?)
+ *       [ ] -> (1,0.4) (6,1) (-1,?)
+ *       [ ] -> (2,0.1) (4,1.4) (5,0.5) (6,1) (-1,?)
+ *       [ ] -> (1,-0.1) (2,-0.2) (3,0.1) (4,1.1) (5,0.1) (6,1) (-1,?)
+ * 
+ */ +public class Problem { + + /** the number of training data */ + public int l; + + /** the number of features (including the bias feature if bias >= 0) */ + public int n; + + /** an array containing the target values */ + public double[] y; + + /** dense array of features */ + public double[][] x; + + /** + * If bias >= 0, we assume that one additional feature is added to the + * end of each data instance + */ + public double bias; + + /** + * see {@link Train#readProblem(File, double)} + */ + public static Problem readFromFile(File file, double bias) throws IOException, InvalidInputDataException { + return Train.readProblem(file, bias); + } +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/SolverMCSVM_CS.java b/src/main/java/de/bwaldvogel/denseliblinear/SolverMCSVM_CS.java new file mode 100644 index 0000000..20c13e9 --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/SolverMCSVM_CS.java @@ -0,0 +1,293 @@ +package de.bwaldvogel.denseliblinear; + +import static de.bwaldvogel.denseliblinear.Linear.copyOf; +import static de.bwaldvogel.denseliblinear.Linear.info; +import static de.bwaldvogel.denseliblinear.Linear.swap; + +/** + * A coordinate descent algorithm for multi-class support vector machines by + * Crammer and Singer + * + *
+ * min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i
+ * s.t. \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i
+ * 
+ * where e^m_i = 0 if y_i = m,
+ * e^m_i = 1 if y_i != m,
+ * C^m_i = C if m = y_i,
+ * C^m_i = 0 if m != y_i,
+ * and w_m(\alpha) = \sum_i \alpha^m_i x_i
+ * 
+ * Given:
+ * x, y, C
+ * eps is the stopping tolerance
+ * 
+ * solution will be put in w
+ * 
+ * See Appendix of LIBLINEAR paper, Fan et al. (2008)
+ * 
+ */ +class SolverMCSVM_CS { + + private final double[] B; + private final double[] C; + private final double eps; + private final double[] G; + private final int max_iter; + private final int w_size, l; + private final int nr_class; + private final Problem prob; + + public SolverMCSVM_CS(Problem prob, int nr_class, double[] C) { + this(prob, nr_class, C, 0.1); + } + + public SolverMCSVM_CS(Problem prob, int nr_class, double[] C, double eps) { + this(prob, nr_class, C, eps, 100000); + } + + public SolverMCSVM_CS(Problem prob, int nr_class, double[] weighted_C, double eps, int max_iter) { + this.w_size = prob.n; + this.l = prob.l; + this.nr_class = nr_class; + this.eps = eps; + this.max_iter = max_iter; + this.prob = prob; + this.C = weighted_C; + this.B = new double[nr_class]; + this.G = new double[nr_class]; + } + + private int GETI(int i) { + return (int) prob.y[i]; + } + + private boolean be_shrunk(int i, int m, int yi, double alpha_i, double minG) { + double bound = 0; + if (m == yi) + bound = C[GETI(i)]; + if (alpha_i == bound && G[m] < minG) + return true; + return false; + } + + public void solve(double[] w) { + int i, m, s; + int iter = 0; + final double[] alpha = new double[l * nr_class]; + final double[] alpha_new = new double[nr_class]; + final int[] index = new int[l]; + final double[] QD = new double[l]; + final int[] d_ind = new int[nr_class]; + final double[] d_val = new double[nr_class]; + final int[] alpha_index = new int[nr_class * l]; + final int[] y_index = new int[l]; + int active_size = l; + final int[] active_size_i = new int[l]; + double eps_shrink = Math.max(10.0 * eps, 1.0); // stopping tolerance for + // shrinking + boolean start_from_all = true; + + // Initial alpha can be set here. Note that + // sum_m alpha[i*nr_class+m] = 0, for all i=1,...,l-1 + // alpha[i*nr_class+m] <= C[GETI(i)] if prob->y[i] == m + // alpha[i*nr_class+m] <= 0 if prob->y[i] != m + // If initial alpha isn't zero, uncomment the for loop below to + // initialize w + for (i = 0; i < l * nr_class; i++) + alpha[i] = 0; + + for (i = 0; i < w_size * nr_class; i++) + w[i] = 0; + for (i = 0; i < l; i++) { + for (m = 0; m < nr_class; m++) + alpha_index[i * nr_class + m] = m; + QD[i] = 0; + for (final double val : prob.x[i]) { + QD[i] += val * val; + + // Uncomment the for loop if initial alpha isn't zero + // for(m=0; mindex-1)*nr_class+m] += alpha[i*nr_class+m]*val; + } + active_size_i[i] = nr_class; + y_index[i] = (int) prob.y[i]; + index[i] = i; + } + + final DoubleArrayPointer alpha_i = new DoubleArrayPointer(alpha, 0); + final IntArrayPointer alpha_index_i = new IntArrayPointer(alpha_index, 0); + + while (iter < max_iter) { + double stopping = Double.NEGATIVE_INFINITY; + + for (i = 0; i < active_size; i++) { + // int j = i+rand()%(active_size-i); + final int j = i + Linear.random.nextInt(active_size - i); + swap(index, i, j); + } + for (s = 0; s < active_size; s++) { + + i = index[s]; + final double Ai = QD[i]; + // double *alpha_i = &alpha[i*nr_class]; + alpha_i.setOffset(i * nr_class); + + // int *alpha_index_i = &alpha_index[i*nr_class]; + alpha_index_i.setOffset(i * nr_class); + + if (Ai > 0) { + for (m = 0; m < active_size_i[i]; m++) + G[m] = 1; + if (y_index[i] < active_size_i[i]) + G[y_index[i]] = 0; + + for (int ind = 0; ind < prob.x[i].length; ind++) { + // double *w_i = &w[ind*nr_class]; + final int w_offset = ind * nr_class; + for (m = 0; m < active_size_i[i]; m++) + // G[m] += w_i[alpha_index_i[m]]*(prob.x[i][ind); + G[m] += w[w_offset + alpha_index_i.get(m)] * prob.x[i][ind]; + + } + + double minG = Double.POSITIVE_INFINITY; + double maxG = Double.NEGATIVE_INFINITY; + for (m = 0; m < active_size_i[i]; m++) { + if (alpha_i.get(alpha_index_i.get(m)) < 0 && G[m] < minG) + minG = G[m]; + if (G[m] > maxG) + maxG = G[m]; + } + if (y_index[i] < active_size_i[i]) { + if (alpha_i.get((int) prob.y[i]) < C[GETI(i)] && G[y_index[i]] < minG) { + minG = G[y_index[i]]; + } + } + + for (m = 0; m < active_size_i[i]; m++) { + if (be_shrunk(i, m, y_index[i], alpha_i.get(alpha_index_i.get(m)), minG)) { + active_size_i[i]--; + while (active_size_i[i] > m) { + if (!be_shrunk(i, active_size_i[i], y_index[i], + alpha_i.get(alpha_index_i.get(active_size_i[i])), minG)) + { + swap(alpha_index_i, m, active_size_i[i]); + swap(G, m, active_size_i[i]); + if (y_index[i] == active_size_i[i]) + y_index[i] = m; + else if (y_index[i] == m) + y_index[i] = active_size_i[i]; + break; + } + active_size_i[i]--; + } + } + } + + if (active_size_i[i] <= 1) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } + + if (maxG - minG <= 1e-12) + continue; + else + stopping = Math.max(maxG - minG, stopping); + + for (m = 0; m < active_size_i[i]; m++) + B[m] = G[m] - Ai * alpha_i.get(alpha_index_i.get(m)); + + solve_sub_problem(Ai, y_index[i], C[GETI(i)], active_size_i[i], alpha_new); + int nz_d = 0; + for (m = 0; m < active_size_i[i]; m++) { + final double d = alpha_new[m] - alpha_i.get(alpha_index_i.get(m)); + alpha_i.set(alpha_index_i.get(m), alpha_new[m]); + if (Math.abs(d) >= 1e-12) { + d_ind[nz_d] = alpha_index_i.get(m); + d_val[nz_d] = d; + nz_d++; + } + } + + for (int ind = 0; ind < prob.x[i].length; ind++) { + // double *w_i = &w[ind*nr_class]; + final int w_offset = ind * nr_class; + for (m = 0; m < nz_d; m++) { + w[w_offset + d_ind[m]] += d_val[m] * prob.x[i][ind]; + } + } + } + } + + iter++; + + if (iter % 10 == 0) { + info("."); + } + + if (stopping < eps_shrink) { + if (stopping < eps && start_from_all == true) + break; + else { + active_size = l; + for (i = 0; i < l; i++) + active_size_i[i] = nr_class; + info("*"); + eps_shrink = Math.max(eps_shrink / 2, eps); + start_from_all = true; + } + } else + start_from_all = false; + } + + info("%noptimization finished, #iter = %d%n", iter); + if (iter >= max_iter) + info("%nWARNING: reaching max number of iterations%n"); + + // calculate objective value + double v = 0; + int nSV = 0; + for (i = 0; i < w_size * nr_class; i++) + v += w[i] * w[i]; + v = 0.5 * v; + for (i = 0; i < l * nr_class; i++) { + v += alpha[i]; + if (Math.abs(alpha[i]) > 0) + nSV++; + } + for (i = 0; i < l; i++) + v -= alpha[i * nr_class + (int) prob.y[i]]; + info("Objective value = %f%n", v); + info("nSV = %d%n", nSV); + + } + + private void solve_sub_problem(double A_i, int yi, double C_yi, int active_i, double[] alpha_new) { + + int r; + assert active_i <= B.length; // no padding + final double[] D = copyOf(B, active_i); + // clone(D, B, active_i); + + if (yi < active_i) + D[yi] += A_i * C_yi; + + // qsort(D, active_i, sizeof(double), compare_double); + ArraySorter.reversedMergesort(D); + + double beta = D[0] - A_i * C_yi; + for (r = 1; r < active_i && beta < r * D[r]; r++) + beta += D[r]; + beta /= r; + + for (r = 0; r < active_i; r++) { + if (r == yi) + alpha_new[r] = Math.min(C_yi, (beta - B[r]) / A_i); + else + alpha_new[r] = Math.min(0.0, (beta - B[r]) / A_i); + } + } +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/SolverType.java b/src/main/java/de/bwaldvogel/denseliblinear/SolverType.java new file mode 100644 index 0000000..a792743 --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/SolverType.java @@ -0,0 +1,129 @@ +package de.bwaldvogel.denseliblinear; + +import java.util.HashMap; +import java.util.Map; + + +public enum SolverType { + + /** + * L2-regularized logistic regression (primal) + * + * (fka L2_LR) + */ + L2R_LR(0, true, false), + + /** + * L2-regularized L2-loss support vector classification (dual) + * + * (fka L2LOSS_SVM_DUAL) + */ + L2R_L2LOSS_SVC_DUAL(1, false, false), + + /** + * L2-regularized L2-loss support vector classification (primal) + * + * (fka L2LOSS_SVM) + */ + L2R_L2LOSS_SVC(2, false, false), + + /** + * L2-regularized L1-loss support vector classification (dual) + * + * (fka L1LOSS_SVM_DUAL) + */ + L2R_L1LOSS_SVC_DUAL(3, false, false), + + /** + * multi-class support vector classification by Crammer and Singer + */ + MCSVM_CS(4, false, false), + + /** + * L1-regularized L2-loss support vector classification + * + * @since 1.5 + */ + L1R_L2LOSS_SVC(5, false, false), + + /** + * L1-regularized logistic regression + * + * @since 1.5 + */ + L1R_LR(6, true, false), + + /** + * L2-regularized logistic regression (dual) + * + * @since 1.7 + */ + L2R_LR_DUAL(7, true, false), + + /** + * L2-regularized L2-loss support vector regression (dual) + * + * @since 1.91 + */ + L2R_L2LOSS_SVR(11, false, true), + + /** + * L2-regularized L1-loss support vector regression (dual) + * + * @since 1.91 + */ + L2R_L2LOSS_SVR_DUAL(12, false, true), + + /** + * L2-regularized L2-loss support vector regression (primal) + * + * @since 1.91 + */ + L2R_L1LOSS_SVR_DUAL(13, false, true), + + ; + + private final boolean logisticRegressionSolver; + private final boolean supportVectorRegression; + private final int id; + + private SolverType( int id, boolean logisticRegressionSolver, boolean supportVectorRegression ) { + this.id = id; + this.logisticRegressionSolver = logisticRegressionSolver; + this.supportVectorRegression = supportVectorRegression; + } + + private static Map SOLVERS_BY_ID = new HashMap(); + static { + for (SolverType solverType : SolverType.values()) { + SolverType old = SOLVERS_BY_ID.put(Integer.valueOf(solverType.getId()), solverType); + if (old != null) throw new Error("duplicate solver type ID: " + solverType.getId()); + } + } + + public int getId() { + return id; + } + + public static SolverType getById(int id) { + SolverType solverType = SOLVERS_BY_ID.get(Integer.valueOf(id)); + if (solverType == null) { + throw new RuntimeException("found no solvertype for id " + id); + } + return solverType; + } + + /** + * @since 1.9 + */ + public boolean isLogisticRegressionSolver() { + return logisticRegressionSolver; + } + + /** + * @since 1.91 + */ + public boolean isSupportVectorRegression() { + return supportVectorRegression; + } +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Train.java b/src/main/java/de/bwaldvogel/denseliblinear/Train.java new file mode 100644 index 0000000..2a690fc --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/Train.java @@ -0,0 +1,420 @@ +package de.bwaldvogel.denseliblinear; + +import static de.bwaldvogel.denseliblinear.Linear.atof; +import static de.bwaldvogel.denseliblinear.Linear.atoi; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.StringTokenizer; + +public class Train { + + public static void main(String[] args) throws IOException, InvalidInputDataException { + new Train().run(args); + } + + private double bias = 1; + private boolean cross_validation = false; + private String inputFilename; + private String modelFilename; + private int nr_fold; + private Parameter param = null; + private Problem prob = null; + + private void do_cross_validation() { + + double total_error = 0; + double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; + final double[] target = new double[prob.l]; + + long start, stop; + start = System.currentTimeMillis(); + Linear.crossValidation(prob, param, nr_fold, target); + stop = System.currentTimeMillis(); + System.out.println("time: " + (stop - start) + " ms"); + + if (param.solverType.isSupportVectorRegression()) { + for (int i = 0; i < prob.l; i++) { + final double y = prob.y[i]; + final double v = target[i]; + total_error += (v - y) * (v - y); + sumv += v; + sumy += y; + sumvv += v * v; + sumyy += y * y; + sumvy += v * y; + } + System.out.printf("Cross Validation Mean squared error = %g%n", total_error / prob.l); + System.out.printf("Cross Validation Squared correlation coefficient = %g%n", // + ((prob.l * sumvy - sumv * sumy) * (prob.l * sumvy - sumv * sumy)) + / ((prob.l * sumvv - sumv * sumv) * (prob.l * sumyy - sumy * sumy))); + } else { + int total_correct = 0; + for (int i = 0; i < prob.l; i++) + if (target[i] == prob.y[i]) + ++total_correct; + + System.out.printf("correct: %d%n", total_correct); + System.out.printf("Cross Validation Accuracy = %g%%%n", 100.0 * total_correct / prob.l); + } + } + + private void exit_with_help() { + System.out.printf("Usage: train [options] training_set_file [model_file]%n" // + + "options:%n" + + "-s type : set type of solver (default 1)%n" + + " for multi-class classification%n" + + " 0 -- L2-regularized logistic regression (primal)%n" + + " 1 -- L2-regularized L2-loss support vector classification (dual)%n" + + " 2 -- L2-regularized L2-loss support vector classification (primal)%n" + + " 3 -- L2-regularized L1-loss support vector classification (dual)%n" + + " 4 -- support vector classification by Crammer and Singer%n" + + " 5 -- L1-regularized L2-loss support vector classification%n" + + " 6 -- L1-regularized logistic regression%n" + + " 7 -- L2-regularized logistic regression (dual)%n" + + " for regression%n" + + " 11 -- L2-regularized L2-loss support vector regression (primal)%n" + + " 12 -- L2-regularized L2-loss support vector regression (dual)%n" + + " 13 -- L2-regularized L1-loss support vector regression (dual)%n" + + "-c cost : set the parameter C (default 1)%n" + + "-p epsilon : set the epsilon in loss function of SVR (default 0.1)%n" + + "-e epsilon : set tolerance of termination criterion%n" + + " -s 0 and 2%n" + " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,%n" + + " where f is the primal function and pos/neg are # of%n" + + " positive/negative data (default 0.01)%n" + " -s 11%n" + + " |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.001)%n" + + " -s 1, 3, 4 and 7%n" + " Dual maximal violation <= eps; similar to libsvm (default 0.1)%n" + + " -s 5 and 6%n" + + " |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,%n" + + " where f is the primal function (default 0.01)%n" + + " -s 12 and 13\n" + + " |f'(alpha)|_1 <= eps |f'(alpha0)|,\n" + + " where f is the dual function (default 0.1)\n" + + "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)%n" + + "-wi weight: weights adjust the parameter C of different classes (see README for details)%n" + + "-v n: n-fold cross validation mode%n" + + "-q : quiet mode (no outputs)%n"); + System.exit(1); + } + + Problem getProblem() { + return prob; + } + + double getBias() { + return bias; + } + + Parameter getParameter() { + return param; + } + + void parse_command_line(String argv[]) { + int i; + + // eps: see setting below + param = new Parameter(SolverType.L2R_L2LOSS_SVC_DUAL, 1, Double.POSITIVE_INFINITY, 0.1); + // default values + bias = -1; + cross_validation = false; + + // parse options + for (i = 0; i < argv.length; i++) { + if (argv[i].charAt(0) != '-') + break; + if (++i >= argv.length) + exit_with_help(); + switch (argv[i - 1].charAt(1)) { + case 's': + param.solverType = SolverType.getById(atoi(argv[i])); + break; + case 'c': + param.setC(atof(argv[i])); + break; + case 'p': + param.setP(atof(argv[i])); + break; + case 'e': + param.setEps(atof(argv[i])); + break; + case 'B': + bias = atof(argv[i]); + break; + case 'w': + final int weightLabel = atoi(argv[i - 1].substring(2)); + final double weight = atof(argv[i]); + param.weightLabel = addToArray(param.weightLabel, weightLabel); + param.weight = addToArray(param.weight, weight); + break; + case 'v': + cross_validation = true; + nr_fold = atoi(argv[i]); + if (nr_fold < 2) { + System.err.println("n-fold cross validation: n must >= 2"); + exit_with_help(); + } + break; + case 'q': + i--; + Linear.disableDebugOutput(); + break; + default: + System.err.println("unknown option"); + exit_with_help(); + } + } + + // determine filenames + + if (i >= argv.length) + exit_with_help(); + + inputFilename = argv[i]; + + if (i < argv.length - 1) + modelFilename = argv[i + 1]; + else { + int p = argv[i].lastIndexOf('/'); + ++p; // whew... + modelFilename = argv[i].substring(p) + ".model"; + } + + if (param.eps == Double.POSITIVE_INFINITY) { + switch (param.solverType) { + case L2R_LR: + case L2R_L2LOSS_SVC: + param.setEps(0.01); + break; + case L2R_L2LOSS_SVR: + param.setEps(0.001); + break; + case L2R_L2LOSS_SVC_DUAL: + case L2R_L1LOSS_SVC_DUAL: + case MCSVM_CS: + case L2R_LR_DUAL: + param.setEps(0.1); + break; + case L1R_L2LOSS_SVC: + case L1R_LR: + param.setEps(0.01); + break; + case L2R_L1LOSS_SVR_DUAL: + case L2R_L2LOSS_SVR_DUAL: + param.setEps(0.1); + break; + default: + throw new IllegalStateException("unknown solver type: " + param.solverType); + } + } + } + + /** + * reads a problem from LibSVM format + * + * @param file + * the SVM file + * @throws IOException + * obviously in case of any I/O exception ;) + * @throws InvalidInputDataException + * if the input file is not correctly formatted + */ + static int readProblemFeatureDim(File file) throws IOException, InvalidInputDataException { + final BufferedReader fp = new BufferedReader(new FileReader(file)); + int max_index = 0; + int lineNr = 0; + + try { + while (true) { + final String line = fp.readLine(); + if (line == null) + break; + lineNr++; + + final StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:"); + String token; + try { + token = st.nextToken(); + } catch (final NoSuchElementException e) { + throw new InvalidInputDataException("empty line", file, lineNr, e); + } + + final int m = st.countTokens() / 2; + + int indexBefore = 0; + for (int j = 0; j < m; j++) { + token = st.nextToken(); + int index; + try { + index = atoi(token); + } catch (final NumberFormatException e) { + throw new InvalidInputDataException("invalid index: " + token, file, lineNr, e); + } + + // assert that indices are valid and sorted + if (index < 0) + throw new InvalidInputDataException("invalid index: " + index, file, lineNr); + if (index <= indexBefore) + throw new InvalidInputDataException("indices must be sorted in ascending order", file, lineNr); + indexBefore = index; + + token = st.nextToken(); + + if (index > max_index) { + max_index = index; + } + } + } + + return max_index; + } finally { + fp.close(); + } + } + + /** + * reads a problem from LibSVM format + * + * @param file + * the SVM file + * @throws IOException + * obviously in case of any I/O exception ;) + * @throws InvalidInputDataException + * if the input file is not correctly formatted + */ + public static Problem readProblem(File file, double bias) throws IOException, InvalidInputDataException { + final BufferedReader fp = new BufferedReader(new FileReader(file)); + final List vy = new ArrayList(); + final List vx = new ArrayList(); + + int lineNr = 0; + + final int w_size = readProblemFeatureDim(file); + + try { + while (true) { + final String line = fp.readLine(); + if (line == null) + break; + lineNr++; + + final StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:"); + String token; + try { + token = st.nextToken(); + } catch (final NoSuchElementException e) { + throw new InvalidInputDataException("empty line", file, lineNr, e); + } + + try { + vy.add(atof(token)); + } catch (final NumberFormatException e) { + throw new InvalidInputDataException("invalid label: " + token, file, lineNr, e); + } + + final int m = st.countTokens() / 2; + double[] x; + if (bias >= 0) { + x = new double[w_size + 1]; + } else { + x = new double[w_size]; + } + int indexBefore = 0; + for (int j = 0; j < m; j++) { + + token = st.nextToken(); + int index; + try { + index = atoi(token); + } catch (final NumberFormatException e) { + throw new InvalidInputDataException("invalid index: " + token, file, lineNr, e); + } + + // assert that indices are valid and sorted + if (index < 0) + throw new InvalidInputDataException("invalid index: " + index, file, lineNr); + if (index <= indexBefore) + throw new InvalidInputDataException("indices must be sorted in ascending order", file, lineNr); + indexBefore = index; + + token = st.nextToken(); + try { + final double value = atof(token); + x[index - 1] = value; + } catch (final NumberFormatException e) { + throw new InvalidInputDataException("invalid value: " + token, file, lineNr); + } + } + + vx.add(x); + } + + return constructProblem(vy, vx, w_size, bias); + } finally { + fp.close(); + } + } + + void readProblem(String filename) throws IOException, InvalidInputDataException { + prob = Train.readProblem(new File(filename), bias); + } + + private static int[] addToArray(int[] array, int newElement) { + final int length = array != null ? array.length : 0; + final int[] newArray = new int[length + 1]; + if (array != null && length > 0) { + System.arraycopy(array, 0, newArray, 0, length); + } + newArray[length] = newElement; + return newArray; + } + + private static double[] addToArray(double[] array, double newElement) { + final int length = array != null ? array.length : 0; + final double[] newArray = new double[length + 1]; + if (array != null && length > 0) { + System.arraycopy(array, 0, newArray, 0, length); + } + newArray[length] = newElement; + return newArray; + } + + private static Problem constructProblem(List vy, List vx, int max_index, double bias) { + final Problem prob = new Problem(); + prob.bias = bias; + prob.l = vy.size(); + prob.n = max_index; + if (bias >= 0) { + prob.n++; + } + prob.x = new double[prob.l][]; + for (int i = 0; i < prob.l; i++) { + prob.x[i] = vx.get(i); + + if (bias >= 0) { + prob.x[i][max_index] = bias; + } + } + + prob.y = new double[prob.l]; + for (int i = 0; i < prob.l; i++) + prob.y[i] = vy.get(i).doubleValue(); + + return prob; + } + + private void run(String[] args) throws IOException, InvalidInputDataException { + parse_command_line(args); + readProblem(inputFilename); + if (cross_validation) + do_cross_validation(); + else { + final Model model = Linear.train(prob, param); + Linear.saveModel(new File(modelFilename), model); + } + } +} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Tron.java b/src/main/java/de/bwaldvogel/denseliblinear/Tron.java new file mode 100644 index 0000000..1235175 --- /dev/null +++ b/src/main/java/de/bwaldvogel/denseliblinear/Tron.java @@ -0,0 +1,260 @@ +package de.bwaldvogel.denseliblinear; + +import static de.bwaldvogel.denseliblinear.Linear.info; + +/** + * Trust Region Newton Method optimization + */ +class Tron { + + private final Function fun_obj; + + private final double eps; + + private final int max_iter; + + public Tron( final Function fun_obj ) { + this(fun_obj, 0.1); + } + + public Tron( final Function fun_obj, double eps ) { + this(fun_obj, eps, 1000); + } + + public Tron( final Function fun_obj, double eps, int max_iter ) { + this.fun_obj = fun_obj; + this.eps = eps; + this.max_iter = max_iter; + } + + void tron(double[] w) { + // Parameters for updating the iterates. + double eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75; + + // Parameters for updating the trust region size delta. + double sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4; + + int n = fun_obj.get_nr_variable(); + int i, cg_iter; + double delta, snorm, one = 1.0; + double alpha, f, fnew, prered, actred, gs; + int search = 1, iter = 1; + double[] s = new double[n]; + double[] r = new double[n]; + double[] w_new = new double[n]; + double[] g = new double[n]; + + for (i = 0; i < n; i++) + w[i] = 0; + + f = fun_obj.fun(w); + fun_obj.grad(w, g); + delta = euclideanNorm(g); + double gnorm1 = delta; + double gnorm = gnorm1; + + if (gnorm <= eps * gnorm1) search = 0; + + iter = 1; + + while (iter <= max_iter && search != 0) { + cg_iter = trcg(delta, g, s, r); + + System.arraycopy(w, 0, w_new, 0, n); + daxpy(one, s, w_new); + + gs = dot(g, s); + prered = -0.5 * (gs - dot(s, r)); + fnew = fun_obj.fun(w_new); + + // Compute the actual reduction. + actred = f - fnew; + + // On the first iteration, adjust the initial step bound. + snorm = euclideanNorm(s); + if (iter == 1) delta = Math.min(delta, snorm); + + // Compute prediction alpha*snorm of the step. + if (fnew - f - gs <= 0) + alpha = sigma3; + else + alpha = Math.max(sigma1, -0.5 * (gs / (fnew - f - gs))); + + // Update the trust region bound according to the ratio of actual to + // predicted reduction. + if (actred < eta0 * prered) + delta = Math.min(Math.max(alpha, sigma1) * snorm, sigma2 * delta); + else if (actred < eta1 * prered) + delta = Math.max(sigma1 * delta, Math.min(alpha * snorm, sigma2 * delta)); + else if (actred < eta2 * prered) + delta = Math.max(sigma1 * delta, Math.min(alpha * snorm, sigma3 * delta)); + else + delta = Math.max(delta, Math.min(alpha * snorm, sigma3 * delta)); + + info("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d%n", iter, actred, prered, delta, f, gnorm, cg_iter); + + if (actred > eta0 * prered) { + iter++; + System.arraycopy(w_new, 0, w, 0, n); + f = fnew; + fun_obj.grad(w, g); + + gnorm = euclideanNorm(g); + if (gnorm <= eps * gnorm1) break; + } + if (f < -1.0e+32) { + info("WARNING: f < -1.0e+32%n"); + break; + } + if (Math.abs(actred) <= 0 && prered <= 0) { + info("WARNING: actred and prered <= 0%n"); + break; + } + if (Math.abs(actred) <= 1.0e-12 * Math.abs(f) && Math.abs(prered) <= 1.0e-12 * Math.abs(f)) { + info("WARNING: actred and prered too small%n"); + break; + } + } + } + + private int trcg(double delta, double[] g, double[] s, double[] r) { + int n = fun_obj.get_nr_variable(); + double one = 1; + double[] d = new double[n]; + double[] Hd = new double[n]; + double rTr, rnewTrnew, cgtol; + + for (int i = 0; i < n; i++) { + s[i] = 0; + r[i] = -g[i]; + d[i] = r[i]; + } + cgtol = 0.1 * euclideanNorm(g); + + int cg_iter = 0; + rTr = dot(r, r); + + while (true) { + if (euclideanNorm(r) <= cgtol) break; + cg_iter++; + fun_obj.Hv(d, Hd); + + double alpha = rTr / dot(d, Hd); + daxpy(alpha, d, s); + if (euclideanNorm(s) > delta) { + info("cg reaches trust region boundary%n"); + alpha = -alpha; + daxpy(alpha, d, s); + + double std = dot(s, d); + double sts = dot(s, s); + double dtd = dot(d, d); + double dsq = delta * delta; + double rad = Math.sqrt(std * std + dtd * (dsq - sts)); + if (std >= 0) + alpha = (dsq - sts) / (std + rad); + else + alpha = (rad - std) / dtd; + daxpy(alpha, d, s); + alpha = -alpha; + daxpy(alpha, Hd, r); + break; + } + alpha = -alpha; + daxpy(alpha, Hd, r); + rnewTrnew = dot(r, r); + double beta = rnewTrnew / rTr; + scale(beta, d); + daxpy(one, r, d); + rTr = rnewTrnew; + } + + return (cg_iter); + } + + /** + * constant times a vector plus a vector + * + *
+     * vector2 += constant * vector1
+     * 
+ * + * @since 1.8 + */ + private static void daxpy(double constant, double vector1[], double vector2[]) { + if (constant == 0) return; + + assert vector1.length == vector2.length; + for (int i = 0; i < vector1.length; i++) { + vector2[i] += constant * vector1[i]; + } + } + + /** + * returns the dot product of two vectors + * + * @since 1.8 + */ + private static double dot(double vector1[], double vector2[]) { + + double product = 0; + assert vector1.length == vector2.length; + for (int i = 0; i < vector1.length; i++) { + product += vector1[i] * vector2[i]; + } + return product; + + } + + /** + * returns the euclidean norm of a vector + * + * @since 1.8 + */ + private static double euclideanNorm(double vector[]) { + + int n = vector.length; + + if (n < 1) { + return 0; + } + + if (n == 1) { + return Math.abs(vector[0]); + } + + // this algorithm is (often) more accurate than just summing up the squares and taking the square-root afterwards + + double scale = 0; // scaling factor that is factored out + double sum = 1; // basic sum of squares from which scale has been factored out + for (int i = 0; i < n; i++) { + if (vector[i] != 0) { + double abs = Math.abs(vector[i]); + // try to get the best scaling factor + if (scale < abs) { + double t = scale / abs; + sum = 1 + sum * (t * t); + scale = abs; + } else { + double t = abs / scale; + sum += t * t; + } + } + } + + return scale * Math.sqrt(sum); + } + + /** + * scales a vector by a constant + * + * @since 1.8 + */ + private static void scale(double constant, double vector[]) { + if (constant == 1.0) return; + for (int i = 0; i < vector.length; i++) { + vector[i] *= constant; + } + + } +} diff --git a/src/test/java/de/bwaldvogel/denseliblinear/ArrayPointerTest.java b/src/test/java/de/bwaldvogel/denseliblinear/ArrayPointerTest.java new file mode 100644 index 0000000..d524c82 --- /dev/null +++ b/src/test/java/de/bwaldvogel/denseliblinear/ArrayPointerTest.java @@ -0,0 +1,63 @@ +package de.bwaldvogel.denseliblinear; + +import static org.fest.assertions.Assertions.assertThat; +import static org.fest.assertions.Fail.fail; + +import org.junit.Test; + +import de.bwaldvogel.denseliblinear.DoubleArrayPointer; +import de.bwaldvogel.denseliblinear.IntArrayPointer; + + +public class ArrayPointerTest { + + @Test + public void testGetIntArrayPointer() { + int[] foo = new int[] {1, 2, 3, 4, 6}; + IntArrayPointer pFoo = new IntArrayPointer(foo, 2); + assertThat(pFoo.get(0)).isEqualTo(3); + assertThat(pFoo.get(1)).isEqualTo(4); + assertThat(pFoo.get(2)).isEqualTo(6); + try { + pFoo.get(3); + fail("ArrayIndexOutOfBoundsException expected"); + } catch (ArrayIndexOutOfBoundsException e) {} + } + + @Test + public void testSetIntArrayPointer() { + int[] foo = new int[] {1, 2, 3, 4, 6}; + IntArrayPointer pFoo = new IntArrayPointer(foo, 2); + pFoo.set(2, 5); + assertThat(foo).isEqualTo(new int[] {1, 2, 3, 4, 5}); + try { + pFoo.set(3, 0); + fail("ArrayIndexOutOfBoundsException expected"); + } catch (ArrayIndexOutOfBoundsException e) {} + } + + @Test + public void testGetDoubleArrayPointer() { + double[] foo = new double[] {1, 2, 3, 4, 6}; + DoubleArrayPointer pFoo = new DoubleArrayPointer(foo, 2); + assertThat(pFoo.get(0)).isEqualTo(3); + assertThat(pFoo.get(1)).isEqualTo(4); + assertThat(pFoo.get(2)).isEqualTo(6); + try { + pFoo.get(3); + fail("ArrayIndexOutOfBoundsException expected"); + } catch (ArrayIndexOutOfBoundsException e) {} + } + + @Test + public void testSetDoubleArrayPointer() { + double[] foo = new double[] {1, 2, 3, 4, 6}; + DoubleArrayPointer pFoo = new DoubleArrayPointer(foo, 2); + pFoo.set(2, 5); + assertThat(foo).isEqualTo(new double[] {1, 2, 3, 4, 5}); + try { + pFoo.set(3, 0); + fail("ArrayIndexOutOfBoundsException expected"); + } catch (ArrayIndexOutOfBoundsException e) {} + } +} diff --git a/src/test/java/de/bwaldvogel/denseliblinear/ArraySorterTest.java b/src/test/java/de/bwaldvogel/denseliblinear/ArraySorterTest.java new file mode 100644 index 0000000..38fb53a --- /dev/null +++ b/src/test/java/de/bwaldvogel/denseliblinear/ArraySorterTest.java @@ -0,0 +1,58 @@ +package de.bwaldvogel.denseliblinear; + +import static de.bwaldvogel.denseliblinear.Linear.swap; +import static org.fest.assertions.Assertions.assertThat; + +import java.util.Random; + +import org.junit.Test; + +import de.bwaldvogel.denseliblinear.ArraySorter; + + +public class ArraySorterTest { + + private Random random = new Random(); + + private void assertDescendingOrder(double[] array) { + double before = array[0]; + for (double d : array) { + // accept that case + if (d == 0.0 && before == -0.0) continue; + + assertThat(d).isLessThanOrEqualTo(before); + before = d; + } + } + + private void shuffleArray(double[] array) { + + for (int i = 0; i < array.length; i++) { + int j = random.nextInt(array.length); + swap(array, i, j); + } + } + + @Test + public void testReversedMergesort() { + + for (int k = 1; k <= 16 * 8096; k *= 2) { + // create random array + double[] array = new double[k]; + for (int i = 0; i < array.length; i++) { + array[i] = random.nextDouble(); + } + + ArraySorter.reversedMergesort(array); + assertDescendingOrder(array); + } + } + + @Test + public void testReversedMergesortWithMeanValues() { + double[] array = new double[] {1.0, -0.0, -1.1, 2.0, 3.0, 0.0, 4.0, -0.0, 0.0}; + shuffleArray(array); + ArraySorter.reversedMergesort(array); + assertDescendingOrder(array); + } +} diff --git a/src/test/java/de/bwaldvogel/denseliblinear/LinearTest.java b/src/test/java/de/bwaldvogel/denseliblinear/LinearTest.java new file mode 100644 index 0000000..753715f --- /dev/null +++ b/src/test/java/de/bwaldvogel/denseliblinear/LinearTest.java @@ -0,0 +1,517 @@ +package de.bwaldvogel.denseliblinear; + +import static org.fest.assertions.Assertions.assertThat; +import static org.fest.assertions.Fail.fail; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.io.File; +import java.io.IOException; +import java.io.Writer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.TreeSet; + +import org.fest.assertions.Delta; +import org.junit.BeforeClass; +import org.junit.Test; +import org.powermock.api.mockito.PowerMockito; + +public class LinearTest { + + private static Random random = new Random(12345); + + @BeforeClass + public static void disableDebugOutput() { + // Linear.disableDebugOutput(); + } + + public static Model createRandomModel() { + final Model model = new Model(); + model.solverType = SolverType.L2R_LR; + model.bias = 2; + model.label = new int[] { 1, Integer.MAX_VALUE, 2 }; + model.w = new double[model.label.length * 300]; + for (int i = 0; i < model.w.length; i++) { + // precision should be at least 1e-4 + model.w[i] = Math.round(random.nextDouble() * 100000.0) / 10000.0; + } + + // force at least one value to be zero + model.w[random.nextInt(model.w.length)] = 0.0; + model.w[random.nextInt(model.w.length)] = -0.0; + + model.nr_feature = model.w.length / model.label.length - 1; + model.nr_class = model.label.length; + return model; + } + + public static Problem createRandomProblem(int numClasses) { + final Problem prob = new Problem(); + prob.bias = -1; + prob.l = random.nextInt(100) + 1; + prob.n = random.nextInt(100) + 1; + prob.x = new double[prob.l][]; + prob.y = new double[prob.l]; + + for (int i = 0; i < prob.l; i++) { + + prob.y[i] = random.nextInt(numClasses); + + final Set randomNumbers = new TreeSet(); + final int num = random.nextInt(prob.n); + for (int j = 0; j < num; j++) { + randomNumbers.add(random.nextInt(prob.n)); + } + final List randomIndices = new ArrayList(randomNumbers); + Collections.sort(randomIndices); + + prob.x[i] = new double[prob.n]; + for (int j = 0; j < randomIndices.size(); j++) { + prob.x[i][randomIndices.get(j)] = random.nextDouble(); + } + } + return prob; + } + + /** + * create a very simple problem and check if the clearly separated examples + * are recognized as such + */ + @Test + public void testTrainPredict() { + final Problem prob = new Problem(); + prob.bias = -1; + prob.l = 4; + prob.n = 4; + prob.x = new double[4][4]; + + prob.x[0][0] = 1; + prob.x[0][1] = 1; + + prob.x[1][2] = 1; + prob.x[2][2] = 1; + + prob.x[3][0] = 2; + prob.x[3][1] = 1; + prob.x[3][3] = 1; + + prob.y = new double[4]; + prob.y[0] = 0; + prob.y[1] = 1; + prob.y[2] = 1; + prob.y[3] = 0; + + for (final SolverType solver : SolverType.values()) { + for (double C = 0.1; C <= 100.; C *= 1.2) { + // compared the behavior with the C version + if (C < 0.2) + if (solver == SolverType.L1R_L2LOSS_SVC) + continue; + if (C < 0.7) + if (solver == SolverType.L1R_LR) + continue; + + if (solver.isSupportVectorRegression()) { + continue; + } + + final Parameter param = new Parameter(solver, C, 0.1, 0.1); + final Model model = Linear.train(prob, param); + + final double[] featureWeights = model.getFeatureWeights(); + if (solver == SolverType.MCSVM_CS) { + assertThat(featureWeights.length).isEqualTo(8); + } else { + assertThat(featureWeights.length).isEqualTo(4); + } + + int i = 0; + for (final double value : prob.y) { + final double prediction = Linear.predict(model, prob.x[i]); + assertThat(prediction).as("prediction with solver " + solver).isEqualTo(value); + if (model.isProbabilityModel()) { + final double[] estimates = new double[model.getNrClass()]; + final double probabilityPrediction = Linear.predictProbability(model, prob.x[i], estimates); + assertThat(probabilityPrediction).isEqualTo(prediction); + assertThat(estimates[(int) probabilityPrediction]).isGreaterThanOrEqualTo( + 1.0 / model.getNrClass()); + double estimationSum = 0; + for (final double estimate : estimates) { + estimationSum += estimate; + } + assertThat(estimationSum).isEqualTo(1.0, Delta.delta(0.001)); + } + i++; + } + } + } + } + + @Test + public void testCrossValidation() throws Exception { + + final int numClasses = random.nextInt(10) + 1; + + final Problem prob = createRandomProblem(numClasses); + + final Parameter param = new Parameter(SolverType.L2R_LR, 10, 0.01); + final int nr_fold = 10; + final double[] target = new double[prob.l]; + Linear.crossValidation(prob, param, nr_fold, target); + + for (final double clazz : target) { + assertThat(clazz).isGreaterThanOrEqualTo(0).isLessThan(numClasses); + } + } + + @Test + public void testLoadSaveModel() throws Exception { + + Model model = null; + for (final SolverType solverType : SolverType.values()) { + model = createRandomModel(); + model.solverType = solverType; + + final File tempFile = File.createTempFile("liblinear", "modeltest"); + tempFile.deleteOnExit(); + Linear.saveModel(tempFile, model); + + final Model loadedModel = Linear.loadModel(tempFile); + assertThat(loadedModel).isEqualTo(model); + } + } + + @Test + public void testPredictProbabilityWrongSolver() throws Exception { + final Problem prob = new Problem(); + prob.l = 1; + prob.n = 1; + prob.x = new double[prob.l][prob.n]; + prob.y = new double[prob.l]; + for (int i = 0; i < prob.l; i++) { + prob.y[i] = i; + } + + final SolverType solverType = SolverType.L2R_L1LOSS_SVC_DUAL; + final Parameter param = new Parameter(solverType, 10, 0.1); + final Model model = Linear.train(prob, param); + try { + Linear.predictProbability(model, prob.x[0], new double[1]); + fail("IllegalArgumentException expected"); + } catch (final IllegalArgumentException e) { + assertThat(e.getMessage()).isEqualTo("probability output is only supported for logistic regression." // + + " This is currently only supported by the following solvers:" // + + " L2R_LR, L1R_LR, L2R_LR_DUAL"); + } + } + + @Test + public void testRealloc() { + + int[] f = new int[] { 1, 2, 3 }; + f = Linear.copyOf(f, 5); + f[3] = 4; + f[4] = 5; + assertThat(f).isEqualTo(new int[] { 1, 2, 3, 4, 5 }); + } + + @Test + public void testAtoi() { + assertThat(Linear.atoi("+25")).isEqualTo(25); + assertThat(Linear.atoi("-345345")).isEqualTo(-345345); + assertThat(Linear.atoi("+0")).isEqualTo(0); + assertThat(Linear.atoi("0")).isEqualTo(0); + assertThat(Linear.atoi("2147483647")).isEqualTo(Integer.MAX_VALUE); + assertThat(Linear.atoi("-2147483648")).isEqualTo(Integer.MIN_VALUE); + } + + @Test(expected = NumberFormatException.class) + public void testAtoiInvalidData() { + Linear.atoi("+"); + } + + @Test(expected = NumberFormatException.class) + public void testAtoiInvalidData2() { + Linear.atoi("abc"); + } + + @Test(expected = NumberFormatException.class) + public void testAtoiInvalidData3() { + Linear.atoi(" "); + } + + @Test + public void testAtof() { + assertThat(Linear.atof("+25")).isEqualTo(25); + assertThat(Linear.atof("-25.12345678")).isEqualTo(-25.12345678); + assertThat(Linear.atof("0.345345299")).isEqualTo(0.345345299); + } + + @Test(expected = NumberFormatException.class) + public void testAtofInvalidData() { + Linear.atof("0.5t"); + } + + @Test + public void testSaveModelWithIOException() throws Exception { + final Model model = createRandomModel(); + + final Writer out = PowerMockito.mock(Writer.class); + + final IOException ioException = new IOException("some reason"); + + doThrow(ioException).when(out).flush(); + + try { + Linear.saveModel(out, model); + fail("IOException expected"); + } catch (final IOException e) { + assertThat(e).isEqualTo(ioException); + } + + verify(out).flush(); + verify(out, times(1)).close(); + } + + /** + * compared input/output values with the C version (1.51) + * + *
+	 * IN:
+	 * res prob.l = 4
+	 * res prob.n = 4
+	 * 0: (2,1) (4,1)
+	 * 1: (1,1)
+	 * 2: (3,1)
+	 * 3: (2,2) (3,1) (4,1)
+	 * 
+	 * TRANSPOSED:
+	 * 
+	 * res prob.l = 4
+	 * res prob.n = 4
+	 * 0: (2,1)
+	 * 1: (1,1) (4,2)
+	 * 2: (3,1) (4,1)
+	 * 3: (1,1) (4,1)
+	 * 
+ */ + @Test + public void testTranspose() throws Exception { + final Problem prob = new Problem(); + prob.bias = -1; + prob.l = 4; + prob.n = 4; + prob.x = new double[4][4]; + + prob.x[0][1] = 1; + prob.x[0][3] = 1; + + prob.x[1][0] = 1; + prob.x[2][2] = 1; + + prob.x[3][1] = 2; + prob.x[3][2] = 1; + prob.x[3][3] = 1; + + prob.y = new double[4]; + prob.y[0] = 0; + prob.y[1] = 1; + prob.y[2] = 1; + prob.y[3] = 0; + + final Problem transposed = Linear.transpose(prob); + + assertThat(transposed.x[0].length).isEqualTo(4); + assertThat(transposed.x[1].length).isEqualTo(4); + assertThat(transposed.x[2].length).isEqualTo(4); + assertThat(transposed.x[3].length).isEqualTo(4); + + assertThat(transposed.x[0][1]).isEqualTo(1); + + assertThat(transposed.x[1][0]).isEqualTo(1); + assertThat(transposed.x[1][3]).isEqualTo(2); + + assertThat(transposed.x[2][2]).isEqualTo(1); + assertThat(transposed.x[2][3]).isEqualTo(1); + + assertThat(transposed.x[3][0]).isEqualTo(1); + assertThat(transposed.x[3][3]).isEqualTo(1); + + assertThat(transposed.y).isEqualTo(prob.y); + } + + /** + * + * compared input/output values with the C version (1.51) + * + *
+	 * IN:
+	 * res prob.l = 5
+	 * res prob.n = 10
+	 * 0: (1,7) (3,3) (5,2)
+	 * 1: (2,1) (4,5) (5,3) (7,4) (8,2)
+	 * 2: (1,9) (3,1) (5,1) (10,7)
+	 * 3: (1,2) (2,2) (3,9) (4,7) (5,8) (6,1) (7,5) (8,4)
+	 * 4: (3,1) (10,3)
+	 * 
+	 * TRANSPOSED:
+	 * 
+	 * res prob.l = 5
+	 * res prob.n = 10
+	 * 0: (1,7) (3,9) (4,2)
+	 * 1: (2,1) (4,2)
+	 * 2: (1,3) (3,1) (4,9) (5,1)
+	 * 3: (2,5) (4,7)
+	 * 4: (1,2) (2,3) (3,1) (4,8)
+	 * 5: (4,1)
+	 * 6: (2,4) (4,5)
+	 * 7: (2,2) (4,4)
+	 * 8:
+	 * 9: (3,7) (5,3)
+	 * 
+ */ + @Test + public void testTranspose2() throws Exception { + final Problem prob = new Problem(); + prob.bias = -1; + prob.l = 5; + prob.n = 10; + prob.x = new double[5][10]; + + prob.x[0][0] = 7; + prob.x[0][2] = 3; + prob.x[0][4] = 2; + + prob.x[1][1] = 1; + prob.x[1][3] = 5; + prob.x[1][4] = 3; + prob.x[1][6] = 4; + prob.x[1][7] = 2; + + prob.x[2][0] = 9; + prob.x[2][2] = 1; + prob.x[2][4] = 1; + prob.x[2][9] = 7; + + prob.x[3][0] = 2; + prob.x[3][1] = 2; + prob.x[3][2] = 9; + prob.x[3][3] = 7; + prob.x[3][4] = 8; + prob.x[3][5] = 1; + prob.x[3][6] = 5; + prob.x[3][7] = 4; + + prob.x[4][2] = 1; + prob.x[4][9] = 3; + + prob.y = new double[5]; + prob.y[0] = 0; + prob.y[1] = 1; + prob.y[2] = 1; + prob.y[3] = 0; + prob.y[4] = 1; + + final Problem transposed = Linear.transpose(prob); + + assertThat(transposed.x[0]).hasSize(5); + assertThat(transposed.x[1]).hasSize(5); + assertThat(transposed.x[2]).hasSize(5); + assertThat(transposed.x[3]).hasSize(5); + assertThat(transposed.x[4]).hasSize(5); + assertThat(transposed.x[5]).hasSize(5); + assertThat(transposed.x[7]).hasSize(5); + assertThat(transposed.x[7]).hasSize(5); + assertThat(transposed.x[8]).hasSize(5); + assertThat(transposed.x[9]).hasSize(5); + + assertThat(transposed.x[0][0]).isEqualTo(7); + assertThat(transposed.x[0][2]).isEqualTo(9); + assertThat(transposed.x[0][3]).isEqualTo(2); + + assertThat(transposed.x[1][1]).isEqualTo(1); + assertThat(transposed.x[1][3]).isEqualTo(2); + + assertThat(transposed.x[2][0]).isEqualTo(3); + assertThat(transposed.x[2][2]).isEqualTo(1); + assertThat(transposed.x[2][3]).isEqualTo(9); + assertThat(transposed.x[2][4]).isEqualTo(1); + + assertThat(transposed.x[3][1]).isEqualTo(5); + assertThat(transposed.x[3][3]).isEqualTo(7); + + assertThat(transposed.x[4][0]).isEqualTo(2); + assertThat(transposed.x[4][1]).isEqualTo(3); + assertThat(transposed.x[4][2]).isEqualTo(1); + assertThat(transposed.x[4][3]).isEqualTo(8); + + assertThat(transposed.x[5][3]).isEqualTo(1); + + assertThat(transposed.x[6][1]).isEqualTo(4); + assertThat(transposed.x[6][3]).isEqualTo(5); + + assertThat(transposed.x[7][1]).isEqualTo(2); + assertThat(transposed.x[7][3]).isEqualTo(4); + + assertThat(transposed.x[9][2]).isEqualTo(7); + assertThat(transposed.x[9][4]).isEqualTo(3); + + assertThat(transposed.y).isEqualTo(prob.y); + } + + /** + * compared input/output values with the C version (1.51) + * + * IN: res prob.l = 3 res prob.n = 4 0: (1,2) (3,1) (4,3) 1: (1,9) (2,7) + * (3,3) (4,3) 2: (2,1) + * + * TRANSPOSED: + * + * res prob.l = 3 * res prob.n = 4 0: (1,2) (2,9) 1: (2,7) (3,1) 2: (1,1) + * (2,3) 3: (1,3) (2,3) + * + */ + @Test + public void testTranspose3() throws Exception { + + final Problem prob = new Problem(); + prob.l = 3; + prob.n = 4; + prob.y = new double[3]; + prob.x = new double[3][4]; + + prob.x[0][0] = 2; + prob.x[0][2] = 1; + prob.x[0][3] = 3; + prob.x[1][0] = 9; + prob.x[1][1] = 7; + prob.x[1][2] = 3; + prob.x[1][3] = 3; + + prob.x[2][1] = 1; + + final Problem transposed = Linear.transpose(prob); + assertThat(transposed.x).hasSize(4); + assertThat(transposed.x[0]).hasSize(3); + assertThat(transposed.x[1]).hasSize(3); + assertThat(transposed.x[2]).hasSize(3); + assertThat(transposed.x[3]).hasSize(3); + + assertThat(transposed.x[0][0]).isEqualTo(2); + assertThat(transposed.x[0][1]).isEqualTo(9); + + assertThat(transposed.x[1][1]).isEqualTo(7); + assertThat(transposed.x[1][2]).isEqualTo(1); + + assertThat(transposed.x[2][0]).isEqualTo(1); + assertThat(transposed.x[2][1]).isEqualTo(3); + + assertThat(transposed.x[3][0]).isEqualTo(3); + assertThat(transposed.x[3][1]).isEqualTo(3); + } +} diff --git a/src/test/java/de/bwaldvogel/denseliblinear/ParameterTest.java b/src/test/java/de/bwaldvogel/denseliblinear/ParameterTest.java new file mode 100644 index 0000000..c0c55d6 --- /dev/null +++ b/src/test/java/de/bwaldvogel/denseliblinear/ParameterTest.java @@ -0,0 +1,127 @@ +package de.bwaldvogel.denseliblinear; + +import static org.fest.assertions.Assertions.assertThat; +import static org.junit.Assert.fail; + +import org.junit.Before; +import org.junit.Test; + +import de.bwaldvogel.denseliblinear.Parameter; +import de.bwaldvogel.denseliblinear.SolverType; + + +public class ParameterTest { + + private Parameter _param; + + @Before + public void setUp() { + _param = new Parameter(SolverType.L2R_L1LOSS_SVC_DUAL, 100, 1e-3); + } + + @Test + public void testSetWeights() { + + assertThat(_param.weight).isNull(); + assertThat(_param.getNumWeights()).isEqualTo(0); + + double[] weights = new double[] {0, 1, 2, 3, 4, 5}; + int[] weightLabels = new int[] {1, 1, 1, 1, 2, 3}; + _param.setWeights(weights, weightLabels); + + assertThat(_param.getNumWeights()).isEqualTo(6); + + // assert parameter uses a copy + weights[0]++; + assertThat(_param.getWeights()[0]).isEqualTo(0); + weightLabels[0]++; + assertThat(_param.getWeightLabels()[0]).isEqualTo(1); + + weights = new double[] {0, 1, 2, 3, 4, 5}; + weightLabels = new int[] {1}; + try { + _param.setWeights(weights, weightLabels); + fail("IllegalArgumentException expected"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("same").contains("length"); + } + } + + @Test + public void testGetWeights() { + double[] weights = new double[] {0, 1, 2, 3, 4, 5}; + int[] weightLabels = new int[] {1, 1, 1, 1, 2, 3}; + _param.setWeights(weights, weightLabels); + + assertThat(_param.getWeights()).isEqualTo(weights); + _param.getWeights()[0]++; // shouldn't change the parameter as we should get a copy + assertThat(_param.getWeights()).isEqualTo(weights); + + assertThat(_param.getWeightLabels()).isEqualTo(weightLabels); + _param.getWeightLabels()[0]++; // shouldn't change the parameter as we should get a copy + assertThat(_param.getWeightLabels()[0]).isEqualTo(1); + } + + @Test + public void testSetC() { + _param.setC(0.0001); + assertThat(_param.getC()).isEqualTo(0.0001); + _param.setC(1); + _param.setC(100); + assertThat(_param.getC()).isEqualTo(100); + _param.setC(Double.MAX_VALUE); + + try { + _param.setC(-1); + fail("IllegalArgumentException expected"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("must").contains("not").contains("<= 0"); + } + + try { + _param.setC(0); + fail("IllegalArgumentException expected"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("must").contains("not").contains("<= 0"); + } + } + + @Test + public void testSetEps() { + _param.setEps(0.0001); + assertThat(_param.getEps()).isEqualTo(0.0001); + _param.setEps(1); + _param.setEps(100); + assertThat(_param.getEps()).isEqualTo(100); + _param.setEps(Double.MAX_VALUE); + + try { + _param.setEps(-1); + fail("IllegalArgumentException expected"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("must").contains("not").contains("<= 0"); + } + + try { + _param.setEps(0); + fail("IllegalArgumentException expected"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("must").contains("not").contains("<= 0"); + } + } + + @Test + public void testSetSolverType() { + for (SolverType type : SolverType.values()) { + _param.setSolverType(type); + assertThat(_param.getSolverType()).isEqualTo(type); + } + try { + _param.setSolverType(null); + fail("IllegalArgumentException expected"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("must").contains("not").contains("null"); + } + } + +} diff --git a/src/test/java/de/bwaldvogel/denseliblinear/PredictTest.java b/src/test/java/de/bwaldvogel/denseliblinear/PredictTest.java new file mode 100644 index 0000000..7c411fe --- /dev/null +++ b/src/test/java/de/bwaldvogel/denseliblinear/PredictTest.java @@ -0,0 +1,57 @@ +package de.bwaldvogel.denseliblinear; + +import static org.fest.assertions.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +import java.io.BufferedReader; +import java.io.PrintStream; +import java.io.StringReader; +import java.io.StringWriter; +import java.io.Writer; + +import org.junit.Before; +import org.junit.Test; + +import de.bwaldvogel.denseliblinear.Model; +import de.bwaldvogel.denseliblinear.Predict; + + +public class PredictTest { + + private Model testModel = LinearTest.createRandomModel(); + private StringBuilder sb = new StringBuilder(); + private Writer writer = new StringWriter(); + + @Before + public void setUp() { + System.setOut(mock(PrintStream.class)); // dev/null + assertThat(testModel.getNrClass()).isGreaterThanOrEqualTo(2); + assertThat(testModel.getNrFeature()).isGreaterThanOrEqualTo(10); + } + + private void testWithLines(StringBuilder sb) throws Exception { + BufferedReader reader = new BufferedReader(new StringReader(sb.toString())); + + Predict.doPredict(reader, writer, testModel); + } + + @Test(expected = RuntimeException.class) + public void testDoPredictCorruptLine() throws Exception { + sb.append(testModel.label[0]).append(" abc").append("\n"); + testWithLines(sb); + } + + @Test(expected = RuntimeException.class) + public void testDoPredictCorruptLine2() throws Exception { + sb.append(testModel.label[0]).append(" 1:").append("\n"); + testWithLines(sb); + } + + @Test + public void testDoPredict() throws Exception { + sb.append(testModel.label[0]).append(" 1:0.32393").append("\n"); + sb.append(testModel.label[1]).append(" 2:-71.555 9:88223").append("\n"); + testWithLines(sb); + assertThat(writer.toString()).isNotEmpty(); + } +} diff --git a/src/test/java/de/bwaldvogel/denseliblinear/TrainTest.java b/src/test/java/de/bwaldvogel/denseliblinear/TrainTest.java new file mode 100644 index 0000000..75558c9 --- /dev/null +++ b/src/test/java/de/bwaldvogel/denseliblinear/TrainTest.java @@ -0,0 +1,210 @@ +package de.bwaldvogel.denseliblinear; + +import static org.fest.assertions.Assertions.assertThat; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.util.ArrayList; +import java.util.Collection; + +import org.junit.Test; + +public class TrainTest { + + @Test + public void testParseCommandLine() { + final Train train = new Train(); + + for (final SolverType solver : SolverType.values()) { + train.parse_command_line(new String[] { "-B", "5.3", "-s", "" + solver.getId(), "-p", "0.01", "model-filename" }); + final Parameter param = train.getParameter(); + assertThat(param.solverType).isEqualTo(solver); + // check default eps + if (solver.getId() == 0 || solver.getId() == 2 // + || solver.getId() == 5 || solver.getId() == 6) + { + assertThat(param.eps).isEqualTo(0.01); + } else if (solver.getId() == 7) { + assertThat(param.eps).isEqualTo(0.1); + } else if (solver.getId() == 11) { + assertThat(param.eps).isEqualTo(0.001); + } else { + assertThat(param.eps).isEqualTo(0.1); + } + // check if bias is set + assertThat(train.getBias()).isEqualTo(5.3); + assertThat(param.p).isEqualTo(0.01); + } + } + + @Test + // https://github.com/bwaldvogel/liblinear-java/issues/4 + public void + testParseWeights() throws Exception + { + final Train train = new Train(); + train.parse_command_line(new String[] { "-v", "10", "-c", "10", "-w1", "1.234", "model-filename" }); + Parameter parameter = train.getParameter(); + assertThat(parameter.weightLabel).isEqualTo(new int[] { 1 }); + assertThat(parameter.weight).isEqualTo(new double[] { 1.234 }); + + train.parse_command_line(new String[] { "-w1", "1.234", "-w2", "0.12", "-w3", "7", "model-filename" }); + parameter = train.getParameter(); + assertThat(parameter.weightLabel).isEqualTo(new int[] { 1, 2, 3 }); + assertThat(parameter.weight).isEqualTo(new double[] { 1.234, 0.12, 7 }); + } + + @Test + public void testReadProblem() throws Exception { + + final File file = File.createTempFile("svm", "test"); + file.deleteOnExit(); + + final Collection lines = new ArrayList(); + lines.add("1 1:1 3:1 4:1 6:1"); + lines.add("2 2:1 3:1 5:1 7:1"); + lines.add("1 3:1 5:1"); + lines.add("1 1:1 4:1 7:1"); + lines.add("2 4:1 5:1 7:1"); + final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); + try { + for (final String line : lines) + writer.append(line).append("\n"); + } finally { + writer.close(); + } + + final Train train = new Train(); + train.readProblem(file.getAbsolutePath()); + + final Problem prob = train.getProblem(); + assertThat(prob.bias).isEqualTo(1); + assertThat(prob.y).hasSize(lines.size()); + assertThat(prob.y).isEqualTo(new double[] { 1, 2, 1, 1, 2 }); + assertThat(prob.n).isEqualTo(8); + assertThat(prob.l).isEqualTo(prob.y.length); + assertThat(prob.x).hasSize(prob.y.length); + + for (final double[] nodes : prob.x) { + + assertThat(nodes.length).isLessThanOrEqualTo(prob.n); + for (int ind = 0; ind < prob.n; ind++) { + // bias term + if (prob.bias >= 0 && ind == prob.n - 1) { + // assertThat(ind).isEqualTo(prob.n); + assertThat(nodes[ind]).isEqualTo(prob.bias); + } else { + assertThat(ind).isLessThan(prob.n); + } + } + } + } + + /** + * unit-test for Issue #1 + * (http://github.com/bwaldvogel/liblinear-java/issues#issue/1) + */ + @Test + public void testReadProblemEmptyLine() throws Exception { + + final File file = File.createTempFile("svm", "test"); + file.deleteOnExit(); + + final Collection lines = new ArrayList(); + lines.add("1 1:1 3:1 4:1 6:1"); + lines.add("2 "); + final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); + try { + for (final String line : lines) + writer.append(line).append("\n"); + } finally { + writer.close(); + } + + final Problem prob = Train.readProblem(file, -1.0); + assertThat(prob.bias).isEqualTo(-1); + assertThat(prob.y).hasSize(lines.size()); + assertThat(prob.y).isEqualTo(new double[] { 1, 2 }); + assertThat(prob.n).isEqualTo(6); + assertThat(prob.l).isEqualTo(prob.y.length); + assertThat(prob.x).hasSize(prob.y.length); + + assertThat(prob.x[0]).hasSize(6); + assertThat(prob.x[1]).hasSize(6); + } + + @Test(expected = InvalidInputDataException.class) + public void testReadUnsortedProblem() throws Exception { + final File file = File.createTempFile("svm", "test"); + file.deleteOnExit(); + + final Collection lines = new ArrayList(); + lines.add("1 1:1 3:1 4:1 6:1"); + lines.add("2 2:1 3:1 5:1 7:1"); + lines.add("1 3:1 5:1 4:1"); // here's the mistake: not correctly + // sorted + + final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); + try { + for (final String line : lines) + writer.append(line).append("\n"); + } finally { + writer.close(); + } + + final Train train = new Train(); + train.readProblem(file.getAbsolutePath()); + } + + @Test(expected = InvalidInputDataException.class) + public void testReadProblemWithInvalidIndex() throws Exception { + final File file = File.createTempFile("svm", "test"); + file.deleteOnExit(); + + final Collection lines = new ArrayList(); + lines.add("1 1:1 3:1 4:1 6:1"); + lines.add("2 2:1 3:1 5:1 -4:1"); + + final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); + try { + for (final String line : lines) + writer.append(line).append("\n"); + } finally { + writer.close(); + } + + final Train train = new Train(); + try { + train.readProblem(file.getAbsolutePath()); + } catch (final InvalidInputDataException e) { + throw e; + } + } + + @Test(expected = InvalidInputDataException.class) + public void testReadWrongProblem() throws Exception { + final File file = File.createTempFile("svm", "test"); + file.deleteOnExit(); + + final Collection lines = new ArrayList(); + lines.add("1 1:1 3:1 4:1 6:1"); + lines.add("2 2:1 3:1 5:1 7:1"); + lines.add("1 3:1 5:a"); // here's the mistake: incomplete line + + final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); + try { + for (final String line : lines) + writer.append(line).append("\n"); + } finally { + writer.close(); + } + + final Train train = new Train(); + try { + train.readProblem(file.getAbsolutePath()); + } catch (final InvalidInputDataException e) { + throw e; + } + } +} From 76fa4e436cd641c01d1ca7846494dd424cab659f Mon Sep 17 00:00:00 2001 From: Jonathon Hare Date: Wed, 17 Jul 2013 16:19:46 +0100 Subject: [PATCH 2/4] refactoring to reduce code duplication --- .../liblinear/DenseL2R_L2_SvcFunction.java | 117 + .../liblinear/DenseL2R_L2_SvrFunction.java | 67 + .../liblinear/DenseL2R_LrFunction.java | 108 + .../de/bwaldvogel/liblinear/DenseLinear.java | 1912 +++++++++++++++++ .../de/bwaldvogel/liblinear/DensePredict.java | 194 ++ .../de/bwaldvogel/liblinear/DenseProblem.java | 62 + .../liblinear/DenseSolverMCSVM_CS.java | 293 +++ .../de/bwaldvogel/liblinear/DenseTrain.java | 420 ++++ .../bwaldvogel/liblinear/DenseLinearTest.java | 517 +++++ .../bwaldvogel/liblinear/DenseTrainTest.java | 213 ++ 10 files changed, 3903 insertions(+) create mode 100644 src/main/java/de/bwaldvogel/liblinear/DenseL2R_L2_SvcFunction.java create mode 100644 src/main/java/de/bwaldvogel/liblinear/DenseL2R_L2_SvrFunction.java create mode 100644 src/main/java/de/bwaldvogel/liblinear/DenseL2R_LrFunction.java create mode 100644 src/main/java/de/bwaldvogel/liblinear/DenseLinear.java create mode 100644 src/main/java/de/bwaldvogel/liblinear/DensePredict.java create mode 100644 src/main/java/de/bwaldvogel/liblinear/DenseProblem.java create mode 100644 src/main/java/de/bwaldvogel/liblinear/DenseSolverMCSVM_CS.java create mode 100644 src/main/java/de/bwaldvogel/liblinear/DenseTrain.java create mode 100644 src/test/java/de/bwaldvogel/liblinear/DenseLinearTest.java create mode 100644 src/test/java/de/bwaldvogel/liblinear/DenseTrainTest.java diff --git a/src/main/java/de/bwaldvogel/liblinear/DenseL2R_L2_SvcFunction.java b/src/main/java/de/bwaldvogel/liblinear/DenseL2R_L2_SvcFunction.java new file mode 100644 index 0000000..8f00f82 --- /dev/null +++ b/src/main/java/de/bwaldvogel/liblinear/DenseL2R_L2_SvcFunction.java @@ -0,0 +1,117 @@ +package de.bwaldvogel.liblinear; + +class DenseL2R_L2_SvcFunction implements Function { + + protected final DenseProblem prob; + protected final double[] C; + protected final int[] I; + protected final double[] z; + + protected int sizeI; + + public DenseL2R_L2_SvcFunction(DenseProblem prob, double[] C) { + final int l = prob.l; + + this.prob = prob; + + z = new double[l]; + I = new int[l]; + this.C = C; + } + + @Override + public double fun(double[] w) { + int i; + double f = 0; + final double[] y = prob.y; + final int l = prob.l; + final int w_size = get_nr_variable(); + + Xv(w, z); + + for (i = 0; i < w_size; i++) + f += w[i] * w[i]; + f /= 2.0; + for (i = 0; i < l; i++) { + z[i] = y[i] * z[i]; + final double d = 1 - z[i]; + if (d > 0) + f += C[i] * d * d; + } + + return (f); + } + + @Override + public int get_nr_variable() { + return prob.n; + } + + @Override + public void grad(double[] w, double[] g) { + final double[] y = prob.y; + final int l = prob.l; + final int w_size = get_nr_variable(); + + sizeI = 0; + for (int i = 0; i < l; i++) { + if (z[i] < 1) { + z[sizeI] = C[i] * y[i] * (z[i] - 1); + I[sizeI] = i; + sizeI++; + } + } + subXTv(z, g); + + for (int i = 0; i < w_size; i++) + g[i] = w[i] + 2 * g[i]; + } + + @Override + public void Hv(double[] s, double[] Hs) { + int i; + final int w_size = get_nr_variable(); + final double[] wa = new double[sizeI]; + + subXv(s, wa); + for (i = 0; i < sizeI; i++) + wa[i] = C[I[i]] * wa[i]; + + subXTv(wa, Hs); + for (i = 0; i < w_size; i++) + Hs[i] = s[i] + 2 * Hs[i]; + } + + protected void subXTv(double[] v, double[] XTv) { + int i; + final int w_size = get_nr_variable(); + + for (i = 0; i < w_size; i++) + XTv[i] = 0; + + for (i = 0; i < sizeI; i++) { + for (int j = 0; j < prob.x[I[i]].length; j++) { + XTv[j] += v[i] * prob.x[I[i]][j]; + } + } + } + + private void subXv(double[] v, double[] Xv) { + for (int i = 0; i < sizeI; i++) { + Xv[i] = 0; + + for (int j = 0; j < prob.x[I[i]].length; j++) { + Xv[i] += v[j] * prob.x[I[i]][j]; + } + } + } + + protected void Xv(double[] v, double[] Xv) { + for (int i = 0; i < prob.l; i++) { + Xv[i] = 0; + for (int j = 0; j < prob.x[i].length; j++) { + Xv[i] += v[j] * prob.x[i][j]; + } + } + } +} diff --git a/src/main/java/de/bwaldvogel/liblinear/DenseL2R_L2_SvrFunction.java b/src/main/java/de/bwaldvogel/liblinear/DenseL2R_L2_SvrFunction.java new file mode 100644 index 0000000..5ad6ce1 --- /dev/null +++ b/src/main/java/de/bwaldvogel/liblinear/DenseL2R_L2_SvrFunction.java @@ -0,0 +1,67 @@ +package de.bwaldvogel.liblinear; + +/** + * @since 1.91 + */ +public class DenseL2R_L2_SvrFunction extends DenseL2R_L2_SvcFunction { + + private double p; + + public DenseL2R_L2_SvrFunction( DenseProblem prob, double[] C, double p ) { + super(prob, C); + this.p = p; + } + + @Override + public double fun(double[] w) { + double f = 0; + double[] y = prob.y; + int l = prob.l; + int w_size = get_nr_variable(); + double d; + + Xv(w, z); + + for (int i = 0; i < w_size; i++) + f += w[i] * w[i]; + f /= 2; + for (int i = 0; i < l; i++) { + d = z[i] - y[i]; + if (d < -p) + f += C[i] * (d + p) * (d + p); + else if (d > p) f += C[i] * (d - p) * (d - p); + } + + return f; + } + + @Override + public void grad(double[] w, double[] g) { + double[] y = prob.y; + int l = prob.l; + int w_size = get_nr_variable(); + + sizeI = 0; + for (int i = 0; i < l; i++) { + double d = z[i] - y[i]; + + // generate index set I + if (d < -p) { + z[sizeI] = C[i] * (d + p); + I[sizeI] = i; + sizeI++; + } else if (d > p) { + z[sizeI] = C[i] * (d - p); + I[sizeI] = i; + sizeI++; + } + + } + subXTv(z, g); + + for (int i = 0; i < w_size; i++) + g[i] = w[i] + 2 * g[i]; + + } + +} diff --git a/src/main/java/de/bwaldvogel/liblinear/DenseL2R_LrFunction.java b/src/main/java/de/bwaldvogel/liblinear/DenseL2R_LrFunction.java new file mode 100644 index 0000000..855d810 --- /dev/null +++ b/src/main/java/de/bwaldvogel/liblinear/DenseL2R_LrFunction.java @@ -0,0 +1,108 @@ +package de.bwaldvogel.liblinear; + +class DenseL2R_LrFunction implements Function { + + private final double[] C; + private final double[] z; + private final double[] D; + private final DenseProblem prob; + + public DenseL2R_LrFunction(DenseProblem prob, double[] C) { + final int l = prob.l; + + this.prob = prob; + + z = new double[l]; + D = new double[l]; + this.C = C; + } + + private void Xv(double[] v, double[] Xv) { + for (int i = 0; i < prob.l; i++) { + Xv[i] = 0; + for (int j = 0; j < prob.x[i].length; j++) { + Xv[i] += v[j] * prob.x[i][j]; + } + } + } + + private void XTv(double[] v, double[] XTv) { + final int l = prob.l; + final int w_size = get_nr_variable(); + final double[][] x = prob.x; + + for (int i = 0; i < w_size; i++) + XTv[i] = 0; + + for (int i = 0; i < l; i++) { + for (int j = 0; j < prob.x[i].length; j++) { + XTv[j] += v[i] * x[i][j]; + } + } + } + + @Override + public double fun(double[] w) { + int i; + double f = 0; + final double[] y = prob.y; + final int l = prob.l; + final int w_size = get_nr_variable(); + + Xv(w, z); + + for (i = 0; i < w_size; i++) + f += w[i] * w[i]; + f /= 2.0; + for (i = 0; i < l; i++) { + final double yz = y[i] * z[i]; + if (yz >= 0) + f += C[i] * Math.log(1 + Math.exp(-yz)); + else + f += C[i] * (-yz + Math.log(1 + Math.exp(yz))); + } + + return (f); + } + + @Override + public void grad(double[] w, double[] g) { + int i; + final double[] y = prob.y; + final int l = prob.l; + final int w_size = get_nr_variable(); + + for (i = 0; i < l; i++) { + z[i] = 1 / (1 + Math.exp(-y[i] * z[i])); + D[i] = z[i] * (1 - z[i]); + z[i] = C[i] * (z[i] - 1) * y[i]; + } + XTv(z, g); + + for (i = 0; i < w_size; i++) + g[i] = w[i] + g[i]; + } + + @Override + public void Hv(double[] s, double[] Hs) { + int i; + final int l = prob.l; + final int w_size = get_nr_variable(); + final double[] wa = new double[l]; + + Xv(s, wa); + for (i = 0; i < l; i++) + wa[i] = C[i] * D[i] * wa[i]; + + XTv(wa, Hs); + for (i = 0; i < w_size; i++) + Hs[i] = s[i] + Hs[i]; + // delete[] wa; + } + + @Override + public int get_nr_variable() { + return prob.n; + } + +} diff --git a/src/main/java/de/bwaldvogel/liblinear/DenseLinear.java b/src/main/java/de/bwaldvogel/liblinear/DenseLinear.java new file mode 100644 index 0000000..f13c5da --- /dev/null +++ b/src/main/java/de/bwaldvogel/liblinear/DenseLinear.java @@ -0,0 +1,1912 @@ +package de.bwaldvogel.liblinear; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.Closeable; +import java.io.EOFException; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.PrintStream; +import java.io.Reader; +import java.io.Writer; +import java.nio.charset.Charset; +import java.util.Formatter; +import java.util.Locale; +import java.util.Random; +import java.util.regex.Pattern; + +/** + *

Java port of liblinear

+ * + *

+ * The usage should be pretty similar to the C version of liblinear. + *

+ *

+ * Please consider reading the README file of liblinear. + *

+ * + *

+ * The port was done by Benedikt Waldvogel (mail at bwaldvogel.de) + *

+ * + * @version 1.92 + */ +public class DenseLinear { + + static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1"); + + static final Locale DEFAULT_LOCALE = Locale.ENGLISH; + + private static Object OUTPUT_MUTEX = new Object(); + private static PrintStream DEBUG_OUTPUT = System.out; + + private static final long DEFAULT_RANDOM_SEED = 0L; + static Random random = new Random(DEFAULT_RANDOM_SEED); + + /** + * @param target + * predicted classes + */ + public static void crossValidation(DenseProblem prob, Parameter param, int nr_fold, double[] target) { + int i; + final int[] fold_start = new int[nr_fold + 1]; + final int l = prob.l; + final int[] perm = new int[l]; + + for (i = 0; i < l; i++) + perm[i] = i; + for (i = 0; i < l; i++) { + final int j = i + random.nextInt(l - i); + swap(perm, i, j); + } + for (i = 0; i <= nr_fold; i++) + fold_start[i] = i * l / nr_fold; + + for (i = 0; i < nr_fold; i++) { + final int begin = fold_start[i]; + final int end = fold_start[i + 1]; + int j, k; + final DenseProblem subprob = new DenseProblem(); + + subprob.bias = prob.bias; + subprob.n = prob.n; + subprob.l = l - (end - begin); + subprob.x = new double[subprob.l][]; + subprob.y = new double[subprob.l]; + + k = 0; + for (j = 0; j < begin; j++) { + subprob.x[k] = prob.x[perm[j]]; + subprob.y[k] = prob.y[perm[j]]; + ++k; + } + for (j = end; j < l; j++) { + subprob.x[k] = prob.x[perm[j]]; + subprob.y[k] = prob.y[perm[j]]; + ++k; + } + final Model submodel = train(subprob, param); + for (j = begin; j < end; j++) + target[perm[j]] = predict(submodel, prob.x[perm[j]]); + } + } + + /** used as complex return type */ + private static class GroupClassesReturn { + + final int[] count; + final int[] label; + final int nr_class; + final int[] start; + + GroupClassesReturn(int nr_class, int[] label, int[] start, int[] count) { + this.nr_class = nr_class; + this.label = label; + this.start = start; + this.count = count; + } + } + + private static GroupClassesReturn groupClasses(DenseProblem prob, int[] perm) { + final int l = prob.l; + int max_nr_class = 16; + int nr_class = 0; + + int[] label = new int[max_nr_class]; + int[] count = new int[max_nr_class]; + final int[] data_label = new int[l]; + int i; + + for (i = 0; i < l; i++) { + final int this_label = (int) prob.y[i]; + int j; + for (j = 0; j < nr_class; j++) { + if (this_label == label[j]) { + ++count[j]; + break; + } + } + data_label[i] = j; + if (j == nr_class) { + if (nr_class == max_nr_class) { + max_nr_class *= 2; + label = copyOf(label, max_nr_class); + count = copyOf(count, max_nr_class); + } + label[nr_class] = this_label; + count[nr_class] = 1; + ++nr_class; + } + } + + final int[] start = new int[nr_class]; + start[0] = 0; + for (i = 1; i < nr_class; i++) + start[i] = start[i - 1] + count[i - 1]; + for (i = 0; i < l; i++) { + perm[start[data_label[i]]] = i; + ++start[data_label[i]]; + } + start[0] = 0; + for (i = 1; i < nr_class; i++) + start[i] = start[i - 1] + count[i - 1]; + + return new GroupClassesReturn(nr_class, label, start, count); + } + + static void info(String message) { + synchronized (OUTPUT_MUTEX) { + if (DEBUG_OUTPUT == null) + return; + DEBUG_OUTPUT.printf(message); + DEBUG_OUTPUT.flush(); + } + } + + static void info(String format, Object... args) { + synchronized (OUTPUT_MUTEX) { + if (DEBUG_OUTPUT == null) + return; + DEBUG_OUTPUT.printf(format, args); + DEBUG_OUTPUT.flush(); + } + } + + /** + * @param s + * the string to parse for the double value + * @throws IllegalArgumentException + * if s is empty or represents NaN or Infinity + * @throws NumberFormatException + * see {@link Double#parseDouble(String)} + */ + static double atof(String s) { + if (s == null || s.length() < 1) + throw new IllegalArgumentException("Can't convert empty string to integer"); + final double d = Double.parseDouble(s); + if (Double.isNaN(d) || Double.isInfinite(d)) { + throw new IllegalArgumentException("NaN or Infinity in input: " + s); + } + return (d); + } + + /** + * @param s + * the string to parse for the integer value + * @throws IllegalArgumentException + * if s is empty + * @throws NumberFormatException + * see {@link Integer#parseInt(String)} + */ + static int atoi(String s) throws NumberFormatException { + if (s == null || s.length() < 1) + throw new IllegalArgumentException("Can't convert empty string to integer"); + // Integer.parseInt doesn't accept '+' prefixed strings + if (s.charAt(0) == '+') + s = s.substring(1); + return Integer.parseInt(s); + } + + /** + * Java5 'backport' of Arrays.copyOf + */ + public static double[] copyOf(double[] original, int newLength) { + final double[] copy = new double[newLength]; + System.arraycopy(original, 0, copy, 0, Math.min(original.length, newLength)); + return copy; + } + + /** + * Java5 'backport' of Arrays.copyOf + */ + public static int[] copyOf(int[] original, int newLength) { + final int[] copy = new int[newLength]; + System.arraycopy(original, 0, copy, 0, Math.min(original.length, newLength)); + return copy; + } + + /** + * Loads the model from inputReader. It uses + * {@link java.util.Locale#ENGLISH} for number formatting. + * + *

+ * Note: The inputReader is NOT closed after reading or in case of an + * exception. + *

+ */ + public static Model loadModel(Reader inputReader) throws IOException { + final Model model = new Model(); + + model.label = null; + + final Pattern whitespace = Pattern.compile("\\s+"); + + BufferedReader reader = null; + if (inputReader instanceof BufferedReader) { + reader = (BufferedReader) inputReader; + } else { + reader = new BufferedReader(inputReader); + } + + String line = null; + while ((line = reader.readLine()) != null) { + final String[] split = whitespace.split(line); + if (split[0].equals("solver_type")) { + final SolverType solver = SolverType.valueOf(split[1]); + if (solver == null) { + throw new RuntimeException("unknown solver type"); + } + model.solverType = solver; + } else if (split[0].equals("nr_class")) { + model.nr_class = atoi(split[1]); + Integer.parseInt(split[1]); + } else if (split[0].equals("nr_feature")) { + model.nr_feature = atoi(split[1]); + } else if (split[0].equals("bias")) { + model.bias = atof(split[1]); + } else if (split[0].equals("w")) { + break; + } else if (split[0].equals("label")) { + model.label = new int[model.nr_class]; + for (int i = 0; i < model.nr_class; i++) { + model.label[i] = atoi(split[i + 1]); + } + } else { + throw new RuntimeException("unknown text in model file: [" + line + "]"); + } + } + + int w_size = model.nr_feature; + if (model.bias >= 0) + w_size++; + + int nr_w = model.nr_class; + if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) + nr_w = 1; + + model.w = new double[w_size * nr_w]; + final int[] buffer = new int[128]; + + for (int i = 0; i < w_size; i++) { + for (int j = 0; j < nr_w; j++) { + int b = 0; + while (true) { + final int ch = reader.read(); + if (ch == -1) { + throw new EOFException("unexpected EOF"); + } + if (ch == ' ') { + model.w[i * nr_w + j] = atof(new String(buffer, 0, b)); + break; + } else { + buffer[b++] = ch; + } + } + } + } + + return model; + } + + /** + * Loads the model from the file with ISO-8859-1 charset. It uses + * {@link java.util.Locale#ENGLISH} for number formatting. + */ + public static Model loadModel(File modelFile) throws IOException { + final BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), + FILE_CHARSET)); + try { + return loadModel(inputReader); + } finally { + inputReader.close(); + } + } + + static void closeQuietly(Closeable c) { + if (c == null) + return; + try { + c.close(); + } catch (final Throwable t) { + } + } + + public static double predict(Model model, double[] x) { + final double[] dec_values = new double[model.nr_class]; + return predictValues(model, x, dec_values); + } + + /** + * @throws IllegalArgumentException + * if model is not probabilistic (see + * {@link Model#isProbabilityModel()}) + */ + public static double predictProbability(Model model, double[] x, double[] prob_estimates) + throws IllegalArgumentException + { + if (!model.isProbabilityModel()) { + final StringBuilder sb = new StringBuilder("probability output is only supported for logistic regression"); + sb.append(". This is currently only supported by the following solvers: "); + int i = 0; + for (final SolverType solverType : SolverType.values()) { + if (solverType.isLogisticRegressionSolver()) { + if (i++ > 0) { + sb.append(", "); + } + sb.append(solverType.name()); + } + } + throw new IllegalArgumentException(sb.toString()); + } + final int nr_class = model.nr_class; + int nr_w; + if (nr_class == 2) + nr_w = 1; + else + nr_w = nr_class; + + final double label = predictValues(model, x, prob_estimates); + for (int i = 0; i < nr_w; i++) + prob_estimates[i] = 1 / (1 + Math.exp(-prob_estimates[i])); + + if (nr_class == 2) // for binary classification + prob_estimates[1] = 1. - prob_estimates[0]; + else { + double sum = 0; + for (int i = 0; i < nr_class; i++) + sum += prob_estimates[i]; + + for (int i = 0; i < nr_class; i++) + prob_estimates[i] = prob_estimates[i] / sum; + } + + return label; + } + + public static double predictValues(Model model, double[] x, double[] dec_values) { + int n; + if (model.bias >= 0) + n = model.nr_feature + 1; + else + n = model.nr_feature; + + final double[] w = model.w; + + int nr_w; + if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) + nr_w = 1; + else + nr_w = model.nr_class; + + for (int i = 0; i < nr_w; i++) + dec_values[i] = 0; + + for (int idx = 0; idx < n; idx++) { + // the dimension of testing data may exceed that of training + for (int i = 0; i < nr_w; i++) { + dec_values[i] += w[idx * nr_w + i] * x[idx]; + } + } + + if (model.nr_class == 2) { + if (model.solverType.isSupportVectorRegression()) + return dec_values[0]; + else + return (dec_values[0] > 0) ? model.label[0] : model.label[1]; + } else { + int dec_max_idx = 0; + for (int i = 1; i < model.nr_class; i++) { + if (dec_values[i] > dec_values[dec_max_idx]) + dec_max_idx = i; + } + return model.label[dec_max_idx]; + } + } + + static void printf(Formatter formatter, String format, Object... args) throws IOException { + formatter.format(format, args); + final IOException ioException = formatter.ioException(); + if (ioException != null) + throw ioException; + } + + /** + * Writes the model to the modelOutput. It uses + * {@link java.util.Locale#ENGLISH} for number formatting. + * + *

+ * Note: The modelOutput is closed after reading or in case of an + * exception. + *

+ */ + public static void saveModel(Writer modelOutput, Model model) throws IOException { + final int nr_feature = model.nr_feature; + int w_size = nr_feature; + if (model.bias >= 0) + w_size++; + + int nr_w = model.nr_class; + if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) + nr_w = 1; + + final Formatter formatter = new Formatter(modelOutput, DEFAULT_LOCALE); + try { + printf(formatter, "solver_type %s\n", model.solverType.name()); + printf(formatter, "nr_class %d\n", model.nr_class); + + if (model.label != null) { + printf(formatter, "label"); + for (int i = 0; i < model.nr_class; i++) { + printf(formatter, " %d", model.label[i]); + } + printf(formatter, "\n"); + } + + printf(formatter, "nr_feature %d\n", nr_feature); + printf(formatter, "bias %.16g\n", model.bias); + + printf(formatter, "w\n"); + for (int i = 0; i < w_size; i++) { + for (int j = 0; j < nr_w; j++) { + final double value = model.w[i * nr_w + j]; + + /** + * this optimization is the reason for + * {@link Model#equals(double[], double[])} + */ + if (value == 0.0) { + printf(formatter, "%d ", 0); + } else { + printf(formatter, "%.16g ", value); + } + } + printf(formatter, "\n"); + } + + formatter.flush(); + final IOException ioException = formatter.ioException(); + if (ioException != null) + throw ioException; + } finally { + formatter.close(); + } + } + + /** + * Writes the model to the file with ISO-8859-1 charset. It uses + * {@link java.util.Locale#ENGLISH} for number formatting. + */ + public static void saveModel(File modelFile, Model model) throws IOException { + final BufferedWriter modelOutput = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(modelFile), + FILE_CHARSET)); + saveModel(modelOutput, model); + } + + /* + * this method corresponds to the following define in the C version: #define + * GETI(i) (y[i]+1) + */ + private static int GETI(byte[] y, int i) { + return y[i] + 1; + } + + /** + * A coordinate descent algorithm for L1-loss and L2-loss SVM dual problems + * + *
+	 *  min_\alpha  0.5(\alpha^T (Q + D)\alpha) - e^T \alpha,
+	 *    s.t.      0 <= \alpha_i <= upper_bound_i,
+	 * 
+	 *  where Qij = yi yj xi^T xj and
+	 *  D is a diagonal matrix
+	 * 
+	 * In L1-SVM case:
+	 *     upper_bound_i = Cp if y_i = 1
+	 *      upper_bound_i = Cn if y_i = -1
+	 *      D_ii = 0
+	 * In L2-SVM case:
+	 *      upper_bound_i = INF
+	 *      D_ii = 1/(2*Cp) if y_i = 1
+	 *      D_ii = 1/(2*Cn) if y_i = -1
+	 * 
+	 * Given:
+	 * x, y, Cp, Cn
+	 * eps is the stopping tolerance
+	 * 
+	 * solution will be put in w
+	 * 
+	 * See Algorithm 3 of Hsieh et al., ICML 2008
+	 * 
+ */ + private static void solve_l2r_l1l2_svc(DenseProblem prob, double[] w, double eps, double Cp, double Cn, + SolverType solver_type) + { + final int l = prob.l; + final int w_size = prob.n; + int i, s, iter = 0; + double C, d, G; + final double[] QD = new double[l]; + final int max_iter = 1000; + final int[] index = new int[l]; + final double[] alpha = new double[l]; + final byte[] y = new byte[l]; + int active_size = l; + + // PG: projected gradient, for shrinking and stopping + double PG; + double PGmax_old = Double.POSITIVE_INFINITY; + double PGmin_old = Double.NEGATIVE_INFINITY; + double PGmax_new, PGmin_new; + + // default solver_type: L2R_L2LOSS_SVC_DUAL + final double diag[] = new double[] { 0.5 / Cn, 0, 0.5 / Cp }; + final double upper_bound[] = new double[] { Double.POSITIVE_INFINITY, 0, Double.POSITIVE_INFINITY }; + if (solver_type == SolverType.L2R_L1LOSS_SVC_DUAL) { + diag[0] = 0; + diag[2] = 0; + upper_bound[0] = Cn; + upper_bound[2] = Cp; + } + + for (i = 0; i < l; i++) { + if (prob.y[i] > 0) { + y[i] = +1; + } else { + y[i] = -1; + } + } + + // Initial alpha can be set here. Note that + // 0 <= alpha[i] <= upper_bound[GETI(i)] + for (i = 0; i < l; i++) + alpha[i] = 0; + + for (i = 0; i < w_size; i++) + w[i] = 0; + for (i = 0; i < l; i++) { + QD[i] = diag[GETI(y, i)]; + + for (int j = 0; j < w_size; j++) { + final double val = prob.x[i][j]; + QD[i] += val * val; + w[j] += y[i] * alpha[i] * val; + } + index[i] = i; + } + + while (iter < max_iter) { + PGmax_new = Double.NEGATIVE_INFINITY; + PGmin_new = Double.POSITIVE_INFINITY; + + for (i = 0; i < active_size; i++) { + final int j = i + random.nextInt(active_size - i); + swap(index, i, j); + } + + for (s = 0; s < active_size; s++) { + i = index[s]; + G = 0; + final byte yi = y[i]; + + for (int j = 0; j < w_size; j++) { + G += w[j] * prob.x[i][j]; + } + G = G * yi - 1; + + C = upper_bound[GETI(y, i)]; + G += alpha[i] * diag[GETI(y, i)]; + + PG = 0; + if (alpha[i] == 0) { + if (G > PGmax_old) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } else if (G < 0) { + PG = G; + } + } else if (alpha[i] == C) { + if (G < PGmin_old) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } else if (G > 0) { + PG = G; + } + } else { + PG = G; + } + + PGmax_new = Math.max(PGmax_new, PG); + PGmin_new = Math.min(PGmin_new, PG); + + if (Math.abs(PG) > 1.0e-12) { + final double alpha_old = alpha[i]; + alpha[i] = Math.min(Math.max(alpha[i] - G / QD[i], 0.0), C); + d = (alpha[i] - alpha_old) * yi; + + for (int j = 0; j < w_size; j++) { + w[j] += d * prob.x[i][j]; + } + } + } + + iter++; + if (iter % 10 == 0) + info("."); + + if (PGmax_new - PGmin_new <= eps) { + if (active_size == l) + break; + else { + active_size = l; + info("*"); + PGmax_old = Double.POSITIVE_INFINITY; + PGmin_old = Double.NEGATIVE_INFINITY; + continue; + } + } + PGmax_old = PGmax_new; + PGmin_old = PGmin_new; + if (PGmax_old <= 0) + PGmax_old = Double.POSITIVE_INFINITY; + if (PGmin_old >= 0) + PGmin_old = Double.NEGATIVE_INFINITY; + } + + info("%noptimization finished, #iter = %d%n", iter); + if (iter >= max_iter) + info("%nWARNING: reaching max number of iterations%nUsing -s 2 may be faster (also see FAQ)%n%n"); + + // calculate objective value + + double v = 0; + int nSV = 0; + for (i = 0; i < w_size; i++) + v += w[i] * w[i]; + for (i = 0; i < l; i++) { + v += alpha[i] * (alpha[i] * diag[GETI(y, i)] - 2); + if (alpha[i] > 0) + ++nSV; + } + info("Objective value = %g%n", v / 2); + info("nSV = %d%n", nSV); + } + + // To support weights for instances, use GETI(i) (i) + private static int GETI_SVR(int i) { + return 0; + } + + /** + * A coordinate descent algorithm for L1-loss and L2-loss epsilon-SVR dual + * problem + * + * min_\beta 0.5\beta^T (Q + diag(lambda)) \beta - p \sum_{i=1}^l|\beta_i| + + * \sum_{i=1}^l yi\beta_i, s.t. -upper_bound_i <= \beta_i <= upper_bound_i, + * + * where Qij = xi^T xj and D is a diagonal matrix + * + * In L1-SVM case: upper_bound_i = C lambda_i = 0 In L2-SVM case: + * upper_bound_i = INF lambda_i = 1/(2*C) + * + * Given: x, y, p, C eps is the stopping tolerance + * + * solution will be put in w + * + * See Algorithm 4 of Ho and Lin, 2012 + */ + private static void solve_l2r_l1l2_svr(DenseProblem prob, double[] w, Parameter param) { + final int l = prob.l; + final double C = param.C; + final double p = param.p; + final int w_size = prob.n; + final double eps = param.eps; + int i, s, iter = 0; + final int max_iter = 1000; + int active_size = l; + final int[] index = new int[l]; + + double d, G, H; + double Gmax_old = Double.POSITIVE_INFINITY; + double Gmax_new, Gnorm1_new; + double Gnorm1_init = 0; // initialize to 0 to get rid of Eclipse + // warning/error + final double[] beta = new double[l]; + final double[] QD = new double[l]; + final double[] y = prob.y; + + // L2R_L2LOSS_SVR_DUAL + final double[] lambda = new double[] { 0.5 / C }; + final double[] upper_bound = new double[] { Double.POSITIVE_INFINITY }; + + if (param.solverType == SolverType.L2R_L1LOSS_SVR_DUAL) { + lambda[0] = 0; + upper_bound[0] = C; + } + + // Initial beta can be set here. Note that + // -upper_bound <= beta[i] <= upper_bound + for (i = 0; i < l; i++) + beta[i] = 0; + + for (i = 0; i < w_size; i++) + w[i] = 0; + for (i = 0; i < l; i++) { + QD[i] = 0; + for (int j = 0; j < w_size; j++) { + final double val = prob.x[i][j]; + QD[i] += val * val; + w[j] += beta[i] * val; + } + + index[i] = i; + } + + while (iter < max_iter) { + Gmax_new = 0; + Gnorm1_new = 0; + + for (i = 0; i < active_size; i++) { + final int j = i + random.nextInt(active_size - i); + swap(index, i, j); + } + + for (s = 0; s < active_size; s++) { + i = index[s]; + G = -y[i] + lambda[GETI_SVR(i)] * beta[i]; + H = QD[i] + lambda[GETI_SVR(i)]; + + for (int ind = 0; ind < w_size; ind++) { + final double val = prob.x[i][ind]; + G += val * w[ind]; + } + + final double Gp = G + p; + final double Gn = G - p; + double violation = 0; + if (beta[i] == 0) { + if (Gp < 0) + violation = -Gp; + else if (Gn > 0) + violation = Gn; + else if (Gp > Gmax_old && Gn < -Gmax_old) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } + } else if (beta[i] >= upper_bound[GETI_SVR(i)]) { + if (Gp > 0) + violation = Gp; + else if (Gp < -Gmax_old) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } + } else if (beta[i] <= -upper_bound[GETI_SVR(i)]) { + if (Gn < 0) + violation = -Gn; + else if (Gn > Gmax_old) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } + } else if (beta[i] > 0) + violation = Math.abs(Gp); + else + violation = Math.abs(Gn); + + Gmax_new = Math.max(Gmax_new, violation); + Gnorm1_new += violation; + + // obtain Newton direction d + if (Gp < H * beta[i]) + d = -Gp / H; + else if (Gn > H * beta[i]) + d = -Gn / H; + else + d = -beta[i]; + + if (Math.abs(d) < 1.0e-12) + continue; + + final double beta_old = beta[i]; + beta[i] = Math.min(Math.max(beta[i] + d, -upper_bound[GETI_SVR(i)]), upper_bound[GETI_SVR(i)]); + d = beta[i] - beta_old; + + if (d != 0) { + for (int j = 0; j < w_size; j++) { + w[j] += d * prob.x[i][j]; + } + } + } + + if (iter == 0) + Gnorm1_init = Gnorm1_new; + iter++; + if (iter % 10 == 0) + info("."); + + if (Gnorm1_new <= eps * Gnorm1_init) { + if (active_size == l) + break; + else { + active_size = l; + info("*"); + Gmax_old = Double.POSITIVE_INFINITY; + continue; + } + } + + Gmax_old = Gmax_new; + } + + info("%noptimization finished, #iter = %d%n", iter); + if (iter >= max_iter) + info("%nWARNING: reaching max number of iterations%nUsing -s 11 may be faster%n%n"); + + // calculate objective value + double v = 0; + int nSV = 0; + for (i = 0; i < w_size; i++) + v += w[i] * w[i]; + v = 0.5 * v; + for (i = 0; i < l; i++) { + v += p * Math.abs(beta[i]) - y[i] * beta[i] + 0.5 * lambda[GETI_SVR(i)] * beta[i] * beta[i]; + if (beta[i] != 0) + nSV++; + } + + info("Objective value = %g%n", v); + info("nSV = %d%n", nSV); + } + + /** + * A coordinate descent algorithm for the dual of L2-regularized logistic + * regression problems + * + *
+	 *  min_\alpha  0.5(\alpha^T Q \alpha) + \sum \alpha_i log (\alpha_i) + (upper_bound_i - \alpha_i) log (upper_bound_i - \alpha_i) ,
+	 *     s.t.      0 <= \alpha_i <= upper_bound_i,
+	 * 
+	 *  where Qij = yi yj xi^T xj and
+	 *  upper_bound_i = Cp if y_i = 1
+	 *  upper_bound_i = Cn if y_i = -1
+	 * 
+	 * Given:
+	 * x, y, Cp, Cn
+	 * eps is the stopping tolerance
+	 * 
+	 * solution will be put in w
+	 * 
+	 * See Algorithm 5 of Yu et al., MLJ 2010
+	 * 
+ * + * @since 1.7 + */ + private static void solve_l2r_lr_dual(DenseProblem prob, double w[], double eps, double Cp, double Cn) { + final int l = prob.l; + final int w_size = prob.n; + int i, s, iter = 0; + final double xTx[] = new double[l]; + final int max_iter = 1000; + final int index[] = new int[l]; + final double alpha[] = new double[2 * l]; // store alpha and C - alpha + final byte y[] = new byte[l]; + final int max_inner_iter = 100; // for inner Newton + double innereps = 1e-2; + final double innereps_min = Math.min(1e-8, eps); + final double upper_bound[] = new double[] { Cn, 0, Cp }; + + for (i = 0; i < l; i++) { + if (prob.y[i] > 0) { + y[i] = +1; + } else { + y[i] = -1; + } + } + + // Initial alpha can be set here. Note that + // 0 < alpha[i] < upper_bound[GETI(i)] + // alpha[2*i] + alpha[2*i+1] = upper_bound[GETI(i)] + for (i = 0; i < l; i++) { + alpha[2 * i] = Math.min(0.001 * upper_bound[GETI(y, i)], 1e-8); + alpha[2 * i + 1] = upper_bound[GETI(y, i)] - alpha[2 * i]; + } + + for (i = 0; i < w_size; i++) + w[i] = 0; + for (i = 0; i < l; i++) { + xTx[i] = 0; + for (int j = 0; j < w_size; j++) { + final double val = prob.x[i][j]; + xTx[i] += val * val; + w[j] += y[i] * alpha[2 * i] * val; + } + index[i] = i; + } + + while (iter < max_iter) { + for (i = 0; i < l; i++) { + final int j = i + random.nextInt(l - i); + swap(index, i, j); + } + int newton_iter = 0; + double Gmax = 0; + for (s = 0; s < l; s++) { + i = index[s]; + final byte yi = y[i]; + final double C = upper_bound[GETI(y, i)]; + double ywTx = 0; + final double xisq = xTx[i]; + for (int j = 0; j < w_size; j++) { + ywTx += w[j] * prob.x[i][j]; + } + ywTx *= y[i]; + final double a = xisq, b = ywTx; + + // Decide to minimize g_1(z) or g_2(z) + int ind1 = 2 * i, ind2 = 2 * i + 1, sign = 1; + if (0.5 * a * (alpha[ind2] - alpha[ind1]) + b < 0) { + ind1 = 2 * i + 1; + ind2 = 2 * i; + sign = -1; + } + + // g_t(z) = z*log(z) + (C-z)*log(C-z) + 0.5a(z-alpha_old)^2 + + // sign*b(z-alpha_old) + final double alpha_old = alpha[ind1]; + double z = alpha_old; + if (C - z < 0.5 * C) + z = 0.1 * z; + double gp = a * (z - alpha_old) + sign * b + Math.log(z / (C - z)); + Gmax = Math.max(Gmax, Math.abs(gp)); + + // Newton method on the sub-problem + final double eta = 0.1; // xi in the paper + int inner_iter = 0; + while (inner_iter <= max_inner_iter) { + if (Math.abs(gp) < innereps) + break; + final double gpp = a + C / (C - z) / z; + final double tmpz = z - gp / gpp; + if (tmpz <= 0) + z *= eta; + else + // tmpz in (0, C) + z = tmpz; + gp = a * (z - alpha_old) + sign * b + Math.log(z / (C - z)); + newton_iter++; + inner_iter++; + } + + if (inner_iter > 0) // update w + { + alpha[ind1] = z; + alpha[ind2] = C - z; + for (int j = 0; j < w_size; j++) { + w[j] += sign * (z - alpha_old) * yi * prob.x[i][j]; + } + } + } + + iter++; + if (iter % 10 == 0) + info("."); + + if (Gmax < eps) + break; + + if (newton_iter <= l / 10) { + innereps = Math.max(innereps_min, 0.1 * innereps); + } + + } + + info("%noptimization finished, #iter = %d%n", iter); + if (iter >= max_iter) + info("%nWARNING: reaching max number of iterations%nUsing -s 0 may be faster (also see FAQ)%n%n"); + + // calculate objective value + + double v = 0; + for (i = 0; i < w_size; i++) + v += w[i] * w[i]; + v *= 0.5; + for (i = 0; i < l; i++) + v += alpha[2 * i] * Math.log(alpha[2 * i]) + alpha[2 * i + 1] * Math.log(alpha[2 * i + 1]) + - upper_bound[GETI(y, i)] + * Math.log(upper_bound[GETI(y, i)]); + info("Objective value = %g%n", v); + } + + /** + * A coordinate descent algorithm for L1-regularized L2-loss support vector + * classification + * + *
+	 *  min_w \sum |wj| + C \sum max(0, 1-yi w^T xi)^2,
+	 * 
+	 * Given:
+	 * x, y, Cp, Cn
+	 * eps is the stopping tolerance
+	 * 
+	 * solution will be put in w
+	 * 
+	 * See Yuan et al. (2010) and appendix of LIBLINEAR paper, Fan et al. (2008)
+	 * 
+ * + * @since 1.5 + */ + private static void solve_l1r_l2_svc(DenseProblem prob_col, double[] w, double eps, double Cp, double Cn) { + final int l = prob_col.l; + final int w_size = prob_col.n; + int j, s, iter = 0; + final int max_iter = 1000; + int active_size = w_size; + final int max_num_linesearch = 20; + + final double sigma = 0.01; + double d, G_loss, G, H; + double Gmax_old = Double.POSITIVE_INFINITY; + double Gmax_new, Gnorm1_new; + double Gnorm1_init = 0; // eclipse moans this variable might not be + // initialized + double d_old, d_diff; + double loss_old = 0; // eclipse moans this variable might not be + // initialized + double loss_new; + double appxcond, cond; + + final int[] index = new int[w_size]; + final byte[] y = new byte[l]; + final double[] b = new double[l]; // b = 1-ywTx + final double[] xj_sq = new double[w_size]; + + final double[] C = new double[] { Cn, 0, Cp }; + + // Initial w can be set here. + for (j = 0; j < w_size; j++) + w[j] = 0; + + for (j = 0; j < l; j++) { + b[j] = 1; + if (prob_col.y[j] > 0) + y[j] = 1; + else + y[j] = -1; + } + for (j = 0; j < w_size; j++) { + index[j] = j; + xj_sq[j] = 0; + for (int ind = 0; ind < w_size; ind++) { + prob_col.x[j][ind] = prob_col.x[j][ind] * y[ind]; // x->value + // stores + // yi*xij + final double val = prob_col.x[j][ind]; + b[ind] -= w[j] * val; + + xj_sq[j] += C[GETI(y, ind)] * val * val; + } + } + + while (iter < max_iter) { + Gmax_new = 0; + Gnorm1_new = 0; + + for (j = 0; j < active_size; j++) { + final int i = j + random.nextInt(active_size - j); + swap(index, i, j); + } + + for (s = 0; s < active_size; s++) { + j = index[s]; + G_loss = 0; + H = 0; + + for (int ind = 0; ind < w_size; ind++) { + if (b[ind] > 0) { + final double val = prob_col.x[j][ind]; + final double tmp = C[GETI(y, ind)] * val; + G_loss -= tmp * b[ind]; + H += tmp * val; + } + } + G_loss *= 2; + + G = G_loss; + H *= 2; + H = Math.max(H, 1e-12); + + final double Gp = G + 1; + final double Gn = G - 1; + double violation = 0; + if (w[j] == 0) { + if (Gp < 0) + violation = -Gp; + else if (Gn > 0) + violation = Gn; + else if (Gp > Gmax_old / l && Gn < -Gmax_old / l) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } + } else if (w[j] > 0) + violation = Math.abs(Gp); + else + violation = Math.abs(Gn); + + Gmax_new = Math.max(Gmax_new, violation); + Gnorm1_new += violation; + + // obtain Newton direction d + if (Gp < H * w[j]) + d = -Gp / H; + else if (Gn > H * w[j]) + d = -Gn / H; + else + d = -w[j]; + + if (Math.abs(d) < 1.0e-12) + continue; + + double delta = Math.abs(w[j] + d) - Math.abs(w[j]) + G * d; + d_old = 0; + int num_linesearch; + for (num_linesearch = 0; num_linesearch < max_num_linesearch; num_linesearch++) { + d_diff = d_old - d; + cond = Math.abs(w[j] + d) - Math.abs(w[j]) - sigma * delta; + + appxcond = xj_sq[j] * d * d + G_loss * d + cond; + if (appxcond <= 0) { + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + b[ind] += d_diff * prob_col.x[j][ind]; + } + break; + } + + if (num_linesearch == 0) { + loss_old = 0; + loss_new = 0; + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + if (b[ind] > 0) { + loss_old += C[GETI(y, ind)] * b[ind] * b[ind]; + } + final double b_new = b[ind] + d_diff * prob_col.x[j][ind]; + b[ind] = b_new; + if (b_new > 0) { + loss_new += C[GETI(y, ind)] * b_new * b_new; + } + } + } else { + loss_new = 0; + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + final double b_new = b[ind] + d_diff * prob_col.x[j][ind]; + b[ind] = b_new; + if (b_new > 0) { + loss_new += C[GETI(y, ind)] * b_new * b_new; + } + } + } + + cond = cond + loss_new - loss_old; + if (cond <= 0) + break; + else { + d_old = d; + d *= 0.5; + delta *= 0.5; + } + } + + w[j] += d; + + // recompute b[] if line search takes too many steps + if (num_linesearch >= max_num_linesearch) { + info("#"); + for (int i = 0; i < l; i++) + b[i] = 1; + + for (int i = 0; i < w_size; i++) { + if (w[i] == 0) + continue; + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + b[ind] -= w[i] * prob_col.x[j][ind]; + } + } + } + } + + if (iter == 0) { + Gnorm1_init = Gnorm1_new; + } + iter++; + if (iter % 10 == 0) + info("."); + + if (Gmax_new <= eps * Gnorm1_init) { + if (active_size == w_size) + break; + else { + active_size = w_size; + info("*"); + Gmax_old = Double.POSITIVE_INFINITY; + continue; + } + } + + Gmax_old = Gmax_new; + } + + info("%noptimization finished, #iter = %d%n", iter); + if (iter >= max_iter) + info("%nWARNING: reaching max number of iterations%n"); + + // calculate objective value + + double v = 0; + int nnz = 0; + for (j = 0; j < w_size; j++) { + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + prob_col.x[j][ind] = prob_col.x[j][ind] * prob_col.y[ind]; // restore + // x->value + } + if (w[j] != 0) { + v += Math.abs(w[j]); + nnz++; + } + } + for (j = 0; j < l; j++) + if (b[j] > 0) + v += C[GETI(y, j)] * b[j] * b[j]; + + info("Objective value = %g%n", v); + info("#nonzeros/#features = %d/%d%n", nnz, w_size); + } + + /** + * A coordinate descent algorithm for L1-regularized logistic regression + * problems + * + *
+	 *  min_w \sum |wj| + C \sum log(1+exp(-yi w^T xi)),
+	 * 
+	 * Given:
+	 * x, y, Cp, Cn
+	 * eps is the stopping tolerance
+	 * 
+	 * solution will be put in w
+	 * 
+	 * See Yuan et al. (2011) and appendix of LIBLINEAR paper, Fan et al. (2008)
+	 * 
+ * + * @since 1.5 + */ + private static void solve_l1r_lr(DenseProblem prob_col, double[] w, double eps, double Cp, double Cn) { + final int l = prob_col.l; + final int w_size = prob_col.n; + int j, s, newton_iter = 0, iter = 0; + final int max_newton_iter = 100; + final int max_iter = 1000; + final int max_num_linesearch = 20; + int active_size; + int QP_active_size; + + final double nu = 1e-12; + double inner_eps = 1; + final double sigma = 0.01; + double w_norm, w_norm_new; + double z, G, H; + double Gnorm1_init = 0; // eclipse moans this variable might not be + // initialized + double Gmax_old = Double.POSITIVE_INFINITY; + double Gmax_new, Gnorm1_new; + double QP_Gmax_old = Double.POSITIVE_INFINITY; + double QP_Gmax_new, QP_Gnorm1_new; + double delta, negsum_xTd, cond; + + final int[] index = new int[w_size]; + final byte[] y = new byte[l]; + final double[] Hdiag = new double[w_size]; + final double[] Grad = new double[w_size]; + final double[] wpd = new double[w_size]; + final double[] xjneg_sum = new double[w_size]; + final double[] xTd = new double[l]; + final double[] exp_wTx = new double[l]; + final double[] exp_wTx_new = new double[l]; + final double[] tau = new double[l]; + final double[] D = new double[l]; + + final double[] C = { Cn, 0, Cp }; + + // Initial w can be set here. + for (j = 0; j < w_size; j++) + w[j] = 0; + + for (j = 0; j < l; j++) { + if (prob_col.y[j] > 0) + y[j] = 1; + else + y[j] = -1; + + exp_wTx[j] = 0; + } + + w_norm = 0; + for (j = 0; j < w_size; j++) { + w_norm += Math.abs(w[j]); + wpd[j] = w[j]; + index[j] = j; + xjneg_sum[j] = 0; + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + final double val = prob_col.x[j][ind]; + exp_wTx[ind] += w[j] * val; + if (y[ind] == -1) { + xjneg_sum[j] += C[GETI(y, ind)] * val; + } + } + } + for (j = 0; j < l; j++) { + exp_wTx[j] = Math.exp(exp_wTx[j]); + final double tau_tmp = 1 / (1 + exp_wTx[j]); + tau[j] = C[GETI(y, j)] * tau_tmp; + D[j] = C[GETI(y, j)] * exp_wTx[j] * tau_tmp * tau_tmp; + } + + while (newton_iter < max_newton_iter) { + Gmax_new = 0; + Gnorm1_new = 0; + active_size = w_size; + + for (s = 0; s < active_size; s++) { + j = index[s]; + Hdiag[j] = nu; + Grad[j] = 0; + + double tmp = 0; + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + Hdiag[j] += prob_col.x[j][ind] * prob_col.x[j][ind] * D[ind]; + tmp += prob_col.x[j][ind] * tau[ind]; + } + Grad[j] = -tmp + xjneg_sum[j]; + + final double Gp = Grad[j] + 1; + final double Gn = Grad[j] - 1; + double violation = 0; + if (w[j] == 0) { + if (Gp < 0) + violation = -Gp; + else if (Gn > 0) + violation = Gn; + // outer-level shrinking + else if (Gp > Gmax_old / l && Gn < -Gmax_old / l) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } + } else if (w[j] > 0) + violation = Math.abs(Gp); + else + violation = Math.abs(Gn); + + Gmax_new = Math.max(Gmax_new, violation); + Gnorm1_new += violation; + } + + if (newton_iter == 0) + Gnorm1_init = Gnorm1_new; + + if (Gnorm1_new <= eps * Gnorm1_init) + break; + + iter = 0; + QP_Gmax_old = Double.POSITIVE_INFINITY; + QP_active_size = active_size; + + for (int i = 0; i < l; i++) + xTd[i] = 0; + + // optimize QP over wpd + while (iter < max_iter) { + QP_Gmax_new = 0; + QP_Gnorm1_new = 0; + + for (j = 0; j < QP_active_size; j++) { + final int i = random.nextInt(QP_active_size - j); + swap(index, i, j); + } + + for (s = 0; s < QP_active_size; s++) { + j = index[s]; + H = Hdiag[j]; + + G = Grad[j] + (wpd[j] - w[j]) * nu; + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + G += prob_col.x[j][ind] * D[ind] * xTd[ind]; + } + + final double Gp = G + 1; + final double Gn = G - 1; + double violation = 0; + if (wpd[j] == 0) { + if (Gp < 0) + violation = -Gp; + else if (Gn > 0) + violation = Gn; + // inner-level shrinking + else if (Gp > QP_Gmax_old / l && Gn < -QP_Gmax_old / l) { + QP_active_size--; + swap(index, s, QP_active_size); + s--; + continue; + } + } else if (wpd[j] > 0) + violation = Math.abs(Gp); + else + violation = Math.abs(Gn); + + QP_Gmax_new = Math.max(QP_Gmax_new, violation); + QP_Gnorm1_new += violation; + + // obtain solution of one-variable problem + if (Gp < H * wpd[j]) + z = -Gp / H; + else if (Gn > H * wpd[j]) + z = -Gn / H; + else + z = -wpd[j]; + + if (Math.abs(z) < 1.0e-12) + continue; + z = Math.min(Math.max(z, -10.0), 10.0); + + wpd[j] += z; + + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + xTd[ind] += prob_col.x[j][ind] * z; + } + } + + iter++; + + if (QP_Gnorm1_new <= inner_eps * Gnorm1_init) { + // inner stopping + if (QP_active_size == active_size) + break; + // active set reactivation + else { + QP_active_size = active_size; + QP_Gmax_old = Double.POSITIVE_INFINITY; + continue; + } + } + + QP_Gmax_old = QP_Gmax_new; + } + + if (iter >= max_iter) + info("WARNING: reaching max number of inner iterations%n"); + + delta = 0; + w_norm_new = 0; + for (j = 0; j < w_size; j++) { + delta += Grad[j] * (wpd[j] - w[j]); + if (wpd[j] != 0) + w_norm_new += Math.abs(wpd[j]); + } + delta += (w_norm_new - w_norm); + + negsum_xTd = 0; + for (int i = 0; i < l; i++) + if (y[i] == -1) + negsum_xTd += C[GETI(y, i)] * xTd[i]; + + int num_linesearch; + for (num_linesearch = 0; num_linesearch < max_num_linesearch; num_linesearch++) { + cond = w_norm_new - w_norm + negsum_xTd - sigma * delta; + + for (int i = 0; i < l; i++) { + final double exp_xTd = Math.exp(xTd[i]); + exp_wTx_new[i] = exp_wTx[i] * exp_xTd; + cond += C[GETI(y, i)] * Math.log((1 + exp_wTx_new[i]) / (exp_xTd + exp_wTx_new[i])); + } + + if (cond <= 0) { + w_norm = w_norm_new; + for (j = 0; j < w_size; j++) + w[j] = wpd[j]; + for (int i = 0; i < l; i++) { + exp_wTx[i] = exp_wTx_new[i]; + final double tau_tmp = 1 / (1 + exp_wTx[i]); + tau[i] = C[GETI(y, i)] * tau_tmp; + D[i] = C[GETI(y, i)] * exp_wTx[i] * tau_tmp * tau_tmp; + } + break; + } else { + w_norm_new = 0; + for (j = 0; j < w_size; j++) { + wpd[j] = (w[j] + wpd[j]) * 0.5; + if (wpd[j] != 0) + w_norm_new += Math.abs(wpd[j]); + } + delta *= 0.5; + negsum_xTd *= 0.5; + for (int i = 0; i < l; i++) + xTd[i] *= 0.5; + } + } + + // Recompute some info due to too many line search steps + if (num_linesearch >= max_num_linesearch) { + for (int i = 0; i < l; i++) + exp_wTx[i] = 0; + + for (int i = 0; i < w_size; i++) { + if (w[i] == 0) + continue; + for (int ind = 0; ind < prob_col.x[j].length; ind++) { + exp_wTx[ind] += w[i] * prob_col.x[i][ind]; + } + } + + for (int i = 0; i < l; i++) + exp_wTx[i] = Math.exp(exp_wTx[i]); + } + + if (iter == 1) + inner_eps *= 0.25; + + newton_iter++; + Gmax_old = Gmax_new; + + info("iter %3d #CD cycles %d%n", newton_iter, iter); + } + + info("=========================%n"); + info("optimization finished, #iter = %d%n", newton_iter); + if (newton_iter >= max_newton_iter) + info("WARNING: reaching max number of iterations%n"); + + // calculate objective value + + double v = 0; + int nnz = 0; + for (j = 0; j < w_size; j++) + if (w[j] != 0) { + v += Math.abs(w[j]); + nnz++; + } + for (j = 0; j < l; j++) + if (y[j] == 1) + v += C[GETI(y, j)] * Math.log(1 + 1 / exp_wTx[j]); + else + v += C[GETI(y, j)] * Math.log(1 + exp_wTx[j]); + + info("Objective value = %g%n", v); + info("#nonzeros/#features = %d/%d%n", nnz, w_size); + } + + // transpose matrix X from row format to column format + static DenseProblem transpose(DenseProblem prob) { + final int l = prob.l; + final int n = prob.n; + final DenseProblem prob_col = new DenseProblem(); + prob_col.l = l; + prob_col.n = n; + prob_col.y = new double[l]; + prob_col.x = new double[n][]; + + for (int i = 0; i < l; i++) + prob_col.y[i] = prob.y[i]; + + for (int i = 0; i < n; i++) { + prob_col.x[i] = new double[l]; + } + + for (int i = 0; i < l; i++) { + for (int j = 0; j < n; j++) { + prob_col.x[j][i] = prob.x[i][j]; + } + } + + return prob_col; + } + + static void swap(double[] array, int idxA, int idxB) { + final double temp = array[idxA]; + array[idxA] = array[idxB]; + array[idxB] = temp; + } + + static void swap(int[] array, int idxA, int idxB) { + final int temp = array[idxA]; + array[idxA] = array[idxB]; + array[idxB] = temp; + } + + static void swap(IntArrayPointer array, int idxA, int idxB) { + final int temp = array.get(idxA); + array.set(idxA, array.get(idxB)); + array.set(idxB, temp); + } + + /** + * @throws IllegalArgumentException + * if the feature nodes of prob are not sorted in ascending + * order + */ + public static Model train(DenseProblem prob, Parameter param) { + + if (prob == null) + throw new IllegalArgumentException("problem must not be null"); + if (param == null) + throw new IllegalArgumentException("parameter must not be null"); + + if (prob.n == 0) + throw new IllegalArgumentException("problem has zero features"); + if (prob.l == 0) + throw new IllegalArgumentException("problem has zero instances"); + + final int l = prob.l; + final int n = prob.n; + final int w_size = prob.n; + final Model model = new Model(); + + if (prob.bias >= 0) + model.nr_feature = n - 1; + else + model.nr_feature = n; + + model.solverType = param.solverType; + model.bias = prob.bias; + + if (param.solverType == SolverType.L2R_L2LOSS_SVR || // + param.solverType == SolverType.L2R_L1LOSS_SVR_DUAL || // + param.solverType == SolverType.L2R_L2LOSS_SVR_DUAL) + { + model.w = new double[w_size]; + model.nr_class = 2; + model.label = null; + + checkProblemSize(n, model.nr_class); + + train_one(prob, param, model.w, 0, 0); + } else { + final int[] perm = new int[l]; + + // group training data of the same class + final GroupClassesReturn rv = groupClasses(prob, perm); + final int nr_class = rv.nr_class; + final int[] label = rv.label; + final int[] start = rv.start; + final int[] count = rv.count; + + checkProblemSize(n, nr_class); + + model.nr_class = nr_class; + model.label = new int[nr_class]; + for (int i = 0; i < nr_class; i++) + model.label[i] = label[i]; + + // calculate weighted C + final double[] weighted_C = new double[nr_class]; + for (int i = 0; i < nr_class; i++) + weighted_C[i] = param.C; + for (int i = 0; i < param.getNumWeights(); i++) { + int j; + for (j = 0; j < nr_class; j++) + if (param.weightLabel[i] == label[j]) + break; + + if (j == nr_class) + throw new IllegalArgumentException("class label " + param.weightLabel[i] + + " specified in weight is not found"); + weighted_C[j] *= param.weight[i]; + } + + // constructing the subproblem + final double[][] x = new double[l][]; + for (int i = 0; i < l; i++) + x[i] = prob.x[perm[i]]; + + final DenseProblem sub_prob = new DenseProblem(); + sub_prob.l = l; + sub_prob.n = n; + sub_prob.x = new double[sub_prob.l][]; + sub_prob.y = new double[sub_prob.l]; + + for (int k = 0; k < sub_prob.l; k++) + sub_prob.x[k] = x[k]; + + // multi-class svm by Crammer and Singer + if (param.solverType == SolverType.MCSVM_CS) { + model.w = new double[n * nr_class]; + for (int i = 0; i < nr_class; i++) { + for (int j = start[i]; j < start[i] + count[i]; j++) { + sub_prob.y[j] = i; + } + } + + final DenseSolverMCSVM_CS solver = new DenseSolverMCSVM_CS(sub_prob, nr_class, weighted_C, param.eps); + solver.solve(model.w); + } else { + if (nr_class == 2) { + model.w = new double[w_size]; + + final int e0 = start[0] + count[0]; + int k = 0; + for (; k < e0; k++) + sub_prob.y[k] = +1; + for (; k < sub_prob.l; k++) + sub_prob.y[k] = -1; + + train_one(sub_prob, param, model.w, weighted_C[0], weighted_C[1]); + } else { + model.w = new double[w_size * nr_class]; + final double[] w = new double[w_size]; + for (int i = 0; i < nr_class; i++) { + final int si = start[i]; + final int ei = si + count[i]; + + int k = 0; + for (; k < si; k++) + sub_prob.y[k] = -1; + for (; k < ei; k++) + sub_prob.y[k] = +1; + for (; k < sub_prob.l; k++) + sub_prob.y[k] = -1; + + train_one(sub_prob, param, w, weighted_C[i], param.C); + + for (int j = 0; j < n; j++) + model.w[j * nr_class + i] = w[j]; + } + } + } + } + return model; + } + + /** + * verify the size and throw an exception early if the problem is too large + */ + private static void checkProblemSize(int n, int nr_class) { + if (n >= Integer.MAX_VALUE / nr_class || n * nr_class < 0) { + throw new IllegalArgumentException("'number of classes' * 'number of instances' is too large: " + nr_class + + "*" + n); + } + } + + private static void train_one(DenseProblem prob, Parameter param, double[] w, double Cp, double Cn) { + final double eps = param.eps; + int pos = 0; + for (int i = 0; i < prob.l; i++) + if (prob.y[i] > 0) { + pos++; + } + final int neg = prob.l - pos; + + final double primal_solver_tol = eps * Math.max(Math.min(pos, neg), 1) / prob.l; + + Function fun_obj = null; + switch (param.solverType) { + case L2R_LR: { + final double[] C = new double[prob.l]; + for (int i = 0; i < prob.l; i++) { + if (prob.y[i] > 0) + C[i] = Cp; + else + C[i] = Cn; + } + fun_obj = new DenseL2R_LrFunction(prob, C); + final Tron tron_obj = new Tron(fun_obj, primal_solver_tol); + tron_obj.tron(w); + break; + } + case L2R_L2LOSS_SVC: { + final double[] C = new double[prob.l]; + for (int i = 0; i < prob.l; i++) { + if (prob.y[i] > 0) + C[i] = Cp; + else + C[i] = Cn; + } + fun_obj = new DenseL2R_L2_SvcFunction(prob, C); + final Tron tron_obj = new Tron(fun_obj, primal_solver_tol); + tron_obj.tron(w); + break; + } + case L2R_L2LOSS_SVC_DUAL: + solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, SolverType.L2R_L2LOSS_SVC_DUAL); + break; + case L2R_L1LOSS_SVC_DUAL: + solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, SolverType.L2R_L1LOSS_SVC_DUAL); + break; + case L1R_L2LOSS_SVC: { + final DenseProblem prob_col = transpose(prob); + solve_l1r_l2_svc(prob_col, w, primal_solver_tol, Cp, Cn); + break; + } + case L1R_LR: { + final DenseProblem prob_col = transpose(prob); + solve_l1r_lr(prob_col, w, primal_solver_tol, Cp, Cn); + break; + } + case L2R_LR_DUAL: + solve_l2r_lr_dual(prob, w, eps, Cp, Cn); + break; + case L2R_L2LOSS_SVR: { + final double[] C = new double[prob.l]; + for (int i = 0; i < prob.l; i++) + C[i] = param.C; + + fun_obj = new DenseL2R_L2_SvrFunction(prob, C, param.p); + final Tron tron_obj = new Tron(fun_obj, param.eps); + tron_obj.tron(w); + break; + } + case L2R_L1LOSS_SVR_DUAL: + case L2R_L2LOSS_SVR_DUAL: + solve_l2r_l1l2_svr(prob, w, param); + break; + + default: + throw new IllegalStateException("unknown solver type: " + param.solverType); + } + } + + public static void disableDebugOutput() { + setDebugOutput(null); + } + + public static void enableDebugOutput() { + setDebugOutput(System.out); + } + + public static void setDebugOutput(PrintStream debugOutput) { + synchronized (OUTPUT_MUTEX) { + DEBUG_OUTPUT = debugOutput; + } + } + + /** + * resets the PRNG + * + * this is i.a. needed for regression testing (eg. the Weka wrapper) + */ + public static void resetRandom() { + random = new Random(DEFAULT_RANDOM_SEED); + } +} diff --git a/src/main/java/de/bwaldvogel/liblinear/DensePredict.java b/src/main/java/de/bwaldvogel/liblinear/DensePredict.java new file mode 100644 index 0000000..201d404 --- /dev/null +++ b/src/main/java/de/bwaldvogel/liblinear/DensePredict.java @@ -0,0 +1,194 @@ +package de.bwaldvogel.liblinear; + +import static de.bwaldvogel.liblinear.DenseLinear.atof; +import static de.bwaldvogel.liblinear.DenseLinear.atoi; +import static de.bwaldvogel.liblinear.DenseLinear.closeQuietly; +import static de.bwaldvogel.liblinear.DenseLinear.info; +import static de.bwaldvogel.liblinear.DenseLinear.printf; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.Writer; +import java.util.Formatter; +import java.util.NoSuchElementException; +import java.util.StringTokenizer; +import java.util.regex.Pattern; + +public class DensePredict { + + private static boolean flag_predict_probability = false; + + private static final Pattern COLON = Pattern.compile(":"); + + /** + *

+ * Note: The streams are NOT closed + *

+ */ + static void doPredict(BufferedReader reader, Writer writer, Model model) throws IOException { + int correct = 0; + int total = 0; + double error = 0; + double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0; + + final int nr_class = model.getNrClass(); + double[] prob_estimates = null; + int n; + final int nr_feature = model.getNrFeature(); + if (model.bias >= 0) + n = nr_feature + 1; + else + n = nr_feature; + + if (flag_predict_probability && !model.isProbabilityModel()) { + throw new IllegalArgumentException("probability output is only supported for logistic regression"); + } + + final Formatter out = new Formatter(writer); + + if (flag_predict_probability) { + final int[] labels = model.getLabels(); + prob_estimates = new double[nr_class]; + + printf(out, "labels"); + for (int j = 0; j < nr_class; j++) + printf(out, " %d", labels[j]); + printf(out, "\n"); + } + + String line = null; + while ((line = reader.readLine()) != null) { + final double[] nodes = new double[n]; + final StringTokenizer st = new StringTokenizer(line, " \t\n"); + double target_label; + try { + final String label = st.nextToken(); + target_label = atof(label); + } catch (final NoSuchElementException e) { + throw new RuntimeException("Wrong input format at line " + (total + 1), e); + } + + while (st.hasMoreTokens()) { + final String[] split = COLON.split(st.nextToken(), 2); + if (split == null || split.length < 2) { + throw new RuntimeException("Wrong input format at line " + (total + 1)); + } + + try { + final int idx = atoi(split[0]); + final double val = atof(split[1]); + + // feature indices larger than those in training are not + // used + if (idx <= nr_feature) { + nodes[idx - 1] = val; + } + } catch (final NumberFormatException e) { + throw new RuntimeException("Wrong input format at line " + (total + 1), e); + } + } + + if (model.bias >= 0) { + nodes[n - 1] = model.bias; + } + + double predict_label; + + if (flag_predict_probability) { + assert prob_estimates != null; + predict_label = DenseLinear.predictProbability(model, nodes, prob_estimates); + printf(out, "%g", predict_label); + for (int j = 0; j < model.nr_class; j++) + printf(out, " %g", prob_estimates[j]); + printf(out, "\n"); + } else { + predict_label = DenseLinear.predict(model, nodes); + printf(out, "%g\n", predict_label); + } + + if (predict_label == target_label) { + ++correct; + } + + error += (predict_label - target_label) * (predict_label - target_label); + sump += predict_label; + sumt += target_label; + sumpp += predict_label * predict_label; + sumtt += target_label * target_label; + sumpt += predict_label * target_label; + ++total; + } + + if (model.solverType.isSupportVectorRegression()) // + { + info("Mean squared error = %g (regression)%n", error / total); + info("Squared correlation coefficient = %g (regression)%n", // + ((total * sumpt - sump * sumt) * (total * sumpt - sump * sumt)) + / ((total * sumpp - sump * sump) * (total * sumtt - sumt * sumt))); + } else { + info("Accuracy = %g%% (%d/%d)%n", (double) correct / total * 100, correct, total); + } + } + + private static void exit_with_help() { + System.out + .printf("Usage: predict [options] test_file model_file output_file%n" // + + "options:%n" // + + "-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0); currently for logistic regression only%n" // + + "-q quiet mode (no outputs)%n"); + System.exit(1); + } + + public static void main(String[] argv) throws IOException { + int i; + + // parse options + for (i = 0; i < argv.length; i++) { + if (argv[i].charAt(0) != '-') + break; + ++i; + switch (argv[i - 1].charAt(1)) { + case 'b': + try { + flag_predict_probability = (atoi(argv[i]) != 0); + } catch (final NumberFormatException e) { + exit_with_help(); + } + break; + + case 'q': + i--; + DenseLinear.disableDebugOutput(); + break; + + default: + System.err.printf("unknown option: -%d%n", argv[i - 1].charAt(1)); + exit_with_help(); + break; + } + } + if (i >= argv.length || argv.length <= i + 2) { + exit_with_help(); + } + + BufferedReader reader = null; + Writer writer = null; + try { + reader = new BufferedReader(new InputStreamReader(new FileInputStream(argv[i]), DenseLinear.FILE_CHARSET)); + writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(argv[i + 2]), + DenseLinear.FILE_CHARSET)); + + final Model model = DenseLinear.loadModel(new File(argv[i + 1])); + doPredict(reader, writer, model); + } finally { + closeQuietly(reader); + closeQuietly(writer); + } + } +} diff --git a/src/main/java/de/bwaldvogel/liblinear/DenseProblem.java b/src/main/java/de/bwaldvogel/liblinear/DenseProblem.java new file mode 100644 index 0000000..1267eab --- /dev/null +++ b/src/main/java/de/bwaldvogel/liblinear/DenseProblem.java @@ -0,0 +1,62 @@ +package de.bwaldvogel.liblinear; + +import java.io.File; +import java.io.IOException; + +/** + *

+ * Describes the problem + *

+ * + * For example, if we have the following training data: + * + *
+ *  LABEL       ATTR1   ATTR2   ATTR3   ATTR4   ATTR5
+ *  -----       -----   -----   -----   -----   -----
+ *  1           0       0.1     0.2     0       0
+ *  2           0       0.1     0.3    -1.2     0
+ *  1           0.4     0       0       0       0
+ *  2           0       0.1     0       1.4     0.5
+ *  3          -0.1    -0.2     0.1     1.1     0.1
+ * 
+ *  and bias = 1, then the components of problem are:
+ * 
+ *  l = 5
+ *  n = 6
+ * 
+ *  y -> 1 2 1 2 3
+ * 
+ *  x -> [ ] -> (2,0.1) (3,0.2) (6,1) (-1,?)
+ *       [ ] -> (2,0.1) (3,0.3) (4,-1.2) (6,1) (-1,?)
+ *       [ ] -> (1,0.4) (6,1) (-1,?)
+ *       [ ] -> (2,0.1) (4,1.4) (5,0.5) (6,1) (-1,?)
+ *       [ ] -> (1,-0.1) (2,-0.2) (3,0.1) (4,1.1) (5,0.1) (6,1) (-1,?)
+ * 
+ */ +public class DenseProblem { + + /** the number of training data */ + public int l; + + /** the number of features (including the bias feature if bias >= 0) */ + public int n; + + /** an array containing the target values */ + public double[] y; + + /** dense array of features */ + public double[][] x; + + /** + * If bias >= 0, we assume that one additional feature is added to the + * end of each data instance + */ + public double bias; + + /** + * see {@link DenseTrain#readProblem(File, double)} + */ + public static DenseProblem readFromFile(File file, double bias) throws IOException, InvalidInputDataException { + return DenseTrain.readProblem(file, bias); + } +} diff --git a/src/main/java/de/bwaldvogel/liblinear/DenseSolverMCSVM_CS.java b/src/main/java/de/bwaldvogel/liblinear/DenseSolverMCSVM_CS.java new file mode 100644 index 0000000..b368e07 --- /dev/null +++ b/src/main/java/de/bwaldvogel/liblinear/DenseSolverMCSVM_CS.java @@ -0,0 +1,293 @@ +package de.bwaldvogel.liblinear; + +import static de.bwaldvogel.liblinear.DenseLinear.copyOf; +import static de.bwaldvogel.liblinear.DenseLinear.info; +import static de.bwaldvogel.liblinear.DenseLinear.swap; + +/** + * A coordinate descent algorithm for multi-class support vector machines by + * Crammer and Singer + * + *
+ * min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i
+ * s.t. \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i
+ * 
+ * where e^m_i = 0 if y_i = m,
+ * e^m_i = 1 if y_i != m,
+ * C^m_i = C if m = y_i,
+ * C^m_i = 0 if m != y_i,
+ * and w_m(\alpha) = \sum_i \alpha^m_i x_i
+ * 
+ * Given:
+ * x, y, C
+ * eps is the stopping tolerance
+ * 
+ * solution will be put in w
+ * 
+ * See Appendix of LIBLINEAR paper, Fan et al. (2008)
+ * 
+ */ +class DenseSolverMCSVM_CS { + + private final double[] B; + private final double[] C; + private final double eps; + private final double[] G; + private final int max_iter; + private final int w_size, l; + private final int nr_class; + private final DenseProblem prob; + + public DenseSolverMCSVM_CS(DenseProblem prob, int nr_class, double[] C) { + this(prob, nr_class, C, 0.1); + } + + public DenseSolverMCSVM_CS(DenseProblem prob, int nr_class, double[] C, double eps) { + this(prob, nr_class, C, eps, 100000); + } + + public DenseSolverMCSVM_CS(DenseProblem prob, int nr_class, double[] weighted_C, double eps, int max_iter) { + this.w_size = prob.n; + this.l = prob.l; + this.nr_class = nr_class; + this.eps = eps; + this.max_iter = max_iter; + this.prob = prob; + this.C = weighted_C; + this.B = new double[nr_class]; + this.G = new double[nr_class]; + } + + private int GETI(int i) { + return (int) prob.y[i]; + } + + private boolean be_shrunk(int i, int m, int yi, double alpha_i, double minG) { + double bound = 0; + if (m == yi) + bound = C[GETI(i)]; + if (alpha_i == bound && G[m] < minG) + return true; + return false; + } + + public void solve(double[] w) { + int i, m, s; + int iter = 0; + final double[] alpha = new double[l * nr_class]; + final double[] alpha_new = new double[nr_class]; + final int[] index = new int[l]; + final double[] QD = new double[l]; + final int[] d_ind = new int[nr_class]; + final double[] d_val = new double[nr_class]; + final int[] alpha_index = new int[nr_class * l]; + final int[] y_index = new int[l]; + int active_size = l; + final int[] active_size_i = new int[l]; + double eps_shrink = Math.max(10.0 * eps, 1.0); // stopping tolerance for + // shrinking + boolean start_from_all = true; + + // Initial alpha can be set here. Note that + // sum_m alpha[i*nr_class+m] = 0, for all i=1,...,l-1 + // alpha[i*nr_class+m] <= C[GETI(i)] if prob->y[i] == m + // alpha[i*nr_class+m] <= 0 if prob->y[i] != m + // If initial alpha isn't zero, uncomment the for loop below to + // initialize w + for (i = 0; i < l * nr_class; i++) + alpha[i] = 0; + + for (i = 0; i < w_size * nr_class; i++) + w[i] = 0; + for (i = 0; i < l; i++) { + for (m = 0; m < nr_class; m++) + alpha_index[i * nr_class + m] = m; + QD[i] = 0; + for (final double val : prob.x[i]) { + QD[i] += val * val; + + // Uncomment the for loop if initial alpha isn't zero + // for(m=0; mindex-1)*nr_class+m] += alpha[i*nr_class+m]*val; + } + active_size_i[i] = nr_class; + y_index[i] = (int) prob.y[i]; + index[i] = i; + } + + final DoubleArrayPointer alpha_i = new DoubleArrayPointer(alpha, 0); + final IntArrayPointer alpha_index_i = new IntArrayPointer(alpha_index, 0); + + while (iter < max_iter) { + double stopping = Double.NEGATIVE_INFINITY; + + for (i = 0; i < active_size; i++) { + // int j = i+rand()%(active_size-i); + final int j = i + DenseLinear.random.nextInt(active_size - i); + swap(index, i, j); + } + for (s = 0; s < active_size; s++) { + + i = index[s]; + final double Ai = QD[i]; + // double *alpha_i = &alpha[i*nr_class]; + alpha_i.setOffset(i * nr_class); + + // int *alpha_index_i = &alpha_index[i*nr_class]; + alpha_index_i.setOffset(i * nr_class); + + if (Ai > 0) { + for (m = 0; m < active_size_i[i]; m++) + G[m] = 1; + if (y_index[i] < active_size_i[i]) + G[y_index[i]] = 0; + + for (int ind = 0; ind < prob.x[i].length; ind++) { + // double *w_i = &w[ind*nr_class]; + final int w_offset = ind * nr_class; + for (m = 0; m < active_size_i[i]; m++) + // G[m] += w_i[alpha_index_i[m]]*(prob.x[i][ind); + G[m] += w[w_offset + alpha_index_i.get(m)] * prob.x[i][ind]; + + } + + double minG = Double.POSITIVE_INFINITY; + double maxG = Double.NEGATIVE_INFINITY; + for (m = 0; m < active_size_i[i]; m++) { + if (alpha_i.get(alpha_index_i.get(m)) < 0 && G[m] < minG) + minG = G[m]; + if (G[m] > maxG) + maxG = G[m]; + } + if (y_index[i] < active_size_i[i]) { + if (alpha_i.get((int) prob.y[i]) < C[GETI(i)] && G[y_index[i]] < minG) { + minG = G[y_index[i]]; + } + } + + for (m = 0; m < active_size_i[i]; m++) { + if (be_shrunk(i, m, y_index[i], alpha_i.get(alpha_index_i.get(m)), minG)) { + active_size_i[i]--; + while (active_size_i[i] > m) { + if (!be_shrunk(i, active_size_i[i], y_index[i], + alpha_i.get(alpha_index_i.get(active_size_i[i])), minG)) + { + swap(alpha_index_i, m, active_size_i[i]); + swap(G, m, active_size_i[i]); + if (y_index[i] == active_size_i[i]) + y_index[i] = m; + else if (y_index[i] == m) + y_index[i] = active_size_i[i]; + break; + } + active_size_i[i]--; + } + } + } + + if (active_size_i[i] <= 1) { + active_size--; + swap(index, s, active_size); + s--; + continue; + } + + if (maxG - minG <= 1e-12) + continue; + else + stopping = Math.max(maxG - minG, stopping); + + for (m = 0; m < active_size_i[i]; m++) + B[m] = G[m] - Ai * alpha_i.get(alpha_index_i.get(m)); + + solve_sub_problem(Ai, y_index[i], C[GETI(i)], active_size_i[i], alpha_new); + int nz_d = 0; + for (m = 0; m < active_size_i[i]; m++) { + final double d = alpha_new[m] - alpha_i.get(alpha_index_i.get(m)); + alpha_i.set(alpha_index_i.get(m), alpha_new[m]); + if (Math.abs(d) >= 1e-12) { + d_ind[nz_d] = alpha_index_i.get(m); + d_val[nz_d] = d; + nz_d++; + } + } + + for (int ind = 0; ind < prob.x[i].length; ind++) { + // double *w_i = &w[ind*nr_class]; + final int w_offset = ind * nr_class; + for (m = 0; m < nz_d; m++) { + w[w_offset + d_ind[m]] += d_val[m] * prob.x[i][ind]; + } + } + } + } + + iter++; + + if (iter % 10 == 0) { + info("."); + } + + if (stopping < eps_shrink) { + if (stopping < eps && start_from_all == true) + break; + else { + active_size = l; + for (i = 0; i < l; i++) + active_size_i[i] = nr_class; + info("*"); + eps_shrink = Math.max(eps_shrink / 2, eps); + start_from_all = true; + } + } else + start_from_all = false; + } + + info("%noptimization finished, #iter = %d%n", iter); + if (iter >= max_iter) + info("%nWARNING: reaching max number of iterations%n"); + + // calculate objective value + double v = 0; + int nSV = 0; + for (i = 0; i < w_size * nr_class; i++) + v += w[i] * w[i]; + v = 0.5 * v; + for (i = 0; i < l * nr_class; i++) { + v += alpha[i]; + if (Math.abs(alpha[i]) > 0) + nSV++; + } + for (i = 0; i < l; i++) + v -= alpha[i * nr_class + (int) prob.y[i]]; + info("Objective value = %f%n", v); + info("nSV = %d%n", nSV); + + } + + private void solve_sub_problem(double A_i, int yi, double C_yi, int active_i, double[] alpha_new) { + + int r; + assert active_i <= B.length; // no padding + final double[] D = copyOf(B, active_i); + // clone(D, B, active_i); + + if (yi < active_i) + D[yi] += A_i * C_yi; + + // qsort(D, active_i, sizeof(double), compare_double); + ArraySorter.reversedMergesort(D); + + double beta = D[0] - A_i * C_yi; + for (r = 1; r < active_i && beta < r * D[r]; r++) + beta += D[r]; + beta /= r; + + for (r = 0; r < active_i; r++) { + if (r == yi) + alpha_new[r] = Math.min(C_yi, (beta - B[r]) / A_i); + else + alpha_new[r] = Math.min(0.0, (beta - B[r]) / A_i); + } + } +} diff --git a/src/main/java/de/bwaldvogel/liblinear/DenseTrain.java b/src/main/java/de/bwaldvogel/liblinear/DenseTrain.java new file mode 100644 index 0000000..4e3da6b --- /dev/null +++ b/src/main/java/de/bwaldvogel/liblinear/DenseTrain.java @@ -0,0 +1,420 @@ +package de.bwaldvogel.liblinear; + +import static de.bwaldvogel.liblinear.DenseLinear.atof; +import static de.bwaldvogel.liblinear.DenseLinear.atoi; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.StringTokenizer; + +public class DenseTrain { + + public static void main(String[] args) throws IOException, InvalidInputDataException { + new DenseTrain().run(args); + } + + private double bias = 1; + private boolean cross_validation = false; + private String inputFilename; + private String modelFilename; + private int nr_fold; + private Parameter param = null; + private DenseProblem prob = null; + + private void do_cross_validation() { + + double total_error = 0; + double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; + final double[] target = new double[prob.l]; + + long start, stop; + start = System.currentTimeMillis(); + DenseLinear.crossValidation(prob, param, nr_fold, target); + stop = System.currentTimeMillis(); + System.out.println("time: " + (stop - start) + " ms"); + + if (param.solverType.isSupportVectorRegression()) { + for (int i = 0; i < prob.l; i++) { + final double y = prob.y[i]; + final double v = target[i]; + total_error += (v - y) * (v - y); + sumv += v; + sumy += y; + sumvv += v * v; + sumyy += y * y; + sumvy += v * y; + } + System.out.printf("Cross Validation Mean squared error = %g%n", total_error / prob.l); + System.out.printf("Cross Validation Squared correlation coefficient = %g%n", // + ((prob.l * sumvy - sumv * sumy) * (prob.l * sumvy - sumv * sumy)) + / ((prob.l * sumvv - sumv * sumv) * (prob.l * sumyy - sumy * sumy))); + } else { + int total_correct = 0; + for (int i = 0; i < prob.l; i++) + if (target[i] == prob.y[i]) + ++total_correct; + + System.out.printf("correct: %d%n", total_correct); + System.out.printf("Cross Validation Accuracy = %g%%%n", 100.0 * total_correct / prob.l); + } + } + + private void exit_with_help() { + System.out.printf("Usage: train [options] training_set_file [model_file]%n" // + + "options:%n" + + "-s type : set type of solver (default 1)%n" + + " for multi-class classification%n" + + " 0 -- L2-regularized logistic regression (primal)%n" + + " 1 -- L2-regularized L2-loss support vector classification (dual)%n" + + " 2 -- L2-regularized L2-loss support vector classification (primal)%n" + + " 3 -- L2-regularized L1-loss support vector classification (dual)%n" + + " 4 -- support vector classification by Crammer and Singer%n" + + " 5 -- L1-regularized L2-loss support vector classification%n" + + " 6 -- L1-regularized logistic regression%n" + + " 7 -- L2-regularized logistic regression (dual)%n" + + " for regression%n" + + " 11 -- L2-regularized L2-loss support vector regression (primal)%n" + + " 12 -- L2-regularized L2-loss support vector regression (dual)%n" + + " 13 -- L2-regularized L1-loss support vector regression (dual)%n" + + "-c cost : set the parameter C (default 1)%n" + + "-p epsilon : set the epsilon in loss function of SVR (default 0.1)%n" + + "-e epsilon : set tolerance of termination criterion%n" + + " -s 0 and 2%n" + " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,%n" + + " where f is the primal function and pos/neg are # of%n" + + " positive/negative data (default 0.01)%n" + " -s 11%n" + + " |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.001)%n" + + " -s 1, 3, 4 and 7%n" + " Dual maximal violation <= eps; similar to libsvm (default 0.1)%n" + + " -s 5 and 6%n" + + " |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,%n" + + " where f is the primal function (default 0.01)%n" + + " -s 12 and 13\n" + + " |f'(alpha)|_1 <= eps |f'(alpha0)|,\n" + + " where f is the dual function (default 0.1)\n" + + "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)%n" + + "-wi weight: weights adjust the parameter C of different classes (see README for details)%n" + + "-v n: n-fold cross validation mode%n" + + "-q : quiet mode (no outputs)%n"); + System.exit(1); + } + + DenseProblem getProblem() { + return prob; + } + + double getBias() { + return bias; + } + + Parameter getParameter() { + return param; + } + + void parse_command_line(String argv[]) { + int i; + + // eps: see setting below + param = new Parameter(SolverType.L2R_L2LOSS_SVC_DUAL, 1, Double.POSITIVE_INFINITY, 0.1); + // default values + bias = -1; + cross_validation = false; + + // parse options + for (i = 0; i < argv.length; i++) { + if (argv[i].charAt(0) != '-') + break; + if (++i >= argv.length) + exit_with_help(); + switch (argv[i - 1].charAt(1)) { + case 's': + param.solverType = SolverType.getById(atoi(argv[i])); + break; + case 'c': + param.setC(atof(argv[i])); + break; + case 'p': + param.setP(atof(argv[i])); + break; + case 'e': + param.setEps(atof(argv[i])); + break; + case 'B': + bias = atof(argv[i]); + break; + case 'w': + final int weightLabel = atoi(argv[i - 1].substring(2)); + final double weight = atof(argv[i]); + param.weightLabel = addToArray(param.weightLabel, weightLabel); + param.weight = addToArray(param.weight, weight); + break; + case 'v': + cross_validation = true; + nr_fold = atoi(argv[i]); + if (nr_fold < 2) { + System.err.println("n-fold cross validation: n must >= 2"); + exit_with_help(); + } + break; + case 'q': + i--; + DenseLinear.disableDebugOutput(); + break; + default: + System.err.println("unknown option"); + exit_with_help(); + } + } + + // determine filenames + + if (i >= argv.length) + exit_with_help(); + + inputFilename = argv[i]; + + if (i < argv.length - 1) + modelFilename = argv[i + 1]; + else { + int p = argv[i].lastIndexOf('/'); + ++p; // whew... + modelFilename = argv[i].substring(p) + ".model"; + } + + if (param.eps == Double.POSITIVE_INFINITY) { + switch (param.solverType) { + case L2R_LR: + case L2R_L2LOSS_SVC: + param.setEps(0.01); + break; + case L2R_L2LOSS_SVR: + param.setEps(0.001); + break; + case L2R_L2LOSS_SVC_DUAL: + case L2R_L1LOSS_SVC_DUAL: + case MCSVM_CS: + case L2R_LR_DUAL: + param.setEps(0.1); + break; + case L1R_L2LOSS_SVC: + case L1R_LR: + param.setEps(0.01); + break; + case L2R_L1LOSS_SVR_DUAL: + case L2R_L2LOSS_SVR_DUAL: + param.setEps(0.1); + break; + default: + throw new IllegalStateException("unknown solver type: " + param.solverType); + } + } + } + + /** + * reads a problem from LibSVM format + * + * @param file + * the SVM file + * @throws IOException + * obviously in case of any I/O exception ;) + * @throws InvalidInputDataException + * if the input file is not correctly formatted + */ + static int readProblemFeatureDim(File file) throws IOException, InvalidInputDataException { + final BufferedReader fp = new BufferedReader(new FileReader(file)); + int max_index = 0; + int lineNr = 0; + + try { + while (true) { + final String line = fp.readLine(); + if (line == null) + break; + lineNr++; + + final StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:"); + String token; + try { + token = st.nextToken(); + } catch (final NoSuchElementException e) { + throw new InvalidInputDataException("empty line", file, lineNr, e); + } + + final int m = st.countTokens() / 2; + + int indexBefore = 0; + for (int j = 0; j < m; j++) { + token = st.nextToken(); + int index; + try { + index = atoi(token); + } catch (final NumberFormatException e) { + throw new InvalidInputDataException("invalid index: " + token, file, lineNr, e); + } + + // assert that indices are valid and sorted + if (index < 0) + throw new InvalidInputDataException("invalid index: " + index, file, lineNr); + if (index <= indexBefore) + throw new InvalidInputDataException("indices must be sorted in ascending order", file, lineNr); + indexBefore = index; + + token = st.nextToken(); + + if (index > max_index) { + max_index = index; + } + } + } + + return max_index; + } finally { + fp.close(); + } + } + + /** + * reads a problem from LibSVM format + * + * @param file + * the SVM file + * @throws IOException + * obviously in case of any I/O exception ;) + * @throws InvalidInputDataException + * if the input file is not correctly formatted + */ + public static DenseProblem readProblem(File file, double bias) throws IOException, InvalidInputDataException { + final BufferedReader fp = new BufferedReader(new FileReader(file)); + final List vy = new ArrayList(); + final List vx = new ArrayList(); + + int lineNr = 0; + + final int w_size = readProblemFeatureDim(file); + + try { + while (true) { + final String line = fp.readLine(); + if (line == null) + break; + lineNr++; + + final StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:"); + String token; + try { + token = st.nextToken(); + } catch (final NoSuchElementException e) { + throw new InvalidInputDataException("empty line", file, lineNr, e); + } + + try { + vy.add(atof(token)); + } catch (final NumberFormatException e) { + throw new InvalidInputDataException("invalid label: " + token, file, lineNr, e); + } + + final int m = st.countTokens() / 2; + double[] x; + if (bias >= 0) { + x = new double[w_size + 1]; + } else { + x = new double[w_size]; + } + int indexBefore = 0; + for (int j = 0; j < m; j++) { + + token = st.nextToken(); + int index; + try { + index = atoi(token); + } catch (final NumberFormatException e) { + throw new InvalidInputDataException("invalid index: " + token, file, lineNr, e); + } + + // assert that indices are valid and sorted + if (index < 0) + throw new InvalidInputDataException("invalid index: " + index, file, lineNr); + if (index <= indexBefore) + throw new InvalidInputDataException("indices must be sorted in ascending order", file, lineNr); + indexBefore = index; + + token = st.nextToken(); + try { + final double value = atof(token); + x[index - 1] = value; + } catch (final NumberFormatException e) { + throw new InvalidInputDataException("invalid value: " + token, file, lineNr); + } + } + + vx.add(x); + } + + return constructProblem(vy, vx, w_size, bias); + } finally { + fp.close(); + } + } + + void readProblem(String filename) throws IOException, InvalidInputDataException { + prob = DenseTrain.readProblem(new File(filename), bias); + } + + private static int[] addToArray(int[] array, int newElement) { + final int length = array != null ? array.length : 0; + final int[] newArray = new int[length + 1]; + if (array != null && length > 0) { + System.arraycopy(array, 0, newArray, 0, length); + } + newArray[length] = newElement; + return newArray; + } + + private static double[] addToArray(double[] array, double newElement) { + final int length = array != null ? array.length : 0; + final double[] newArray = new double[length + 1]; + if (array != null && length > 0) { + System.arraycopy(array, 0, newArray, 0, length); + } + newArray[length] = newElement; + return newArray; + } + + private static DenseProblem constructProblem(List vy, List vx, int max_index, double bias) { + final DenseProblem prob = new DenseProblem(); + prob.bias = bias; + prob.l = vy.size(); + prob.n = max_index; + if (bias >= 0) { + prob.n++; + } + prob.x = new double[prob.l][]; + for (int i = 0; i < prob.l; i++) { + prob.x[i] = vx.get(i); + + if (bias >= 0) { + prob.x[i][max_index] = bias; + } + } + + prob.y = new double[prob.l]; + for (int i = 0; i < prob.l; i++) + prob.y[i] = vy.get(i).doubleValue(); + + return prob; + } + + private void run(String[] args) throws IOException, InvalidInputDataException { + parse_command_line(args); + readProblem(inputFilename); + if (cross_validation) + do_cross_validation(); + else { + final Model model = DenseLinear.train(prob, param); + DenseLinear.saveModel(new File(modelFilename), model); + } + } +} diff --git a/src/test/java/de/bwaldvogel/liblinear/DenseLinearTest.java b/src/test/java/de/bwaldvogel/liblinear/DenseLinearTest.java new file mode 100644 index 0000000..8d2e776 --- /dev/null +++ b/src/test/java/de/bwaldvogel/liblinear/DenseLinearTest.java @@ -0,0 +1,517 @@ +package de.bwaldvogel.liblinear; + +import static org.fest.assertions.Assertions.assertThat; +import static org.fest.assertions.Fail.fail; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.io.File; +import java.io.IOException; +import java.io.Writer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.TreeSet; + +import org.fest.assertions.Delta; +import org.junit.BeforeClass; +import org.junit.Test; +import org.powermock.api.mockito.PowerMockito; + +public class DenseLinearTest { + + private static Random random = new Random(12345); + + @BeforeClass + public static void disableDebugOutput() { + Linear.disableDebugOutput(); + } + + public static Model createRandomModel() { + final Model model = new Model(); + model.solverType = SolverType.L2R_LR; + model.bias = 2; + model.label = new int[] { 1, Integer.MAX_VALUE, 2 }; + model.w = new double[model.label.length * 300]; + for (int i = 0; i < model.w.length; i++) { + // precision should be at least 1e-4 + model.w[i] = Math.round(random.nextDouble() * 100000.0) / 10000.0; + } + + // force at least one value to be zero + model.w[random.nextInt(model.w.length)] = 0.0; + model.w[random.nextInt(model.w.length)] = -0.0; + + model.nr_feature = model.w.length / model.label.length - 1; + model.nr_class = model.label.length; + return model; + } + + public static DenseProblem createRandomProblem(int numClasses) { + final DenseProblem prob = new DenseProblem(); + prob.bias = -1; + prob.l = random.nextInt(100) + 1; + prob.n = random.nextInt(100) + 1; + prob.x = new double[prob.l][]; + prob.y = new double[prob.l]; + + for (int i = 0; i < prob.l; i++) { + + prob.y[i] = random.nextInt(numClasses); + + final Set randomNumbers = new TreeSet(); + final int num = random.nextInt(prob.n); + for (int j = 0; j < num; j++) { + randomNumbers.add(random.nextInt(prob.n)); + } + final List randomIndices = new ArrayList(randomNumbers); + Collections.sort(randomIndices); + + prob.x[i] = new double[prob.n]; + for (int j = 0; j < randomIndices.size(); j++) { + prob.x[i][randomIndices.get(j)] = random.nextDouble(); + } + } + return prob; + } + + /** + * create a very simple problem and check if the clearly separated examples + * are recognized as such + */ + @Test + public void testTrainPredict() { + final DenseProblem prob = new DenseProblem(); + prob.bias = -1; + prob.l = 4; + prob.n = 4; + prob.x = new double[4][4]; + + prob.x[0][0] = 1; + prob.x[0][1] = 1; + + prob.x[1][2] = 1; + prob.x[2][2] = 1; + + prob.x[3][0] = 2; + prob.x[3][1] = 1; + prob.x[3][3] = 1; + + prob.y = new double[4]; + prob.y[0] = 0; + prob.y[1] = 1; + prob.y[2] = 1; + prob.y[3] = 0; + + for (final SolverType solver : SolverType.values()) { + for (double C = 0.1; C <= 100.; C *= 1.2) { + // compared the behavior with the C version + if (C < 0.2) + if (solver == SolverType.L1R_L2LOSS_SVC) + continue; + if (C < 0.7) + if (solver == SolverType.L1R_LR) + continue; + + if (solver.isSupportVectorRegression()) { + continue; + } + + final Parameter param = new Parameter(solver, C, 0.1, 0.1); + final Model model = DenseLinear.train(prob, param); + + final double[] featureWeights = model.getFeatureWeights(); + if (solver == SolverType.MCSVM_CS) { + assertThat(featureWeights.length).isEqualTo(8); + } else { + assertThat(featureWeights.length).isEqualTo(4); + } + + int i = 0; + for (final double value : prob.y) { + final double prediction = DenseLinear.predict(model, prob.x[i]); + assertThat(prediction).as("prediction with solver " + solver).isEqualTo(value); + if (model.isProbabilityModel()) { + final double[] estimates = new double[model.getNrClass()]; + final double probabilityPrediction = DenseLinear.predictProbability(model, prob.x[i], estimates); + assertThat(probabilityPrediction).isEqualTo(prediction); + assertThat(estimates[(int) probabilityPrediction]).isGreaterThanOrEqualTo( + 1.0 / model.getNrClass()); + double estimationSum = 0; + for (final double estimate : estimates) { + estimationSum += estimate; + } + assertThat(estimationSum).isEqualTo(1.0, Delta.delta(0.001)); + } + i++; + } + } + } + } + + @Test + public void testCrossValidation() throws Exception { + + final int numClasses = random.nextInt(10) + 1; + + final DenseProblem prob = createRandomProblem(numClasses); + + final Parameter param = new Parameter(SolverType.L2R_LR, 10, 0.01); + final int nr_fold = 10; + final double[] target = new double[prob.l]; + DenseLinear.crossValidation(prob, param, nr_fold, target); + + for (final double clazz : target) { + assertThat(clazz).isGreaterThanOrEqualTo(0).isLessThan(numClasses); + } + } + + @Test + public void testLoadSaveModel() throws Exception { + + Model model = null; + for (final SolverType solverType : SolverType.values()) { + model = createRandomModel(); + model.solverType = solverType; + + final File tempFile = File.createTempFile("liblinear", "modeltest"); + tempFile.deleteOnExit(); + DenseLinear.saveModel(tempFile, model); + + final Model loadedModel = DenseLinear.loadModel(tempFile); + assertThat(loadedModel).isEqualTo(model); + } + } + + @Test + public void testPredictProbabilityWrongSolver() throws Exception { + final DenseProblem prob = new DenseProblem(); + prob.l = 1; + prob.n = 1; + prob.x = new double[prob.l][prob.n]; + prob.y = new double[prob.l]; + for (int i = 0; i < prob.l; i++) { + prob.y[i] = i; + } + + final SolverType solverType = SolverType.L2R_L1LOSS_SVC_DUAL; + final Parameter param = new Parameter(solverType, 10, 0.1); + final Model model = DenseLinear.train(prob, param); + try { + DenseLinear.predictProbability(model, prob.x[0], new double[1]); + fail("IllegalArgumentException expected"); + } catch (final IllegalArgumentException e) { + assertThat(e.getMessage()).isEqualTo("probability output is only supported for logistic regression." // + + " This is currently only supported by the following solvers:" // + + " L2R_LR, L1R_LR, L2R_LR_DUAL"); + } + } + + @Test + public void testRealloc() { + + int[] f = new int[] { 1, 2, 3 }; + f = DenseLinear.copyOf(f, 5); + f[3] = 4; + f[4] = 5; + assertThat(f).isEqualTo(new int[] { 1, 2, 3, 4, 5 }); + } + + @Test + public void testAtoi() { + assertThat(DenseLinear.atoi("+25")).isEqualTo(25); + assertThat(DenseLinear.atoi("-345345")).isEqualTo(-345345); + assertThat(DenseLinear.atoi("+0")).isEqualTo(0); + assertThat(DenseLinear.atoi("0")).isEqualTo(0); + assertThat(DenseLinear.atoi("2147483647")).isEqualTo(Integer.MAX_VALUE); + assertThat(DenseLinear.atoi("-2147483648")).isEqualTo(Integer.MIN_VALUE); + } + + @Test(expected = NumberFormatException.class) + public void testAtoiInvalidData() { + DenseLinear.atoi("+"); + } + + @Test(expected = NumberFormatException.class) + public void testAtoiInvalidData2() { + DenseLinear.atoi("abc"); + } + + @Test(expected = NumberFormatException.class) + public void testAtoiInvalidData3() { + DenseLinear.atoi(" "); + } + + @Test + public void testAtof() { + assertThat(DenseLinear.atof("+25")).isEqualTo(25); + assertThat(DenseLinear.atof("-25.12345678")).isEqualTo(-25.12345678); + assertThat(DenseLinear.atof("0.345345299")).isEqualTo(0.345345299); + } + + @Test(expected = NumberFormatException.class) + public void testAtofInvalidData() { + DenseLinear.atof("0.5t"); + } + + @Test + public void testSaveModelWithIOException() throws Exception { + final Model model = createRandomModel(); + + final Writer out = PowerMockito.mock(Writer.class); + + final IOException ioException = new IOException("some reason"); + + doThrow(ioException).when(out).flush(); + + try { + DenseLinear.saveModel(out, model); + fail("IOException expected"); + } catch (final IOException e) { + assertThat(e).isEqualTo(ioException); + } + + verify(out).flush(); + verify(out, times(1)).close(); + } + + /** + * compared input/output values with the C version (1.51) + * + *
+	 * IN:
+	 * res prob.l = 4
+	 * res prob.n = 4
+	 * 0: (2,1) (4,1)
+	 * 1: (1,1)
+	 * 2: (3,1)
+	 * 3: (2,2) (3,1) (4,1)
+	 * 
+	 * TRANSPOSED:
+	 * 
+	 * res prob.l = 4
+	 * res prob.n = 4
+	 * 0: (2,1)
+	 * 1: (1,1) (4,2)
+	 * 2: (3,1) (4,1)
+	 * 3: (1,1) (4,1)
+	 * 
+ */ + @Test + public void testTranspose() throws Exception { + final DenseProblem prob = new DenseProblem(); + prob.bias = -1; + prob.l = 4; + prob.n = 4; + prob.x = new double[4][4]; + + prob.x[0][1] = 1; + prob.x[0][3] = 1; + + prob.x[1][0] = 1; + prob.x[2][2] = 1; + + prob.x[3][1] = 2; + prob.x[3][2] = 1; + prob.x[3][3] = 1; + + prob.y = new double[4]; + prob.y[0] = 0; + prob.y[1] = 1; + prob.y[2] = 1; + prob.y[3] = 0; + + final DenseProblem transposed = DenseLinear.transpose(prob); + + assertThat(transposed.x[0].length).isEqualTo(4); + assertThat(transposed.x[1].length).isEqualTo(4); + assertThat(transposed.x[2].length).isEqualTo(4); + assertThat(transposed.x[3].length).isEqualTo(4); + + assertThat(transposed.x[0][1]).isEqualTo(1); + + assertThat(transposed.x[1][0]).isEqualTo(1); + assertThat(transposed.x[1][3]).isEqualTo(2); + + assertThat(transposed.x[2][2]).isEqualTo(1); + assertThat(transposed.x[2][3]).isEqualTo(1); + + assertThat(transposed.x[3][0]).isEqualTo(1); + assertThat(transposed.x[3][3]).isEqualTo(1); + + assertThat(transposed.y).isEqualTo(prob.y); + } + + /** + * + * compared input/output values with the C version (1.51) + * + *
+	 * IN:
+	 * res prob.l = 5
+	 * res prob.n = 10
+	 * 0: (1,7) (3,3) (5,2)
+	 * 1: (2,1) (4,5) (5,3) (7,4) (8,2)
+	 * 2: (1,9) (3,1) (5,1) (10,7)
+	 * 3: (1,2) (2,2) (3,9) (4,7) (5,8) (6,1) (7,5) (8,4)
+	 * 4: (3,1) (10,3)
+	 * 
+	 * TRANSPOSED:
+	 * 
+	 * res prob.l = 5
+	 * res prob.n = 10
+	 * 0: (1,7) (3,9) (4,2)
+	 * 1: (2,1) (4,2)
+	 * 2: (1,3) (3,1) (4,9) (5,1)
+	 * 3: (2,5) (4,7)
+	 * 4: (1,2) (2,3) (3,1) (4,8)
+	 * 5: (4,1)
+	 * 6: (2,4) (4,5)
+	 * 7: (2,2) (4,4)
+	 * 8:
+	 * 9: (3,7) (5,3)
+	 * 
+ */ + @Test + public void testTranspose2() throws Exception { + final DenseProblem prob = new DenseProblem(); + prob.bias = -1; + prob.l = 5; + prob.n = 10; + prob.x = new double[5][10]; + + prob.x[0][0] = 7; + prob.x[0][2] = 3; + prob.x[0][4] = 2; + + prob.x[1][1] = 1; + prob.x[1][3] = 5; + prob.x[1][4] = 3; + prob.x[1][6] = 4; + prob.x[1][7] = 2; + + prob.x[2][0] = 9; + prob.x[2][2] = 1; + prob.x[2][4] = 1; + prob.x[2][9] = 7; + + prob.x[3][0] = 2; + prob.x[3][1] = 2; + prob.x[3][2] = 9; + prob.x[3][3] = 7; + prob.x[3][4] = 8; + prob.x[3][5] = 1; + prob.x[3][6] = 5; + prob.x[3][7] = 4; + + prob.x[4][2] = 1; + prob.x[4][9] = 3; + + prob.y = new double[5]; + prob.y[0] = 0; + prob.y[1] = 1; + prob.y[2] = 1; + prob.y[3] = 0; + prob.y[4] = 1; + + final DenseProblem transposed = DenseLinear.transpose(prob); + + assertThat(transposed.x[0]).hasSize(5); + assertThat(transposed.x[1]).hasSize(5); + assertThat(transposed.x[2]).hasSize(5); + assertThat(transposed.x[3]).hasSize(5); + assertThat(transposed.x[4]).hasSize(5); + assertThat(transposed.x[5]).hasSize(5); + assertThat(transposed.x[7]).hasSize(5); + assertThat(transposed.x[7]).hasSize(5); + assertThat(transposed.x[8]).hasSize(5); + assertThat(transposed.x[9]).hasSize(5); + + assertThat(transposed.x[0][0]).isEqualTo(7); + assertThat(transposed.x[0][2]).isEqualTo(9); + assertThat(transposed.x[0][3]).isEqualTo(2); + + assertThat(transposed.x[1][1]).isEqualTo(1); + assertThat(transposed.x[1][3]).isEqualTo(2); + + assertThat(transposed.x[2][0]).isEqualTo(3); + assertThat(transposed.x[2][2]).isEqualTo(1); + assertThat(transposed.x[2][3]).isEqualTo(9); + assertThat(transposed.x[2][4]).isEqualTo(1); + + assertThat(transposed.x[3][1]).isEqualTo(5); + assertThat(transposed.x[3][3]).isEqualTo(7); + + assertThat(transposed.x[4][0]).isEqualTo(2); + assertThat(transposed.x[4][1]).isEqualTo(3); + assertThat(transposed.x[4][2]).isEqualTo(1); + assertThat(transposed.x[4][3]).isEqualTo(8); + + assertThat(transposed.x[5][3]).isEqualTo(1); + + assertThat(transposed.x[6][1]).isEqualTo(4); + assertThat(transposed.x[6][3]).isEqualTo(5); + + assertThat(transposed.x[7][1]).isEqualTo(2); + assertThat(transposed.x[7][3]).isEqualTo(4); + + assertThat(transposed.x[9][2]).isEqualTo(7); + assertThat(transposed.x[9][4]).isEqualTo(3); + + assertThat(transposed.y).isEqualTo(prob.y); + } + + /** + * compared input/output values with the C version (1.51) + * + * IN: res prob.l = 3 res prob.n = 4 0: (1,2) (3,1) (4,3) 1: (1,9) (2,7) + * (3,3) (4,3) 2: (2,1) + * + * TRANSPOSED: + * + * res prob.l = 3 * res prob.n = 4 0: (1,2) (2,9) 1: (2,7) (3,1) 2: (1,1) + * (2,3) 3: (1,3) (2,3) + * + */ + @Test + public void testTranspose3() throws Exception { + + final DenseProblem prob = new DenseProblem(); + prob.l = 3; + prob.n = 4; + prob.y = new double[3]; + prob.x = new double[3][4]; + + prob.x[0][0] = 2; + prob.x[0][2] = 1; + prob.x[0][3] = 3; + prob.x[1][0] = 9; + prob.x[1][1] = 7; + prob.x[1][2] = 3; + prob.x[1][3] = 3; + + prob.x[2][1] = 1; + + final DenseProblem transposed = DenseLinear.transpose(prob); + assertThat(transposed.x).hasSize(4); + assertThat(transposed.x[0]).hasSize(3); + assertThat(transposed.x[1]).hasSize(3); + assertThat(transposed.x[2]).hasSize(3); + assertThat(transposed.x[3]).hasSize(3); + + assertThat(transposed.x[0][0]).isEqualTo(2); + assertThat(transposed.x[0][1]).isEqualTo(9); + + assertThat(transposed.x[1][1]).isEqualTo(7); + assertThat(transposed.x[1][2]).isEqualTo(1); + + assertThat(transposed.x[2][0]).isEqualTo(1); + assertThat(transposed.x[2][1]).isEqualTo(3); + + assertThat(transposed.x[3][0]).isEqualTo(3); + assertThat(transposed.x[3][1]).isEqualTo(3); + } +} diff --git a/src/test/java/de/bwaldvogel/liblinear/DenseTrainTest.java b/src/test/java/de/bwaldvogel/liblinear/DenseTrainTest.java new file mode 100644 index 0000000..0f94bae --- /dev/null +++ b/src/test/java/de/bwaldvogel/liblinear/DenseTrainTest.java @@ -0,0 +1,213 @@ +package de.bwaldvogel.liblinear; + +import static org.fest.assertions.Assertions.assertThat; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.util.ArrayList; +import java.util.Collection; + +import org.junit.Test; + +import de.bwaldvogel.liblinear.DenseProblem; +import de.bwaldvogel.liblinear.DenseTrain; + +public class DenseTrainTest { + + @Test + public void testParseCommandLine() { + final DenseTrain train = new DenseTrain(); + + for (final SolverType solver : SolverType.values()) { + train.parse_command_line(new String[] { "-B", "5.3", "-s", "" + solver.getId(), "-p", "0.01", "model-filename" }); + final Parameter param = train.getParameter(); + assertThat(param.solverType).isEqualTo(solver); + // check default eps + if (solver.getId() == 0 || solver.getId() == 2 // + || solver.getId() == 5 || solver.getId() == 6) + { + assertThat(param.eps).isEqualTo(0.01); + } else if (solver.getId() == 7) { + assertThat(param.eps).isEqualTo(0.1); + } else if (solver.getId() == 11) { + assertThat(param.eps).isEqualTo(0.001); + } else { + assertThat(param.eps).isEqualTo(0.1); + } + // check if bias is set + assertThat(train.getBias()).isEqualTo(5.3); + assertThat(param.p).isEqualTo(0.01); + } + } + + @Test + // https://github.com/bwaldvogel/liblinear-java/issues/4 + public void + testParseWeights() throws Exception + { + final DenseTrain train = new DenseTrain(); + train.parse_command_line(new String[] { "-v", "10", "-c", "10", "-w1", "1.234", "model-filename" }); + Parameter parameter = train.getParameter(); + assertThat(parameter.weightLabel).isEqualTo(new int[] { 1 }); + assertThat(parameter.weight).isEqualTo(new double[] { 1.234 }); + + train.parse_command_line(new String[] { "-w1", "1.234", "-w2", "0.12", "-w3", "7", "model-filename" }); + parameter = train.getParameter(); + assertThat(parameter.weightLabel).isEqualTo(new int[] { 1, 2, 3 }); + assertThat(parameter.weight).isEqualTo(new double[] { 1.234, 0.12, 7 }); + } + + @Test + public void testReadProblem() throws Exception { + + final File file = File.createTempFile("svm", "test"); + file.deleteOnExit(); + + final Collection lines = new ArrayList(); + lines.add("1 1:1 3:1 4:1 6:1"); + lines.add("2 2:1 3:1 5:1 7:1"); + lines.add("1 3:1 5:1"); + lines.add("1 1:1 4:1 7:1"); + lines.add("2 4:1 5:1 7:1"); + final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); + try { + for (final String line : lines) + writer.append(line).append("\n"); + } finally { + writer.close(); + } + + final DenseTrain train = new DenseTrain(); + train.readProblem(file.getAbsolutePath()); + + final DenseProblem prob = train.getProblem(); + assertThat(prob.bias).isEqualTo(1); + assertThat(prob.y).hasSize(lines.size()); + assertThat(prob.y).isEqualTo(new double[] { 1, 2, 1, 1, 2 }); + assertThat(prob.n).isEqualTo(8); + assertThat(prob.l).isEqualTo(prob.y.length); + assertThat(prob.x).hasSize(prob.y.length); + + for (final double[] nodes : prob.x) { + + assertThat(nodes.length).isLessThanOrEqualTo(prob.n); + for (int ind = 0; ind < prob.n; ind++) { + // bias term + if (prob.bias >= 0 && ind == prob.n - 1) { + // assertThat(ind).isEqualTo(prob.n); + assertThat(nodes[ind]).isEqualTo(prob.bias); + } else { + assertThat(ind).isLessThan(prob.n); + } + } + } + } + + /** + * unit-test for Issue #1 + * (http://github.com/bwaldvogel/liblinear-java/issues#issue/1) + */ + @Test + public void testReadProblemEmptyLine() throws Exception { + + final File file = File.createTempFile("svm", "test"); + file.deleteOnExit(); + + final Collection lines = new ArrayList(); + lines.add("1 1:1 3:1 4:1 6:1"); + lines.add("2 "); + final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); + try { + for (final String line : lines) + writer.append(line).append("\n"); + } finally { + writer.close(); + } + + final DenseProblem prob = DenseTrain.readProblem(file, -1.0); + assertThat(prob.bias).isEqualTo(-1); + assertThat(prob.y).hasSize(lines.size()); + assertThat(prob.y).isEqualTo(new double[] { 1, 2 }); + assertThat(prob.n).isEqualTo(6); + assertThat(prob.l).isEqualTo(prob.y.length); + assertThat(prob.x).hasSize(prob.y.length); + + assertThat(prob.x[0]).hasSize(6); + assertThat(prob.x[1]).hasSize(6); + } + + @Test(expected = InvalidInputDataException.class) + public void testReadUnsortedProblem() throws Exception { + final File file = File.createTempFile("svm", "test"); + file.deleteOnExit(); + + final Collection lines = new ArrayList(); + lines.add("1 1:1 3:1 4:1 6:1"); + lines.add("2 2:1 3:1 5:1 7:1"); + lines.add("1 3:1 5:1 4:1"); // here's the mistake: not correctly + // sorted + + final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); + try { + for (final String line : lines) + writer.append(line).append("\n"); + } finally { + writer.close(); + } + + final DenseTrain train = new DenseTrain(); + train.readProblem(file.getAbsolutePath()); + } + + @Test(expected = InvalidInputDataException.class) + public void testReadProblemWithInvalidIndex() throws Exception { + final File file = File.createTempFile("svm", "test"); + file.deleteOnExit(); + + final Collection lines = new ArrayList(); + lines.add("1 1:1 3:1 4:1 6:1"); + lines.add("2 2:1 3:1 5:1 -4:1"); + + final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); + try { + for (final String line : lines) + writer.append(line).append("\n"); + } finally { + writer.close(); + } + + final DenseTrain train = new DenseTrain(); + try { + train.readProblem(file.getAbsolutePath()); + } catch (final InvalidInputDataException e) { + throw e; + } + } + + @Test(expected = InvalidInputDataException.class) + public void testReadWrongProblem() throws Exception { + final File file = File.createTempFile("svm", "test"); + file.deleteOnExit(); + + final Collection lines = new ArrayList(); + lines.add("1 1:1 3:1 4:1 6:1"); + lines.add("2 2:1 3:1 5:1 7:1"); + lines.add("1 3:1 5:a"); // here's the mistake: incomplete line + + final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); + try { + for (final String line : lines) + writer.append(line).append("\n"); + } finally { + writer.close(); + } + + final DenseTrain train = new DenseTrain(); + try { + train.readProblem(file.getAbsolutePath()); + } catch (final InvalidInputDataException e) { + throw e; + } + } +} From 5825fbffba53792b284d46cca1f25ed3f4812041 Mon Sep 17 00:00:00 2001 From: Jonathon Hare Date: Wed, 17 Jul 2013 16:20:36 +0100 Subject: [PATCH 3/4] refactoring to reduce code duplication --- .../denseliblinear/ArraySorter.java | 91 - .../denseliblinear/DoubleArrayPointer.java | 27 - .../bwaldvogel/denseliblinear/Function.java | 13 - .../denseliblinear/IntArrayPointer.java | 27 - .../InvalidInputDataException.java | 57 - .../denseliblinear/L2R_L2_SvcFunction.java | 117 - .../denseliblinear/L2R_L2_SvrFunction.java | 67 - .../denseliblinear/L2R_LrFunction.java | 108 - .../de/bwaldvogel/denseliblinear/Linear.java | 1912 ----------------- .../de/bwaldvogel/denseliblinear/Model.java | 178 -- .../bwaldvogel/denseliblinear/Parameter.java | 120 -- .../de/bwaldvogel/denseliblinear/Predict.java | 193 -- .../de/bwaldvogel/denseliblinear/Problem.java | 62 - .../denseliblinear/SolverMCSVM_CS.java | 293 --- .../bwaldvogel/denseliblinear/SolverType.java | 129 -- .../de/bwaldvogel/denseliblinear/Train.java | 420 ---- .../de/bwaldvogel/denseliblinear/Tron.java | 260 --- .../denseliblinear/ArrayPointerTest.java | 63 - .../denseliblinear/ArraySorterTest.java | 58 - .../bwaldvogel/denseliblinear/LinearTest.java | 517 ----- .../denseliblinear/ParameterTest.java | 127 -- .../denseliblinear/PredictTest.java | 57 - .../bwaldvogel/denseliblinear/TrainTest.java | 210 -- 23 files changed, 5106 deletions(-) delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/ArraySorter.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/DoubleArrayPointer.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Function.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/IntArrayPointer.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/InvalidInputDataException.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvcFunction.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvrFunction.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/L2R_LrFunction.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Linear.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Model.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Parameter.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Predict.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Problem.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/SolverMCSVM_CS.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/SolverType.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Train.java delete mode 100644 src/main/java/de/bwaldvogel/denseliblinear/Tron.java delete mode 100644 src/test/java/de/bwaldvogel/denseliblinear/ArrayPointerTest.java delete mode 100644 src/test/java/de/bwaldvogel/denseliblinear/ArraySorterTest.java delete mode 100644 src/test/java/de/bwaldvogel/denseliblinear/LinearTest.java delete mode 100644 src/test/java/de/bwaldvogel/denseliblinear/ParameterTest.java delete mode 100644 src/test/java/de/bwaldvogel/denseliblinear/PredictTest.java delete mode 100644 src/test/java/de/bwaldvogel/denseliblinear/TrainTest.java diff --git a/src/main/java/de/bwaldvogel/denseliblinear/ArraySorter.java b/src/main/java/de/bwaldvogel/denseliblinear/ArraySorter.java deleted file mode 100644 index 06e8e50..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/ArraySorter.java +++ /dev/null @@ -1,91 +0,0 @@ -package de.bwaldvogel.denseliblinear; - - -final class ArraySorter { - - /** - *

Sorts the specified array of doubles into descending order.

- * - * This code is borrowed from Sun's JDK 1.6.0.07 - */ - public static void reversedMergesort(double[] a) { - reversedMergesort(a, 0, a.length); - } - - private static void reversedMergesort(double x[], int off, int len) { - // Insertion sort on smallest arrays - if (len < 7) { - for (int i = off; i < len + off; i++) - for (int j = i; j > off && x[j - 1] < x[j]; j--) - swap(x, j, j - 1); - return; - } - - // Choose a partition element, v - int m = off + (len >> 1); // Small arrays, middle element - if (len > 7) { - int l = off; - int n = off + len - 1; - if (len > 40) { // Big arrays, pseudomedian of 9 - int s = len / 8; - l = med3(x, l, l + s, l + 2 * s); - m = med3(x, m - s, m, m + s); - n = med3(x, n - 2 * s, n - s, n); - } - m = med3(x, l, m, n); // Mid-size, med of 3 - } - double v = x[m]; - - // Establish Invariant: v* (v)* v* - int a = off, b = a, c = off + len - 1, d = c; - while (true) { - while (b <= c && x[b] >= v) { - if (x[b] == v) swap(x, a++, b); - b++; - } - while (c >= b && x[c] <= v) { - if (x[c] == v) swap(x, c, d--); - c--; - } - if (b > c) break; - swap(x, b++, c--); - } - - // Swap partition elements back to middle - int s, n = off + len; - s = Math.min(a - off, b - a); - vecswap(x, off, b - s, s); - s = Math.min(d - c, n - d - 1); - vecswap(x, b, n - s, s); - - // Recursively sort non-partition-elements - if ((s = b - a) > 1) reversedMergesort(x, off, s); - if ((s = d - c) > 1) reversedMergesort(x, n - s, s); - } - - /** - * Swaps x[a] with x[b]. - */ - private static void swap(double x[], int a, int b) { - double t = x[a]; - x[a] = x[b]; - x[b] = t; - } - - /** - * Swaps x[a .. (a+n-1)] with x[b .. (b+n-1)]. - */ - private static void vecswap(double x[], int a, int b, int n) { - for (int i = 0; i < n; i++, a++, b++) - swap(x, a, b); - } - - /** - * Returns the index of the median of the three indexed doubles. - */ - private static int med3(double x[], int a, int b, int c) { - return (x[a] < x[b] ? (x[b] < x[c] ? b : x[a] < x[c] ? c : a) : (x[b] > x[c] ? b : x[a] > x[c] ? c : a)); - } - - -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/DoubleArrayPointer.java b/src/main/java/de/bwaldvogel/denseliblinear/DoubleArrayPointer.java deleted file mode 100644 index 1f6e1aa..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/DoubleArrayPointer.java +++ /dev/null @@ -1,27 +0,0 @@ -package de.bwaldvogel.denseliblinear; - - -final class DoubleArrayPointer { - - private final double[] _array; - private int _offset; - - - public void setOffset(int offset) { - if (offset < 0 || offset >= _array.length) throw new IllegalArgumentException("offset must be between 0 and the length of the array"); - _offset = offset; - } - - public DoubleArrayPointer( final double[] array, final int offset ) { - _array = array; - setOffset(offset); - } - - public double get(final int index) { - return _array[_offset + index]; - } - - public void set(final int index, final double value) { - _array[_offset + index] = value; - } -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Function.java b/src/main/java/de/bwaldvogel/denseliblinear/Function.java deleted file mode 100644 index 9a15c27..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/Function.java +++ /dev/null @@ -1,13 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -// origin: tron.h -interface Function { - - double fun(double[] w); - - void grad(double[] w, double[] g); - - void Hv(double[] s, double[] Hs); - - int get_nr_variable(); -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/IntArrayPointer.java b/src/main/java/de/bwaldvogel/denseliblinear/IntArrayPointer.java deleted file mode 100644 index f8635fd..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/IntArrayPointer.java +++ /dev/null @@ -1,27 +0,0 @@ -package de.bwaldvogel.denseliblinear; - - -final class IntArrayPointer { - - private final int[] _array; - private int _offset; - - - public void setOffset(int offset) { - if (offset < 0 || offset >= _array.length) throw new IllegalArgumentException("offset must be between 0 and the length of the array"); - _offset = offset; - } - - public IntArrayPointer( final int[] array, final int offset ) { - _array = array; - setOffset(offset); - } - - public int get(final int index) { - return _array[_offset + index]; - } - - public void set(final int index, final int value) { - _array[_offset + index] = value; - } -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/InvalidInputDataException.java b/src/main/java/de/bwaldvogel/denseliblinear/InvalidInputDataException.java deleted file mode 100644 index 5991a64..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/InvalidInputDataException.java +++ /dev/null @@ -1,57 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import java.io.File; - - -public class InvalidInputDataException extends Exception { - - private static final long serialVersionUID = 2945131732407207308L; - - private final int _line; - - private File _file; - - public InvalidInputDataException( String message, File file, int line ) { - super(message); - _file = file; - _line = line; - } - - public InvalidInputDataException( String message, String filename, int line ) { - this(message, new File(filename), line); - } - - public InvalidInputDataException( String message, File file, int lineNr, Exception cause ) { - super(message, cause); - _file = file; - _line = lineNr; - } - - public InvalidInputDataException( String message, String filename, int lineNr, Exception cause ) { - this(message, new File(filename), lineNr, cause); - } - - public File getFile() { - return _file; - } - - /** - * This methods returns the path of the file. - * The method name might be misleading. - * - * @deprecated use {@link #getFile()} instead - */ - public String getFilename() { - return _file.getPath(); - } - - public int getLine() { - return _line; - } - - @Override - public String toString() { - return super.toString() + " (" + _file + ":" + _line + ")"; - } - -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvcFunction.java b/src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvcFunction.java deleted file mode 100644 index 2a13238..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvcFunction.java +++ /dev/null @@ -1,117 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -class L2R_L2_SvcFunction implements Function { - - protected final Problem prob; - protected final double[] C; - protected final int[] I; - protected final double[] z; - - protected int sizeI; - - public L2R_L2_SvcFunction(Problem prob, double[] C) { - final int l = prob.l; - - this.prob = prob; - - z = new double[l]; - I = new int[l]; - this.C = C; - } - - @Override - public double fun(double[] w) { - int i; - double f = 0; - final double[] y = prob.y; - final int l = prob.l; - final int w_size = get_nr_variable(); - - Xv(w, z); - - for (i = 0; i < w_size; i++) - f += w[i] * w[i]; - f /= 2.0; - for (i = 0; i < l; i++) { - z[i] = y[i] * z[i]; - final double d = 1 - z[i]; - if (d > 0) - f += C[i] * d * d; - } - - return (f); - } - - @Override - public int get_nr_variable() { - return prob.n; - } - - @Override - public void grad(double[] w, double[] g) { - final double[] y = prob.y; - final int l = prob.l; - final int w_size = get_nr_variable(); - - sizeI = 0; - for (int i = 0; i < l; i++) { - if (z[i] < 1) { - z[sizeI] = C[i] * y[i] * (z[i] - 1); - I[sizeI] = i; - sizeI++; - } - } - subXTv(z, g); - - for (int i = 0; i < w_size; i++) - g[i] = w[i] + 2 * g[i]; - } - - @Override - public void Hv(double[] s, double[] Hs) { - int i; - final int w_size = get_nr_variable(); - final double[] wa = new double[sizeI]; - - subXv(s, wa); - for (i = 0; i < sizeI; i++) - wa[i] = C[I[i]] * wa[i]; - - subXTv(wa, Hs); - for (i = 0; i < w_size; i++) - Hs[i] = s[i] + 2 * Hs[i]; - } - - protected void subXTv(double[] v, double[] XTv) { - int i; - final int w_size = get_nr_variable(); - - for (i = 0; i < w_size; i++) - XTv[i] = 0; - - for (i = 0; i < sizeI; i++) { - for (int j = 0; j < prob.x[I[i]].length; j++) { - XTv[j] += v[i] * prob.x[I[i]][j]; - } - } - } - - private void subXv(double[] v, double[] Xv) { - for (int i = 0; i < sizeI; i++) { - Xv[i] = 0; - - for (int j = 0; j < prob.x[I[i]].length; j++) { - Xv[i] += v[j] * prob.x[I[i]][j]; - } - } - } - - protected void Xv(double[] v, double[] Xv) { - for (int i = 0; i < prob.l; i++) { - Xv[i] = 0; - for (int j = 0; j < prob.x[i].length; j++) { - Xv[i] += v[j] * prob.x[i][j]; - } - } - } -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvrFunction.java b/src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvrFunction.java deleted file mode 100644 index d4de914..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/L2R_L2_SvrFunction.java +++ /dev/null @@ -1,67 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -/** - * @since 1.91 - */ -public class L2R_L2_SvrFunction extends L2R_L2_SvcFunction { - - private double p; - - public L2R_L2_SvrFunction( Problem prob, double[] C, double p ) { - super(prob, C); - this.p = p; - } - - @Override - public double fun(double[] w) { - double f = 0; - double[] y = prob.y; - int l = prob.l; - int w_size = get_nr_variable(); - double d; - - Xv(w, z); - - for (int i = 0; i < w_size; i++) - f += w[i] * w[i]; - f /= 2; - for (int i = 0; i < l; i++) { - d = z[i] - y[i]; - if (d < -p) - f += C[i] * (d + p) * (d + p); - else if (d > p) f += C[i] * (d - p) * (d - p); - } - - return f; - } - - @Override - public void grad(double[] w, double[] g) { - double[] y = prob.y; - int l = prob.l; - int w_size = get_nr_variable(); - - sizeI = 0; - for (int i = 0; i < l; i++) { - double d = z[i] - y[i]; - - // generate index set I - if (d < -p) { - z[sizeI] = C[i] * (d + p); - I[sizeI] = i; - sizeI++; - } else if (d > p) { - z[sizeI] = C[i] * (d - p); - I[sizeI] = i; - sizeI++; - } - - } - subXTv(z, g); - - for (int i = 0; i < w_size; i++) - g[i] = w[i] + 2 * g[i]; - - } - -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/L2R_LrFunction.java b/src/main/java/de/bwaldvogel/denseliblinear/L2R_LrFunction.java deleted file mode 100644 index faf68ce..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/L2R_LrFunction.java +++ /dev/null @@ -1,108 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -class L2R_LrFunction implements Function { - - private final double[] C; - private final double[] z; - private final double[] D; - private final Problem prob; - - public L2R_LrFunction(Problem prob, double[] C) { - final int l = prob.l; - - this.prob = prob; - - z = new double[l]; - D = new double[l]; - this.C = C; - } - - private void Xv(double[] v, double[] Xv) { - for (int i = 0; i < prob.l; i++) { - Xv[i] = 0; - for (int j = 0; j < prob.x[i].length; j++) { - Xv[i] += v[j] * prob.x[i][j]; - } - } - } - - private void XTv(double[] v, double[] XTv) { - final int l = prob.l; - final int w_size = get_nr_variable(); - final double[][] x = prob.x; - - for (int i = 0; i < w_size; i++) - XTv[i] = 0; - - for (int i = 0; i < l; i++) { - for (int j = 0; j < prob.x[i].length; j++) { - XTv[j] += v[i] * x[i][j]; - } - } - } - - @Override - public double fun(double[] w) { - int i; - double f = 0; - final double[] y = prob.y; - final int l = prob.l; - final int w_size = get_nr_variable(); - - Xv(w, z); - - for (i = 0; i < w_size; i++) - f += w[i] * w[i]; - f /= 2.0; - for (i = 0; i < l; i++) { - final double yz = y[i] * z[i]; - if (yz >= 0) - f += C[i] * Math.log(1 + Math.exp(-yz)); - else - f += C[i] * (-yz + Math.log(1 + Math.exp(yz))); - } - - return (f); - } - - @Override - public void grad(double[] w, double[] g) { - int i; - final double[] y = prob.y; - final int l = prob.l; - final int w_size = get_nr_variable(); - - for (i = 0; i < l; i++) { - z[i] = 1 / (1 + Math.exp(-y[i] * z[i])); - D[i] = z[i] * (1 - z[i]); - z[i] = C[i] * (z[i] - 1) * y[i]; - } - XTv(z, g); - - for (i = 0; i < w_size; i++) - g[i] = w[i] + g[i]; - } - - @Override - public void Hv(double[] s, double[] Hs) { - int i; - final int l = prob.l; - final int w_size = get_nr_variable(); - final double[] wa = new double[l]; - - Xv(s, wa); - for (i = 0; i < l; i++) - wa[i] = C[i] * D[i] * wa[i]; - - XTv(wa, Hs); - for (i = 0; i < w_size; i++) - Hs[i] = s[i] + Hs[i]; - // delete[] wa; - } - - @Override - public int get_nr_variable() { - return prob.n; - } - -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Linear.java b/src/main/java/de/bwaldvogel/denseliblinear/Linear.java deleted file mode 100644 index f2f8029..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/Linear.java +++ /dev/null @@ -1,1912 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.Closeable; -import java.io.EOFException; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStreamReader; -import java.io.OutputStreamWriter; -import java.io.PrintStream; -import java.io.Reader; -import java.io.Writer; -import java.nio.charset.Charset; -import java.util.Formatter; -import java.util.Locale; -import java.util.Random; -import java.util.regex.Pattern; - -/** - *

Java port of liblinear

- * - *

- * The usage should be pretty similar to the C version of liblinear. - *

- *

- * Please consider reading the README file of liblinear. - *

- * - *

- * The port was done by Benedikt Waldvogel (mail at bwaldvogel.de) - *

- * - * @version 1.92 - */ -public class Linear { - - static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1"); - - static final Locale DEFAULT_LOCALE = Locale.ENGLISH; - - private static Object OUTPUT_MUTEX = new Object(); - private static PrintStream DEBUG_OUTPUT = System.out; - - private static final long DEFAULT_RANDOM_SEED = 0L; - static Random random = new Random(DEFAULT_RANDOM_SEED); - - /** - * @param target - * predicted classes - */ - public static void crossValidation(Problem prob, Parameter param, int nr_fold, double[] target) { - int i; - final int[] fold_start = new int[nr_fold + 1]; - final int l = prob.l; - final int[] perm = new int[l]; - - for (i = 0; i < l; i++) - perm[i] = i; - for (i = 0; i < l; i++) { - final int j = i + random.nextInt(l - i); - swap(perm, i, j); - } - for (i = 0; i <= nr_fold; i++) - fold_start[i] = i * l / nr_fold; - - for (i = 0; i < nr_fold; i++) { - final int begin = fold_start[i]; - final int end = fold_start[i + 1]; - int j, k; - final Problem subprob = new Problem(); - - subprob.bias = prob.bias; - subprob.n = prob.n; - subprob.l = l - (end - begin); - subprob.x = new double[subprob.l][]; - subprob.y = new double[subprob.l]; - - k = 0; - for (j = 0; j < begin; j++) { - subprob.x[k] = prob.x[perm[j]]; - subprob.y[k] = prob.y[perm[j]]; - ++k; - } - for (j = end; j < l; j++) { - subprob.x[k] = prob.x[perm[j]]; - subprob.y[k] = prob.y[perm[j]]; - ++k; - } - final Model submodel = train(subprob, param); - for (j = begin; j < end; j++) - target[perm[j]] = predict(submodel, prob.x[perm[j]]); - } - } - - /** used as complex return type */ - private static class GroupClassesReturn { - - final int[] count; - final int[] label; - final int nr_class; - final int[] start; - - GroupClassesReturn(int nr_class, int[] label, int[] start, int[] count) { - this.nr_class = nr_class; - this.label = label; - this.start = start; - this.count = count; - } - } - - private static GroupClassesReturn groupClasses(Problem prob, int[] perm) { - final int l = prob.l; - int max_nr_class = 16; - int nr_class = 0; - - int[] label = new int[max_nr_class]; - int[] count = new int[max_nr_class]; - final int[] data_label = new int[l]; - int i; - - for (i = 0; i < l; i++) { - final int this_label = (int) prob.y[i]; - int j; - for (j = 0; j < nr_class; j++) { - if (this_label == label[j]) { - ++count[j]; - break; - } - } - data_label[i] = j; - if (j == nr_class) { - if (nr_class == max_nr_class) { - max_nr_class *= 2; - label = copyOf(label, max_nr_class); - count = copyOf(count, max_nr_class); - } - label[nr_class] = this_label; - count[nr_class] = 1; - ++nr_class; - } - } - - final int[] start = new int[nr_class]; - start[0] = 0; - for (i = 1; i < nr_class; i++) - start[i] = start[i - 1] + count[i - 1]; - for (i = 0; i < l; i++) { - perm[start[data_label[i]]] = i; - ++start[data_label[i]]; - } - start[0] = 0; - for (i = 1; i < nr_class; i++) - start[i] = start[i - 1] + count[i - 1]; - - return new GroupClassesReturn(nr_class, label, start, count); - } - - static void info(String message) { - synchronized (OUTPUT_MUTEX) { - if (DEBUG_OUTPUT == null) - return; - DEBUG_OUTPUT.printf(message); - DEBUG_OUTPUT.flush(); - } - } - - static void info(String format, Object... args) { - synchronized (OUTPUT_MUTEX) { - if (DEBUG_OUTPUT == null) - return; - DEBUG_OUTPUT.printf(format, args); - DEBUG_OUTPUT.flush(); - } - } - - /** - * @param s - * the string to parse for the double value - * @throws IllegalArgumentException - * if s is empty or represents NaN or Infinity - * @throws NumberFormatException - * see {@link Double#parseDouble(String)} - */ - static double atof(String s) { - if (s == null || s.length() < 1) - throw new IllegalArgumentException("Can't convert empty string to integer"); - final double d = Double.parseDouble(s); - if (Double.isNaN(d) || Double.isInfinite(d)) { - throw new IllegalArgumentException("NaN or Infinity in input: " + s); - } - return (d); - } - - /** - * @param s - * the string to parse for the integer value - * @throws IllegalArgumentException - * if s is empty - * @throws NumberFormatException - * see {@link Integer#parseInt(String)} - */ - static int atoi(String s) throws NumberFormatException { - if (s == null || s.length() < 1) - throw new IllegalArgumentException("Can't convert empty string to integer"); - // Integer.parseInt doesn't accept '+' prefixed strings - if (s.charAt(0) == '+') - s = s.substring(1); - return Integer.parseInt(s); - } - - /** - * Java5 'backport' of Arrays.copyOf - */ - public static double[] copyOf(double[] original, int newLength) { - final double[] copy = new double[newLength]; - System.arraycopy(original, 0, copy, 0, Math.min(original.length, newLength)); - return copy; - } - - /** - * Java5 'backport' of Arrays.copyOf - */ - public static int[] copyOf(int[] original, int newLength) { - final int[] copy = new int[newLength]; - System.arraycopy(original, 0, copy, 0, Math.min(original.length, newLength)); - return copy; - } - - /** - * Loads the model from inputReader. It uses - * {@link java.util.Locale#ENGLISH} for number formatting. - * - *

- * Note: The inputReader is NOT closed after reading or in case of an - * exception. - *

- */ - public static Model loadModel(Reader inputReader) throws IOException { - final Model model = new Model(); - - model.label = null; - - final Pattern whitespace = Pattern.compile("\\s+"); - - BufferedReader reader = null; - if (inputReader instanceof BufferedReader) { - reader = (BufferedReader) inputReader; - } else { - reader = new BufferedReader(inputReader); - } - - String line = null; - while ((line = reader.readLine()) != null) { - final String[] split = whitespace.split(line); - if (split[0].equals("solver_type")) { - final SolverType solver = SolverType.valueOf(split[1]); - if (solver == null) { - throw new RuntimeException("unknown solver type"); - } - model.solverType = solver; - } else if (split[0].equals("nr_class")) { - model.nr_class = atoi(split[1]); - Integer.parseInt(split[1]); - } else if (split[0].equals("nr_feature")) { - model.nr_feature = atoi(split[1]); - } else if (split[0].equals("bias")) { - model.bias = atof(split[1]); - } else if (split[0].equals("w")) { - break; - } else if (split[0].equals("label")) { - model.label = new int[model.nr_class]; - for (int i = 0; i < model.nr_class; i++) { - model.label[i] = atoi(split[i + 1]); - } - } else { - throw new RuntimeException("unknown text in model file: [" + line + "]"); - } - } - - int w_size = model.nr_feature; - if (model.bias >= 0) - w_size++; - - int nr_w = model.nr_class; - if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) - nr_w = 1; - - model.w = new double[w_size * nr_w]; - final int[] buffer = new int[128]; - - for (int i = 0; i < w_size; i++) { - for (int j = 0; j < nr_w; j++) { - int b = 0; - while (true) { - final int ch = reader.read(); - if (ch == -1) { - throw new EOFException("unexpected EOF"); - } - if (ch == ' ') { - model.w[i * nr_w + j] = atof(new String(buffer, 0, b)); - break; - } else { - buffer[b++] = ch; - } - } - } - } - - return model; - } - - /** - * Loads the model from the file with ISO-8859-1 charset. It uses - * {@link java.util.Locale#ENGLISH} for number formatting. - */ - public static Model loadModel(File modelFile) throws IOException { - final BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), - FILE_CHARSET)); - try { - return loadModel(inputReader); - } finally { - inputReader.close(); - } - } - - static void closeQuietly(Closeable c) { - if (c == null) - return; - try { - c.close(); - } catch (final Throwable t) { - } - } - - public static double predict(Model model, double[] x) { - final double[] dec_values = new double[model.nr_class]; - return predictValues(model, x, dec_values); - } - - /** - * @throws IllegalArgumentException - * if model is not probabilistic (see - * {@link Model#isProbabilityModel()}) - */ - public static double predictProbability(Model model, double[] x, double[] prob_estimates) - throws IllegalArgumentException - { - if (!model.isProbabilityModel()) { - final StringBuilder sb = new StringBuilder("probability output is only supported for logistic regression"); - sb.append(". This is currently only supported by the following solvers: "); - int i = 0; - for (final SolverType solverType : SolverType.values()) { - if (solverType.isLogisticRegressionSolver()) { - if (i++ > 0) { - sb.append(", "); - } - sb.append(solverType.name()); - } - } - throw new IllegalArgumentException(sb.toString()); - } - final int nr_class = model.nr_class; - int nr_w; - if (nr_class == 2) - nr_w = 1; - else - nr_w = nr_class; - - final double label = predictValues(model, x, prob_estimates); - for (int i = 0; i < nr_w; i++) - prob_estimates[i] = 1 / (1 + Math.exp(-prob_estimates[i])); - - if (nr_class == 2) // for binary classification - prob_estimates[1] = 1. - prob_estimates[0]; - else { - double sum = 0; - for (int i = 0; i < nr_class; i++) - sum += prob_estimates[i]; - - for (int i = 0; i < nr_class; i++) - prob_estimates[i] = prob_estimates[i] / sum; - } - - return label; - } - - public static double predictValues(Model model, double[] x, double[] dec_values) { - int n; - if (model.bias >= 0) - n = model.nr_feature + 1; - else - n = model.nr_feature; - - final double[] w = model.w; - - int nr_w; - if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) - nr_w = 1; - else - nr_w = model.nr_class; - - for (int i = 0; i < nr_w; i++) - dec_values[i] = 0; - - for (int idx = 0; idx < n; idx++) { - // the dimension of testing data may exceed that of training - for (int i = 0; i < nr_w; i++) { - dec_values[i] += w[idx * nr_w + i] * x[idx]; - } - } - - if (model.nr_class == 2) { - if (model.solverType.isSupportVectorRegression()) - return dec_values[0]; - else - return (dec_values[0] > 0) ? model.label[0] : model.label[1]; - } else { - int dec_max_idx = 0; - for (int i = 1; i < model.nr_class; i++) { - if (dec_values[i] > dec_values[dec_max_idx]) - dec_max_idx = i; - } - return model.label[dec_max_idx]; - } - } - - static void printf(Formatter formatter, String format, Object... args) throws IOException { - formatter.format(format, args); - final IOException ioException = formatter.ioException(); - if (ioException != null) - throw ioException; - } - - /** - * Writes the model to the modelOutput. It uses - * {@link java.util.Locale#ENGLISH} for number formatting. - * - *

- * Note: The modelOutput is closed after reading or in case of an - * exception. - *

- */ - public static void saveModel(Writer modelOutput, Model model) throws IOException { - final int nr_feature = model.nr_feature; - int w_size = nr_feature; - if (model.bias >= 0) - w_size++; - - int nr_w = model.nr_class; - if (model.nr_class == 2 && model.solverType != SolverType.MCSVM_CS) - nr_w = 1; - - final Formatter formatter = new Formatter(modelOutput, DEFAULT_LOCALE); - try { - printf(formatter, "solver_type %s\n", model.solverType.name()); - printf(formatter, "nr_class %d\n", model.nr_class); - - if (model.label != null) { - printf(formatter, "label"); - for (int i = 0; i < model.nr_class; i++) { - printf(formatter, " %d", model.label[i]); - } - printf(formatter, "\n"); - } - - printf(formatter, "nr_feature %d\n", nr_feature); - printf(formatter, "bias %.16g\n", model.bias); - - printf(formatter, "w\n"); - for (int i = 0; i < w_size; i++) { - for (int j = 0; j < nr_w; j++) { - final double value = model.w[i * nr_w + j]; - - /** - * this optimization is the reason for - * {@link Model#equals(double[], double[])} - */ - if (value == 0.0) { - printf(formatter, "%d ", 0); - } else { - printf(formatter, "%.16g ", value); - } - } - printf(formatter, "\n"); - } - - formatter.flush(); - final IOException ioException = formatter.ioException(); - if (ioException != null) - throw ioException; - } finally { - formatter.close(); - } - } - - /** - * Writes the model to the file with ISO-8859-1 charset. It uses - * {@link java.util.Locale#ENGLISH} for number formatting. - */ - public static void saveModel(File modelFile, Model model) throws IOException { - final BufferedWriter modelOutput = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(modelFile), - FILE_CHARSET)); - saveModel(modelOutput, model); - } - - /* - * this method corresponds to the following define in the C version: #define - * GETI(i) (y[i]+1) - */ - private static int GETI(byte[] y, int i) { - return y[i] + 1; - } - - /** - * A coordinate descent algorithm for L1-loss and L2-loss SVM dual problems - * - *
-	 *  min_\alpha  0.5(\alpha^T (Q + D)\alpha) - e^T \alpha,
-	 *    s.t.      0 <= \alpha_i <= upper_bound_i,
-	 * 
-	 *  where Qij = yi yj xi^T xj and
-	 *  D is a diagonal matrix
-	 * 
-	 * In L1-SVM case:
-	 *     upper_bound_i = Cp if y_i = 1
-	 *      upper_bound_i = Cn if y_i = -1
-	 *      D_ii = 0
-	 * In L2-SVM case:
-	 *      upper_bound_i = INF
-	 *      D_ii = 1/(2*Cp) if y_i = 1
-	 *      D_ii = 1/(2*Cn) if y_i = -1
-	 * 
-	 * Given:
-	 * x, y, Cp, Cn
-	 * eps is the stopping tolerance
-	 * 
-	 * solution will be put in w
-	 * 
-	 * See Algorithm 3 of Hsieh et al., ICML 2008
-	 * 
- */ - private static void solve_l2r_l1l2_svc(Problem prob, double[] w, double eps, double Cp, double Cn, - SolverType solver_type) - { - final int l = prob.l; - final int w_size = prob.n; - int i, s, iter = 0; - double C, d, G; - final double[] QD = new double[l]; - final int max_iter = 1000; - final int[] index = new int[l]; - final double[] alpha = new double[l]; - final byte[] y = new byte[l]; - int active_size = l; - - // PG: projected gradient, for shrinking and stopping - double PG; - double PGmax_old = Double.POSITIVE_INFINITY; - double PGmin_old = Double.NEGATIVE_INFINITY; - double PGmax_new, PGmin_new; - - // default solver_type: L2R_L2LOSS_SVC_DUAL - final double diag[] = new double[] { 0.5 / Cn, 0, 0.5 / Cp }; - final double upper_bound[] = new double[] { Double.POSITIVE_INFINITY, 0, Double.POSITIVE_INFINITY }; - if (solver_type == SolverType.L2R_L1LOSS_SVC_DUAL) { - diag[0] = 0; - diag[2] = 0; - upper_bound[0] = Cn; - upper_bound[2] = Cp; - } - - for (i = 0; i < l; i++) { - if (prob.y[i] > 0) { - y[i] = +1; - } else { - y[i] = -1; - } - } - - // Initial alpha can be set here. Note that - // 0 <= alpha[i] <= upper_bound[GETI(i)] - for (i = 0; i < l; i++) - alpha[i] = 0; - - for (i = 0; i < w_size; i++) - w[i] = 0; - for (i = 0; i < l; i++) { - QD[i] = diag[GETI(y, i)]; - - for (int j = 0; j < w_size; j++) { - final double val = prob.x[i][j]; - QD[i] += val * val; - w[j] += y[i] * alpha[i] * val; - } - index[i] = i; - } - - while (iter < max_iter) { - PGmax_new = Double.NEGATIVE_INFINITY; - PGmin_new = Double.POSITIVE_INFINITY; - - for (i = 0; i < active_size; i++) { - final int j = i + random.nextInt(active_size - i); - swap(index, i, j); - } - - for (s = 0; s < active_size; s++) { - i = index[s]; - G = 0; - final byte yi = y[i]; - - for (int j = 0; j < w_size; j++) { - G += w[j] * prob.x[i][j]; - } - G = G * yi - 1; - - C = upper_bound[GETI(y, i)]; - G += alpha[i] * diag[GETI(y, i)]; - - PG = 0; - if (alpha[i] == 0) { - if (G > PGmax_old) { - active_size--; - swap(index, s, active_size); - s--; - continue; - } else if (G < 0) { - PG = G; - } - } else if (alpha[i] == C) { - if (G < PGmin_old) { - active_size--; - swap(index, s, active_size); - s--; - continue; - } else if (G > 0) { - PG = G; - } - } else { - PG = G; - } - - PGmax_new = Math.max(PGmax_new, PG); - PGmin_new = Math.min(PGmin_new, PG); - - if (Math.abs(PG) > 1.0e-12) { - final double alpha_old = alpha[i]; - alpha[i] = Math.min(Math.max(alpha[i] - G / QD[i], 0.0), C); - d = (alpha[i] - alpha_old) * yi; - - for (int j = 0; j < w_size; j++) { - w[j] += d * prob.x[i][j]; - } - } - } - - iter++; - if (iter % 10 == 0) - info("."); - - if (PGmax_new - PGmin_new <= eps) { - if (active_size == l) - break; - else { - active_size = l; - info("*"); - PGmax_old = Double.POSITIVE_INFINITY; - PGmin_old = Double.NEGATIVE_INFINITY; - continue; - } - } - PGmax_old = PGmax_new; - PGmin_old = PGmin_new; - if (PGmax_old <= 0) - PGmax_old = Double.POSITIVE_INFINITY; - if (PGmin_old >= 0) - PGmin_old = Double.NEGATIVE_INFINITY; - } - - info("%noptimization finished, #iter = %d%n", iter); - if (iter >= max_iter) - info("%nWARNING: reaching max number of iterations%nUsing -s 2 may be faster (also see FAQ)%n%n"); - - // calculate objective value - - double v = 0; - int nSV = 0; - for (i = 0; i < w_size; i++) - v += w[i] * w[i]; - for (i = 0; i < l; i++) { - v += alpha[i] * (alpha[i] * diag[GETI(y, i)] - 2); - if (alpha[i] > 0) - ++nSV; - } - info("Objective value = %g%n", v / 2); - info("nSV = %d%n", nSV); - } - - // To support weights for instances, use GETI(i) (i) - private static int GETI_SVR(int i) { - return 0; - } - - /** - * A coordinate descent algorithm for L1-loss and L2-loss epsilon-SVR dual - * problem - * - * min_\beta 0.5\beta^T (Q + diag(lambda)) \beta - p \sum_{i=1}^l|\beta_i| + - * \sum_{i=1}^l yi\beta_i, s.t. -upper_bound_i <= \beta_i <= upper_bound_i, - * - * where Qij = xi^T xj and D is a diagonal matrix - * - * In L1-SVM case: upper_bound_i = C lambda_i = 0 In L2-SVM case: - * upper_bound_i = INF lambda_i = 1/(2*C) - * - * Given: x, y, p, C eps is the stopping tolerance - * - * solution will be put in w - * - * See Algorithm 4 of Ho and Lin, 2012 - */ - private static void solve_l2r_l1l2_svr(Problem prob, double[] w, Parameter param) { - final int l = prob.l; - final double C = param.C; - final double p = param.p; - final int w_size = prob.n; - final double eps = param.eps; - int i, s, iter = 0; - final int max_iter = 1000; - int active_size = l; - final int[] index = new int[l]; - - double d, G, H; - double Gmax_old = Double.POSITIVE_INFINITY; - double Gmax_new, Gnorm1_new; - double Gnorm1_init = 0; // initialize to 0 to get rid of Eclipse - // warning/error - final double[] beta = new double[l]; - final double[] QD = new double[l]; - final double[] y = prob.y; - - // L2R_L2LOSS_SVR_DUAL - final double[] lambda = new double[] { 0.5 / C }; - final double[] upper_bound = new double[] { Double.POSITIVE_INFINITY }; - - if (param.solverType == SolverType.L2R_L1LOSS_SVR_DUAL) { - lambda[0] = 0; - upper_bound[0] = C; - } - - // Initial beta can be set here. Note that - // -upper_bound <= beta[i] <= upper_bound - for (i = 0; i < l; i++) - beta[i] = 0; - - for (i = 0; i < w_size; i++) - w[i] = 0; - for (i = 0; i < l; i++) { - QD[i] = 0; - for (int j = 0; j < w_size; j++) { - final double val = prob.x[i][j]; - QD[i] += val * val; - w[j] += beta[i] * val; - } - - index[i] = i; - } - - while (iter < max_iter) { - Gmax_new = 0; - Gnorm1_new = 0; - - for (i = 0; i < active_size; i++) { - final int j = i + random.nextInt(active_size - i); - swap(index, i, j); - } - - for (s = 0; s < active_size; s++) { - i = index[s]; - G = -y[i] + lambda[GETI_SVR(i)] * beta[i]; - H = QD[i] + lambda[GETI_SVR(i)]; - - for (int ind = 0; ind < w_size; ind++) { - final double val = prob.x[i][ind]; - G += val * w[ind]; - } - - final double Gp = G + p; - final double Gn = G - p; - double violation = 0; - if (beta[i] == 0) { - if (Gp < 0) - violation = -Gp; - else if (Gn > 0) - violation = Gn; - else if (Gp > Gmax_old && Gn < -Gmax_old) { - active_size--; - swap(index, s, active_size); - s--; - continue; - } - } else if (beta[i] >= upper_bound[GETI_SVR(i)]) { - if (Gp > 0) - violation = Gp; - else if (Gp < -Gmax_old) { - active_size--; - swap(index, s, active_size); - s--; - continue; - } - } else if (beta[i] <= -upper_bound[GETI_SVR(i)]) { - if (Gn < 0) - violation = -Gn; - else if (Gn > Gmax_old) { - active_size--; - swap(index, s, active_size); - s--; - continue; - } - } else if (beta[i] > 0) - violation = Math.abs(Gp); - else - violation = Math.abs(Gn); - - Gmax_new = Math.max(Gmax_new, violation); - Gnorm1_new += violation; - - // obtain Newton direction d - if (Gp < H * beta[i]) - d = -Gp / H; - else if (Gn > H * beta[i]) - d = -Gn / H; - else - d = -beta[i]; - - if (Math.abs(d) < 1.0e-12) - continue; - - final double beta_old = beta[i]; - beta[i] = Math.min(Math.max(beta[i] + d, -upper_bound[GETI_SVR(i)]), upper_bound[GETI_SVR(i)]); - d = beta[i] - beta_old; - - if (d != 0) { - for (int j = 0; j < w_size; j++) { - w[j] += d * prob.x[i][j]; - } - } - } - - if (iter == 0) - Gnorm1_init = Gnorm1_new; - iter++; - if (iter % 10 == 0) - info("."); - - if (Gnorm1_new <= eps * Gnorm1_init) { - if (active_size == l) - break; - else { - active_size = l; - info("*"); - Gmax_old = Double.POSITIVE_INFINITY; - continue; - } - } - - Gmax_old = Gmax_new; - } - - info("%noptimization finished, #iter = %d%n", iter); - if (iter >= max_iter) - info("%nWARNING: reaching max number of iterations%nUsing -s 11 may be faster%n%n"); - - // calculate objective value - double v = 0; - int nSV = 0; - for (i = 0; i < w_size; i++) - v += w[i] * w[i]; - v = 0.5 * v; - for (i = 0; i < l; i++) { - v += p * Math.abs(beta[i]) - y[i] * beta[i] + 0.5 * lambda[GETI_SVR(i)] * beta[i] * beta[i]; - if (beta[i] != 0) - nSV++; - } - - info("Objective value = %g%n", v); - info("nSV = %d%n", nSV); - } - - /** - * A coordinate descent algorithm for the dual of L2-regularized logistic - * regression problems - * - *
-	 *  min_\alpha  0.5(\alpha^T Q \alpha) + \sum \alpha_i log (\alpha_i) + (upper_bound_i - \alpha_i) log (upper_bound_i - \alpha_i) ,
-	 *     s.t.      0 <= \alpha_i <= upper_bound_i,
-	 * 
-	 *  where Qij = yi yj xi^T xj and
-	 *  upper_bound_i = Cp if y_i = 1
-	 *  upper_bound_i = Cn if y_i = -1
-	 * 
-	 * Given:
-	 * x, y, Cp, Cn
-	 * eps is the stopping tolerance
-	 * 
-	 * solution will be put in w
-	 * 
-	 * See Algorithm 5 of Yu et al., MLJ 2010
-	 * 
- * - * @since 1.7 - */ - private static void solve_l2r_lr_dual(Problem prob, double w[], double eps, double Cp, double Cn) { - final int l = prob.l; - final int w_size = prob.n; - int i, s, iter = 0; - final double xTx[] = new double[l]; - final int max_iter = 1000; - final int index[] = new int[l]; - final double alpha[] = new double[2 * l]; // store alpha and C - alpha - final byte y[] = new byte[l]; - final int max_inner_iter = 100; // for inner Newton - double innereps = 1e-2; - final double innereps_min = Math.min(1e-8, eps); - final double upper_bound[] = new double[] { Cn, 0, Cp }; - - for (i = 0; i < l; i++) { - if (prob.y[i] > 0) { - y[i] = +1; - } else { - y[i] = -1; - } - } - - // Initial alpha can be set here. Note that - // 0 < alpha[i] < upper_bound[GETI(i)] - // alpha[2*i] + alpha[2*i+1] = upper_bound[GETI(i)] - for (i = 0; i < l; i++) { - alpha[2 * i] = Math.min(0.001 * upper_bound[GETI(y, i)], 1e-8); - alpha[2 * i + 1] = upper_bound[GETI(y, i)] - alpha[2 * i]; - } - - for (i = 0; i < w_size; i++) - w[i] = 0; - for (i = 0; i < l; i++) { - xTx[i] = 0; - for (int j = 0; j < w_size; j++) { - final double val = prob.x[i][j]; - xTx[i] += val * val; - w[j] += y[i] * alpha[2 * i] * val; - } - index[i] = i; - } - - while (iter < max_iter) { - for (i = 0; i < l; i++) { - final int j = i + random.nextInt(l - i); - swap(index, i, j); - } - int newton_iter = 0; - double Gmax = 0; - for (s = 0; s < l; s++) { - i = index[s]; - final byte yi = y[i]; - final double C = upper_bound[GETI(y, i)]; - double ywTx = 0; - final double xisq = xTx[i]; - for (int j = 0; j < w_size; j++) { - ywTx += w[j] * prob.x[i][j]; - } - ywTx *= y[i]; - final double a = xisq, b = ywTx; - - // Decide to minimize g_1(z) or g_2(z) - int ind1 = 2 * i, ind2 = 2 * i + 1, sign = 1; - if (0.5 * a * (alpha[ind2] - alpha[ind1]) + b < 0) { - ind1 = 2 * i + 1; - ind2 = 2 * i; - sign = -1; - } - - // g_t(z) = z*log(z) + (C-z)*log(C-z) + 0.5a(z-alpha_old)^2 + - // sign*b(z-alpha_old) - final double alpha_old = alpha[ind1]; - double z = alpha_old; - if (C - z < 0.5 * C) - z = 0.1 * z; - double gp = a * (z - alpha_old) + sign * b + Math.log(z / (C - z)); - Gmax = Math.max(Gmax, Math.abs(gp)); - - // Newton method on the sub-problem - final double eta = 0.1; // xi in the paper - int inner_iter = 0; - while (inner_iter <= max_inner_iter) { - if (Math.abs(gp) < innereps) - break; - final double gpp = a + C / (C - z) / z; - final double tmpz = z - gp / gpp; - if (tmpz <= 0) - z *= eta; - else - // tmpz in (0, C) - z = tmpz; - gp = a * (z - alpha_old) + sign * b + Math.log(z / (C - z)); - newton_iter++; - inner_iter++; - } - - if (inner_iter > 0) // update w - { - alpha[ind1] = z; - alpha[ind2] = C - z; - for (int j = 0; j < w_size; j++) { - w[j] += sign * (z - alpha_old) * yi * prob.x[i][j]; - } - } - } - - iter++; - if (iter % 10 == 0) - info("."); - - if (Gmax < eps) - break; - - if (newton_iter <= l / 10) { - innereps = Math.max(innereps_min, 0.1 * innereps); - } - - } - - info("%noptimization finished, #iter = %d%n", iter); - if (iter >= max_iter) - info("%nWARNING: reaching max number of iterations%nUsing -s 0 may be faster (also see FAQ)%n%n"); - - // calculate objective value - - double v = 0; - for (i = 0; i < w_size; i++) - v += w[i] * w[i]; - v *= 0.5; - for (i = 0; i < l; i++) - v += alpha[2 * i] * Math.log(alpha[2 * i]) + alpha[2 * i + 1] * Math.log(alpha[2 * i + 1]) - - upper_bound[GETI(y, i)] - * Math.log(upper_bound[GETI(y, i)]); - info("Objective value = %g%n", v); - } - - /** - * A coordinate descent algorithm for L1-regularized L2-loss support vector - * classification - * - *
-	 *  min_w \sum |wj| + C \sum max(0, 1-yi w^T xi)^2,
-	 * 
-	 * Given:
-	 * x, y, Cp, Cn
-	 * eps is the stopping tolerance
-	 * 
-	 * solution will be put in w
-	 * 
-	 * See Yuan et al. (2010) and appendix of LIBLINEAR paper, Fan et al. (2008)
-	 * 
- * - * @since 1.5 - */ - private static void solve_l1r_l2_svc(Problem prob_col, double[] w, double eps, double Cp, double Cn) { - final int l = prob_col.l; - final int w_size = prob_col.n; - int j, s, iter = 0; - final int max_iter = 1000; - int active_size = w_size; - final int max_num_linesearch = 20; - - final double sigma = 0.01; - double d, G_loss, G, H; - double Gmax_old = Double.POSITIVE_INFINITY; - double Gmax_new, Gnorm1_new; - double Gnorm1_init = 0; // eclipse moans this variable might not be - // initialized - double d_old, d_diff; - double loss_old = 0; // eclipse moans this variable might not be - // initialized - double loss_new; - double appxcond, cond; - - final int[] index = new int[w_size]; - final byte[] y = new byte[l]; - final double[] b = new double[l]; // b = 1-ywTx - final double[] xj_sq = new double[w_size]; - - final double[] C = new double[] { Cn, 0, Cp }; - - // Initial w can be set here. - for (j = 0; j < w_size; j++) - w[j] = 0; - - for (j = 0; j < l; j++) { - b[j] = 1; - if (prob_col.y[j] > 0) - y[j] = 1; - else - y[j] = -1; - } - for (j = 0; j < w_size; j++) { - index[j] = j; - xj_sq[j] = 0; - for (int ind = 0; ind < w_size; ind++) { - prob_col.x[j][ind] = prob_col.x[j][ind] * y[ind]; // x->value - // stores - // yi*xij - final double val = prob_col.x[j][ind]; - b[ind] -= w[j] * val; - - xj_sq[j] += C[GETI(y, ind)] * val * val; - } - } - - while (iter < max_iter) { - Gmax_new = 0; - Gnorm1_new = 0; - - for (j = 0; j < active_size; j++) { - final int i = j + random.nextInt(active_size - j); - swap(index, i, j); - } - - for (s = 0; s < active_size; s++) { - j = index[s]; - G_loss = 0; - H = 0; - - for (int ind = 0; ind < w_size; ind++) { - if (b[ind] > 0) { - final double val = prob_col.x[j][ind]; - final double tmp = C[GETI(y, ind)] * val; - G_loss -= tmp * b[ind]; - H += tmp * val; - } - } - G_loss *= 2; - - G = G_loss; - H *= 2; - H = Math.max(H, 1e-12); - - final double Gp = G + 1; - final double Gn = G - 1; - double violation = 0; - if (w[j] == 0) { - if (Gp < 0) - violation = -Gp; - else if (Gn > 0) - violation = Gn; - else if (Gp > Gmax_old / l && Gn < -Gmax_old / l) { - active_size--; - swap(index, s, active_size); - s--; - continue; - } - } else if (w[j] > 0) - violation = Math.abs(Gp); - else - violation = Math.abs(Gn); - - Gmax_new = Math.max(Gmax_new, violation); - Gnorm1_new += violation; - - // obtain Newton direction d - if (Gp < H * w[j]) - d = -Gp / H; - else if (Gn > H * w[j]) - d = -Gn / H; - else - d = -w[j]; - - if (Math.abs(d) < 1.0e-12) - continue; - - double delta = Math.abs(w[j] + d) - Math.abs(w[j]) + G * d; - d_old = 0; - int num_linesearch; - for (num_linesearch = 0; num_linesearch < max_num_linesearch; num_linesearch++) { - d_diff = d_old - d; - cond = Math.abs(w[j] + d) - Math.abs(w[j]) - sigma * delta; - - appxcond = xj_sq[j] * d * d + G_loss * d + cond; - if (appxcond <= 0) { - for (int ind = 0; ind < prob_col.x[j].length; ind++) { - b[ind] += d_diff * prob_col.x[j][ind]; - } - break; - } - - if (num_linesearch == 0) { - loss_old = 0; - loss_new = 0; - for (int ind = 0; ind < prob_col.x[j].length; ind++) { - if (b[ind] > 0) { - loss_old += C[GETI(y, ind)] * b[ind] * b[ind]; - } - final double b_new = b[ind] + d_diff * prob_col.x[j][ind]; - b[ind] = b_new; - if (b_new > 0) { - loss_new += C[GETI(y, ind)] * b_new * b_new; - } - } - } else { - loss_new = 0; - for (int ind = 0; ind < prob_col.x[j].length; ind++) { - final double b_new = b[ind] + d_diff * prob_col.x[j][ind]; - b[ind] = b_new; - if (b_new > 0) { - loss_new += C[GETI(y, ind)] * b_new * b_new; - } - } - } - - cond = cond + loss_new - loss_old; - if (cond <= 0) - break; - else { - d_old = d; - d *= 0.5; - delta *= 0.5; - } - } - - w[j] += d; - - // recompute b[] if line search takes too many steps - if (num_linesearch >= max_num_linesearch) { - info("#"); - for (int i = 0; i < l; i++) - b[i] = 1; - - for (int i = 0; i < w_size; i++) { - if (w[i] == 0) - continue; - for (int ind = 0; ind < prob_col.x[j].length; ind++) { - b[ind] -= w[i] * prob_col.x[j][ind]; - } - } - } - } - - if (iter == 0) { - Gnorm1_init = Gnorm1_new; - } - iter++; - if (iter % 10 == 0) - info("."); - - if (Gmax_new <= eps * Gnorm1_init) { - if (active_size == w_size) - break; - else { - active_size = w_size; - info("*"); - Gmax_old = Double.POSITIVE_INFINITY; - continue; - } - } - - Gmax_old = Gmax_new; - } - - info("%noptimization finished, #iter = %d%n", iter); - if (iter >= max_iter) - info("%nWARNING: reaching max number of iterations%n"); - - // calculate objective value - - double v = 0; - int nnz = 0; - for (j = 0; j < w_size; j++) { - for (int ind = 0; ind < prob_col.x[j].length; ind++) { - prob_col.x[j][ind] = prob_col.x[j][ind] * prob_col.y[ind]; // restore - // x->value - } - if (w[j] != 0) { - v += Math.abs(w[j]); - nnz++; - } - } - for (j = 0; j < l; j++) - if (b[j] > 0) - v += C[GETI(y, j)] * b[j] * b[j]; - - info("Objective value = %g%n", v); - info("#nonzeros/#features = %d/%d%n", nnz, w_size); - } - - /** - * A coordinate descent algorithm for L1-regularized logistic regression - * problems - * - *
-	 *  min_w \sum |wj| + C \sum log(1+exp(-yi w^T xi)),
-	 * 
-	 * Given:
-	 * x, y, Cp, Cn
-	 * eps is the stopping tolerance
-	 * 
-	 * solution will be put in w
-	 * 
-	 * See Yuan et al. (2011) and appendix of LIBLINEAR paper, Fan et al. (2008)
-	 * 
- * - * @since 1.5 - */ - private static void solve_l1r_lr(Problem prob_col, double[] w, double eps, double Cp, double Cn) { - final int l = prob_col.l; - final int w_size = prob_col.n; - int j, s, newton_iter = 0, iter = 0; - final int max_newton_iter = 100; - final int max_iter = 1000; - final int max_num_linesearch = 20; - int active_size; - int QP_active_size; - - final double nu = 1e-12; - double inner_eps = 1; - final double sigma = 0.01; - double w_norm, w_norm_new; - double z, G, H; - double Gnorm1_init = 0; // eclipse moans this variable might not be - // initialized - double Gmax_old = Double.POSITIVE_INFINITY; - double Gmax_new, Gnorm1_new; - double QP_Gmax_old = Double.POSITIVE_INFINITY; - double QP_Gmax_new, QP_Gnorm1_new; - double delta, negsum_xTd, cond; - - final int[] index = new int[w_size]; - final byte[] y = new byte[l]; - final double[] Hdiag = new double[w_size]; - final double[] Grad = new double[w_size]; - final double[] wpd = new double[w_size]; - final double[] xjneg_sum = new double[w_size]; - final double[] xTd = new double[l]; - final double[] exp_wTx = new double[l]; - final double[] exp_wTx_new = new double[l]; - final double[] tau = new double[l]; - final double[] D = new double[l]; - - final double[] C = { Cn, 0, Cp }; - - // Initial w can be set here. - for (j = 0; j < w_size; j++) - w[j] = 0; - - for (j = 0; j < l; j++) { - if (prob_col.y[j] > 0) - y[j] = 1; - else - y[j] = -1; - - exp_wTx[j] = 0; - } - - w_norm = 0; - for (j = 0; j < w_size; j++) { - w_norm += Math.abs(w[j]); - wpd[j] = w[j]; - index[j] = j; - xjneg_sum[j] = 0; - for (int ind = 0; ind < prob_col.x[j].length; ind++) { - final double val = prob_col.x[j][ind]; - exp_wTx[ind] += w[j] * val; - if (y[ind] == -1) { - xjneg_sum[j] += C[GETI(y, ind)] * val; - } - } - } - for (j = 0; j < l; j++) { - exp_wTx[j] = Math.exp(exp_wTx[j]); - final double tau_tmp = 1 / (1 + exp_wTx[j]); - tau[j] = C[GETI(y, j)] * tau_tmp; - D[j] = C[GETI(y, j)] * exp_wTx[j] * tau_tmp * tau_tmp; - } - - while (newton_iter < max_newton_iter) { - Gmax_new = 0; - Gnorm1_new = 0; - active_size = w_size; - - for (s = 0; s < active_size; s++) { - j = index[s]; - Hdiag[j] = nu; - Grad[j] = 0; - - double tmp = 0; - for (int ind = 0; ind < prob_col.x[j].length; ind++) { - Hdiag[j] += prob_col.x[j][ind] * prob_col.x[j][ind] * D[ind]; - tmp += prob_col.x[j][ind] * tau[ind]; - } - Grad[j] = -tmp + xjneg_sum[j]; - - final double Gp = Grad[j] + 1; - final double Gn = Grad[j] - 1; - double violation = 0; - if (w[j] == 0) { - if (Gp < 0) - violation = -Gp; - else if (Gn > 0) - violation = Gn; - // outer-level shrinking - else if (Gp > Gmax_old / l && Gn < -Gmax_old / l) { - active_size--; - swap(index, s, active_size); - s--; - continue; - } - } else if (w[j] > 0) - violation = Math.abs(Gp); - else - violation = Math.abs(Gn); - - Gmax_new = Math.max(Gmax_new, violation); - Gnorm1_new += violation; - } - - if (newton_iter == 0) - Gnorm1_init = Gnorm1_new; - - if (Gnorm1_new <= eps * Gnorm1_init) - break; - - iter = 0; - QP_Gmax_old = Double.POSITIVE_INFINITY; - QP_active_size = active_size; - - for (int i = 0; i < l; i++) - xTd[i] = 0; - - // optimize QP over wpd - while (iter < max_iter) { - QP_Gmax_new = 0; - QP_Gnorm1_new = 0; - - for (j = 0; j < QP_active_size; j++) { - final int i = random.nextInt(QP_active_size - j); - swap(index, i, j); - } - - for (s = 0; s < QP_active_size; s++) { - j = index[s]; - H = Hdiag[j]; - - G = Grad[j] + (wpd[j] - w[j]) * nu; - for (int ind = 0; ind < prob_col.x[j].length; ind++) { - G += prob_col.x[j][ind] * D[ind] * xTd[ind]; - } - - final double Gp = G + 1; - final double Gn = G - 1; - double violation = 0; - if (wpd[j] == 0) { - if (Gp < 0) - violation = -Gp; - else if (Gn > 0) - violation = Gn; - // inner-level shrinking - else if (Gp > QP_Gmax_old / l && Gn < -QP_Gmax_old / l) { - QP_active_size--; - swap(index, s, QP_active_size); - s--; - continue; - } - } else if (wpd[j] > 0) - violation = Math.abs(Gp); - else - violation = Math.abs(Gn); - - QP_Gmax_new = Math.max(QP_Gmax_new, violation); - QP_Gnorm1_new += violation; - - // obtain solution of one-variable problem - if (Gp < H * wpd[j]) - z = -Gp / H; - else if (Gn > H * wpd[j]) - z = -Gn / H; - else - z = -wpd[j]; - - if (Math.abs(z) < 1.0e-12) - continue; - z = Math.min(Math.max(z, -10.0), 10.0); - - wpd[j] += z; - - for (int ind = 0; ind < prob_col.x[j].length; ind++) { - xTd[ind] += prob_col.x[j][ind] * z; - } - } - - iter++; - - if (QP_Gnorm1_new <= inner_eps * Gnorm1_init) { - // inner stopping - if (QP_active_size == active_size) - break; - // active set reactivation - else { - QP_active_size = active_size; - QP_Gmax_old = Double.POSITIVE_INFINITY; - continue; - } - } - - QP_Gmax_old = QP_Gmax_new; - } - - if (iter >= max_iter) - info("WARNING: reaching max number of inner iterations%n"); - - delta = 0; - w_norm_new = 0; - for (j = 0; j < w_size; j++) { - delta += Grad[j] * (wpd[j] - w[j]); - if (wpd[j] != 0) - w_norm_new += Math.abs(wpd[j]); - } - delta += (w_norm_new - w_norm); - - negsum_xTd = 0; - for (int i = 0; i < l; i++) - if (y[i] == -1) - negsum_xTd += C[GETI(y, i)] * xTd[i]; - - int num_linesearch; - for (num_linesearch = 0; num_linesearch < max_num_linesearch; num_linesearch++) { - cond = w_norm_new - w_norm + negsum_xTd - sigma * delta; - - for (int i = 0; i < l; i++) { - final double exp_xTd = Math.exp(xTd[i]); - exp_wTx_new[i] = exp_wTx[i] * exp_xTd; - cond += C[GETI(y, i)] * Math.log((1 + exp_wTx_new[i]) / (exp_xTd + exp_wTx_new[i])); - } - - if (cond <= 0) { - w_norm = w_norm_new; - for (j = 0; j < w_size; j++) - w[j] = wpd[j]; - for (int i = 0; i < l; i++) { - exp_wTx[i] = exp_wTx_new[i]; - final double tau_tmp = 1 / (1 + exp_wTx[i]); - tau[i] = C[GETI(y, i)] * tau_tmp; - D[i] = C[GETI(y, i)] * exp_wTx[i] * tau_tmp * tau_tmp; - } - break; - } else { - w_norm_new = 0; - for (j = 0; j < w_size; j++) { - wpd[j] = (w[j] + wpd[j]) * 0.5; - if (wpd[j] != 0) - w_norm_new += Math.abs(wpd[j]); - } - delta *= 0.5; - negsum_xTd *= 0.5; - for (int i = 0; i < l; i++) - xTd[i] *= 0.5; - } - } - - // Recompute some info due to too many line search steps - if (num_linesearch >= max_num_linesearch) { - for (int i = 0; i < l; i++) - exp_wTx[i] = 0; - - for (int i = 0; i < w_size; i++) { - if (w[i] == 0) - continue; - for (int ind = 0; ind < prob_col.x[j].length; ind++) { - exp_wTx[ind] += w[i] * prob_col.x[i][ind]; - } - } - - for (int i = 0; i < l; i++) - exp_wTx[i] = Math.exp(exp_wTx[i]); - } - - if (iter == 1) - inner_eps *= 0.25; - - newton_iter++; - Gmax_old = Gmax_new; - - info("iter %3d #CD cycles %d%n", newton_iter, iter); - } - - info("=========================%n"); - info("optimization finished, #iter = %d%n", newton_iter); - if (newton_iter >= max_newton_iter) - info("WARNING: reaching max number of iterations%n"); - - // calculate objective value - - double v = 0; - int nnz = 0; - for (j = 0; j < w_size; j++) - if (w[j] != 0) { - v += Math.abs(w[j]); - nnz++; - } - for (j = 0; j < l; j++) - if (y[j] == 1) - v += C[GETI(y, j)] * Math.log(1 + 1 / exp_wTx[j]); - else - v += C[GETI(y, j)] * Math.log(1 + exp_wTx[j]); - - info("Objective value = %g%n", v); - info("#nonzeros/#features = %d/%d%n", nnz, w_size); - } - - // transpose matrix X from row format to column format - static Problem transpose(Problem prob) { - final int l = prob.l; - final int n = prob.n; - final Problem prob_col = new Problem(); - prob_col.l = l; - prob_col.n = n; - prob_col.y = new double[l]; - prob_col.x = new double[n][]; - - for (int i = 0; i < l; i++) - prob_col.y[i] = prob.y[i]; - - for (int i = 0; i < n; i++) { - prob_col.x[i] = new double[l]; - } - - for (int i = 0; i < l; i++) { - for (int j = 0; j < n; j++) { - prob_col.x[j][i] = prob.x[i][j]; - } - } - - return prob_col; - } - - static void swap(double[] array, int idxA, int idxB) { - final double temp = array[idxA]; - array[idxA] = array[idxB]; - array[idxB] = temp; - } - - static void swap(int[] array, int idxA, int idxB) { - final int temp = array[idxA]; - array[idxA] = array[idxB]; - array[idxB] = temp; - } - - static void swap(IntArrayPointer array, int idxA, int idxB) { - final int temp = array.get(idxA); - array.set(idxA, array.get(idxB)); - array.set(idxB, temp); - } - - /** - * @throws IllegalArgumentException - * if the feature nodes of prob are not sorted in ascending - * order - */ - public static Model train(Problem prob, Parameter param) { - - if (prob == null) - throw new IllegalArgumentException("problem must not be null"); - if (param == null) - throw new IllegalArgumentException("parameter must not be null"); - - if (prob.n == 0) - throw new IllegalArgumentException("problem has zero features"); - if (prob.l == 0) - throw new IllegalArgumentException("problem has zero instances"); - - final int l = prob.l; - final int n = prob.n; - final int w_size = prob.n; - final Model model = new Model(); - - if (prob.bias >= 0) - model.nr_feature = n - 1; - else - model.nr_feature = n; - - model.solverType = param.solverType; - model.bias = prob.bias; - - if (param.solverType == SolverType.L2R_L2LOSS_SVR || // - param.solverType == SolverType.L2R_L1LOSS_SVR_DUAL || // - param.solverType == SolverType.L2R_L2LOSS_SVR_DUAL) - { - model.w = new double[w_size]; - model.nr_class = 2; - model.label = null; - - checkProblemSize(n, model.nr_class); - - train_one(prob, param, model.w, 0, 0); - } else { - final int[] perm = new int[l]; - - // group training data of the same class - final GroupClassesReturn rv = groupClasses(prob, perm); - final int nr_class = rv.nr_class; - final int[] label = rv.label; - final int[] start = rv.start; - final int[] count = rv.count; - - checkProblemSize(n, nr_class); - - model.nr_class = nr_class; - model.label = new int[nr_class]; - for (int i = 0; i < nr_class; i++) - model.label[i] = label[i]; - - // calculate weighted C - final double[] weighted_C = new double[nr_class]; - for (int i = 0; i < nr_class; i++) - weighted_C[i] = param.C; - for (int i = 0; i < param.getNumWeights(); i++) { - int j; - for (j = 0; j < nr_class; j++) - if (param.weightLabel[i] == label[j]) - break; - - if (j == nr_class) - throw new IllegalArgumentException("class label " + param.weightLabel[i] - + " specified in weight is not found"); - weighted_C[j] *= param.weight[i]; - } - - // constructing the subproblem - final double[][] x = new double[l][]; - for (int i = 0; i < l; i++) - x[i] = prob.x[perm[i]]; - - final Problem sub_prob = new Problem(); - sub_prob.l = l; - sub_prob.n = n; - sub_prob.x = new double[sub_prob.l][]; - sub_prob.y = new double[sub_prob.l]; - - for (int k = 0; k < sub_prob.l; k++) - sub_prob.x[k] = x[k]; - - // multi-class svm by Crammer and Singer - if (param.solverType == SolverType.MCSVM_CS) { - model.w = new double[n * nr_class]; - for (int i = 0; i < nr_class; i++) { - for (int j = start[i]; j < start[i] + count[i]; j++) { - sub_prob.y[j] = i; - } - } - - final SolverMCSVM_CS solver = new SolverMCSVM_CS(sub_prob, nr_class, weighted_C, param.eps); - solver.solve(model.w); - } else { - if (nr_class == 2) { - model.w = new double[w_size]; - - final int e0 = start[0] + count[0]; - int k = 0; - for (; k < e0; k++) - sub_prob.y[k] = +1; - for (; k < sub_prob.l; k++) - sub_prob.y[k] = -1; - - train_one(sub_prob, param, model.w, weighted_C[0], weighted_C[1]); - } else { - model.w = new double[w_size * nr_class]; - final double[] w = new double[w_size]; - for (int i = 0; i < nr_class; i++) { - final int si = start[i]; - final int ei = si + count[i]; - - int k = 0; - for (; k < si; k++) - sub_prob.y[k] = -1; - for (; k < ei; k++) - sub_prob.y[k] = +1; - for (; k < sub_prob.l; k++) - sub_prob.y[k] = -1; - - train_one(sub_prob, param, w, weighted_C[i], param.C); - - for (int j = 0; j < n; j++) - model.w[j * nr_class + i] = w[j]; - } - } - } - } - return model; - } - - /** - * verify the size and throw an exception early if the problem is too large - */ - private static void checkProblemSize(int n, int nr_class) { - if (n >= Integer.MAX_VALUE / nr_class || n * nr_class < 0) { - throw new IllegalArgumentException("'number of classes' * 'number of instances' is too large: " + nr_class - + "*" + n); - } - } - - private static void train_one(Problem prob, Parameter param, double[] w, double Cp, double Cn) { - final double eps = param.eps; - int pos = 0; - for (int i = 0; i < prob.l; i++) - if (prob.y[i] > 0) { - pos++; - } - final int neg = prob.l - pos; - - final double primal_solver_tol = eps * Math.max(Math.min(pos, neg), 1) / prob.l; - - Function fun_obj = null; - switch (param.solverType) { - case L2R_LR: { - final double[] C = new double[prob.l]; - for (int i = 0; i < prob.l; i++) { - if (prob.y[i] > 0) - C[i] = Cp; - else - C[i] = Cn; - } - fun_obj = new L2R_LrFunction(prob, C); - final Tron tron_obj = new Tron(fun_obj, primal_solver_tol); - tron_obj.tron(w); - break; - } - case L2R_L2LOSS_SVC: { - final double[] C = new double[prob.l]; - for (int i = 0; i < prob.l; i++) { - if (prob.y[i] > 0) - C[i] = Cp; - else - C[i] = Cn; - } - fun_obj = new L2R_L2_SvcFunction(prob, C); - final Tron tron_obj = new Tron(fun_obj, primal_solver_tol); - tron_obj.tron(w); - break; - } - case L2R_L2LOSS_SVC_DUAL: - solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, SolverType.L2R_L2LOSS_SVC_DUAL); - break; - case L2R_L1LOSS_SVC_DUAL: - solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, SolverType.L2R_L1LOSS_SVC_DUAL); - break; - case L1R_L2LOSS_SVC: { - final Problem prob_col = transpose(prob); - solve_l1r_l2_svc(prob_col, w, primal_solver_tol, Cp, Cn); - break; - } - case L1R_LR: { - final Problem prob_col = transpose(prob); - solve_l1r_lr(prob_col, w, primal_solver_tol, Cp, Cn); - break; - } - case L2R_LR_DUAL: - solve_l2r_lr_dual(prob, w, eps, Cp, Cn); - break; - case L2R_L2LOSS_SVR: { - final double[] C = new double[prob.l]; - for (int i = 0; i < prob.l; i++) - C[i] = param.C; - - fun_obj = new L2R_L2_SvrFunction(prob, C, param.p); - final Tron tron_obj = new Tron(fun_obj, param.eps); - tron_obj.tron(w); - break; - } - case L2R_L1LOSS_SVR_DUAL: - case L2R_L2LOSS_SVR_DUAL: - solve_l2r_l1l2_svr(prob, w, param); - break; - - default: - throw new IllegalStateException("unknown solver type: " + param.solverType); - } - } - - public static void disableDebugOutput() { - setDebugOutput(null); - } - - public static void enableDebugOutput() { - setDebugOutput(System.out); - } - - public static void setDebugOutput(PrintStream debugOutput) { - synchronized (OUTPUT_MUTEX) { - DEBUG_OUTPUT = debugOutput; - } - } - - /** - * resets the PRNG - * - * this is i.a. needed for regression testing (eg. the Weka wrapper) - */ - public static void resetRandom() { - random = new Random(DEFAULT_RANDOM_SEED); - } -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Model.java b/src/main/java/de/bwaldvogel/denseliblinear/Model.java deleted file mode 100644 index 670d858..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/Model.java +++ /dev/null @@ -1,178 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import static de.bwaldvogel.denseliblinear.Linear.copyOf; - -import java.io.File; -import java.io.IOException; -import java.io.Reader; -import java.io.Serializable; -import java.io.Writer; -import java.util.Arrays; - - -/** - *

Model stores the model obtained from the training procedure

- * - *

use {@link Linear#loadModel(File)} and {@link Linear#saveModel(File, Model)} to load/save it

- */ -public final class Model implements Serializable { - - private static final long serialVersionUID = -6456047576741854834L; - - double bias; - - /** label of each class */ - int[] label; - - int nr_class; - - int nr_feature; - - SolverType solverType; - - /** feature weight array */ - double[] w; - - /** - * @return number of classes - */ - public int getNrClass() { - return nr_class; - } - - /** - * @return number of features - */ - public int getNrFeature() { - return nr_feature; - } - - public int[] getLabels() { - return copyOf(label, nr_class); - } - - /** - * The nr_feature*nr_class array w gives feature weights. We use one - * against the rest for multi-class classification, so each feature - * index corresponds to nr_class weight values. Weights are - * organized in the following way - * - *
-     * +------------------+------------------+------------+
-     * | nr_class weights | nr_class weights |  ...
-     * | for 1st feature  | for 2nd feature  |
-     * +------------------+------------------+------------+
-     * 
- * - * If bias >= 0, x becomes [x; bias]. The number of features is - * increased by one, so w is a (nr_feature+1)*nr_class array. The - * value of bias is stored in the variable bias. - * @see #getBias() - * @return a copy of the feature weight array as described - */ - public double[] getFeatureWeights() { - return Linear.copyOf(w, w.length); - } - - /** - * @return true for logistic regression solvers - */ - public boolean isProbabilityModel() { - return solverType.isLogisticRegressionSolver(); - } - - /** - * @see #getFeatureWeights() - */ - public double getBias() { - return bias; - } - - @Override - public String toString() { - StringBuilder sb = new StringBuilder("Model"); - sb.append(" bias=").append(bias); - sb.append(" nr_class=").append(nr_class); - sb.append(" nr_feature=").append(nr_feature); - sb.append(" solverType=").append(solverType); - return sb.toString(); - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - long temp; - temp = Double.doubleToLongBits(bias); - result = prime * result + (int)(temp ^ (temp >>> 32)); - result = prime * result + Arrays.hashCode(label); - result = prime * result + nr_class; - result = prime * result + nr_feature; - result = prime * result + ((solverType == null) ? 0 : solverType.hashCode()); - result = prime * result + Arrays.hashCode(w); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) return true; - if (obj == null) return false; - if (getClass() != obj.getClass()) return false; - Model other = (Model)obj; - if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false; - if (!Arrays.equals(label, other.label)) return false; - if (nr_class != other.nr_class) return false; - if (nr_feature != other.nr_feature) return false; - if (solverType == null) { - if (other.solverType != null) return false; - } else if (!solverType.equals(other.solverType)) return false; - if (!equals(w, other.w)) return false; - return true; - } - - /** - * don't use {@link Arrays#equals(double[], double[])} here, cause 0.0 and -0.0 should be handled the same - * - * @see Linear#saveModel(java.io.Writer, Model) - */ - protected static boolean equals(double[] a, double[] a2) { - if (a == a2) return true; - if (a == null || a2 == null) return false; - - int length = a.length; - if (a2.length != length) return false; - - for (int i = 0; i < length; i++) - if (a[i] != a2[i]) return false; - - return true; - } - - /** - * see {@link Linear#saveModel(java.io.File, Model)} - */ - public void save(File file) throws IOException { - Linear.saveModel(file, this); - } - - /** - * see {@link Linear#saveModel(Writer, Model)} - */ - public void save(Writer writer) throws IOException { - Linear.saveModel(writer, this); - } - - /** - * see {@link Linear#loadModel(File)} - */ - public static Model load(File file) throws IOException { - return Linear.loadModel(file); - } - - /** - * see {@link Linear#loadModel(Reader)} - */ - public static Model load(Reader inputReader) throws IOException { - return Linear.loadModel(inputReader); - } -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Parameter.java b/src/main/java/de/bwaldvogel/denseliblinear/Parameter.java deleted file mode 100644 index 012b0a1..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/Parameter.java +++ /dev/null @@ -1,120 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import static de.bwaldvogel.denseliblinear.Linear.copyOf; - - -public final class Parameter { - - double C; - - /** stopping criteria */ - double eps; - - SolverType solverType; - - double[] weight = null; - - int[] weightLabel = null; - - double p; - - public Parameter( SolverType solver, double C, double eps ) { - this(solver, C, eps, 0.1); - } - - public Parameter( SolverType solverType, double C, double eps, double p ) { - setSolverType(solverType); - setC(C); - setEps(eps); - setP(p); - } - - /** - *

nr_weight, weight_label, and weight are used to change the penalty - * for some classes (If the weight for a class is not changed, it is - * set to 1). This is useful for training classifier using unbalanced - * input data or with asymmetric misclassification cost.

- * - *

Each weight[i] corresponds to weight_label[i], meaning that - * the penalty of class weight_label[i] is scaled by a factor of weight[i].

- * - *

If you do not want to change penalty for any of the classes, - * just set nr_weight to 0.

- */ - public void setWeights(double[] weights, int[] weightLabels) { - if (weights == null) throw new IllegalArgumentException("'weight' must not be null"); - if (weightLabels == null || weightLabels.length != weights.length) - throw new IllegalArgumentException("'weightLabels' must have same length as 'weight'"); - this.weightLabel = copyOf(weightLabels, weightLabels.length); - this.weight = copyOf(weights, weights.length); - } - - /** - * @see #setWeights(double[], int[]) - */ - public double[] getWeights() { - return copyOf(weight, weight.length); - } - - /** - * @see #setWeights(double[], int[]) - */ - public int[] getWeightLabels() { - return copyOf(weightLabel, weightLabel.length); - } - - /** - * the number of weights - * @see #setWeights(double[], int[]) - */ - public int getNumWeights() { - if (weight == null) return 0; - return weight.length; - } - - /** - * C is the cost of constraints violation. (we usually use 1 to 1000) - */ - public void setC(double C) { - if (C <= 0) throw new IllegalArgumentException("C must not be <= 0"); - this.C = C; - } - - public double getC() { - return C; - } - - /** - * eps is the stopping criterion. (we usually use 0.01). - */ - public void setEps(double eps) { - if (eps <= 0) throw new IllegalArgumentException("eps must not be <= 0"); - this.eps = eps; - } - - public double getEps() { - return eps; - } - - public void setSolverType(SolverType solverType) { - if (solverType == null) throw new IllegalArgumentException("solver type must not be null"); - this.solverType = solverType; - } - - public SolverType getSolverType() { - return solverType; - } - - - /** - * set the epsilon in loss function of epsilon-SVR (default 0.1) - */ - public void setP(double p) { - if (p < 0) throw new IllegalArgumentException("p must not be less than 0"); - this.p = p; - } - - public double getP() { - return p; - } -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Predict.java b/src/main/java/de/bwaldvogel/denseliblinear/Predict.java deleted file mode 100644 index 4b60801..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/Predict.java +++ /dev/null @@ -1,193 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import static de.bwaldvogel.denseliblinear.Linear.atof; -import static de.bwaldvogel.denseliblinear.Linear.atoi; -import static de.bwaldvogel.denseliblinear.Linear.closeQuietly; -import static de.bwaldvogel.denseliblinear.Linear.info; -import static de.bwaldvogel.denseliblinear.Linear.printf; - -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStreamReader; -import java.io.OutputStreamWriter; -import java.io.Writer; -import java.util.Formatter; -import java.util.NoSuchElementException; -import java.util.StringTokenizer; -import java.util.regex.Pattern; - -public class Predict { - - private static boolean flag_predict_probability = false; - - private static final Pattern COLON = Pattern.compile(":"); - - /** - *

- * Note: The streams are NOT closed - *

- */ - static void doPredict(BufferedReader reader, Writer writer, Model model) throws IOException { - int correct = 0; - int total = 0; - double error = 0; - double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0; - - final int nr_class = model.getNrClass(); - double[] prob_estimates = null; - int n; - final int nr_feature = model.getNrFeature(); - if (model.bias >= 0) - n = nr_feature + 1; - else - n = nr_feature; - - if (flag_predict_probability && !model.isProbabilityModel()) { - throw new IllegalArgumentException("probability output is only supported for logistic regression"); - } - - final Formatter out = new Formatter(writer); - - if (flag_predict_probability) { - final int[] labels = model.getLabels(); - prob_estimates = new double[nr_class]; - - printf(out, "labels"); - for (int j = 0; j < nr_class; j++) - printf(out, " %d", labels[j]); - printf(out, "\n"); - } - - String line = null; - while ((line = reader.readLine()) != null) { - final double[] nodes = new double[n]; - final StringTokenizer st = new StringTokenizer(line, " \t\n"); - double target_label; - try { - final String label = st.nextToken(); - target_label = atof(label); - } catch (final NoSuchElementException e) { - throw new RuntimeException("Wrong input format at line " + (total + 1), e); - } - - while (st.hasMoreTokens()) { - final String[] split = COLON.split(st.nextToken(), 2); - if (split == null || split.length < 2) { - throw new RuntimeException("Wrong input format at line " + (total + 1)); - } - - try { - final int idx = atoi(split[0]); - final double val = atof(split[1]); - - // feature indices larger than those in training are not - // used - if (idx <= nr_feature) { - nodes[idx - 1] = val; - } - } catch (final NumberFormatException e) { - throw new RuntimeException("Wrong input format at line " + (total + 1), e); - } - } - - if (model.bias >= 0) { - nodes[n - 1] = model.bias; - } - - double predict_label; - - if (flag_predict_probability) { - assert prob_estimates != null; - predict_label = Linear.predictProbability(model, nodes, prob_estimates); - printf(out, "%g", predict_label); - for (int j = 0; j < model.nr_class; j++) - printf(out, " %g", prob_estimates[j]); - printf(out, "\n"); - } else { - predict_label = Linear.predict(model, nodes); - printf(out, "%g\n", predict_label); - } - - if (predict_label == target_label) { - ++correct; - } - - error += (predict_label - target_label) * (predict_label - target_label); - sump += predict_label; - sumt += target_label; - sumpp += predict_label * predict_label; - sumtt += target_label * target_label; - sumpt += predict_label * target_label; - ++total; - } - - if (model.solverType.isSupportVectorRegression()) // - { - info("Mean squared error = %g (regression)%n", error / total); - info("Squared correlation coefficient = %g (regression)%n", // - ((total * sumpt - sump * sumt) * (total * sumpt - sump * sumt)) - / ((total * sumpp - sump * sump) * (total * sumtt - sumt * sumt))); - } else { - info("Accuracy = %g%% (%d/%d)%n", (double) correct / total * 100, correct, total); - } - } - - private static void exit_with_help() { - System.out - .printf("Usage: predict [options] test_file model_file output_file%n" // - + "options:%n" // - + "-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0); currently for logistic regression only%n" // - + "-q quiet mode (no outputs)%n"); - System.exit(1); - } - - public static void main(String[] argv) throws IOException { - int i; - - // parse options - for (i = 0; i < argv.length; i++) { - if (argv[i].charAt(0) != '-') - break; - ++i; - switch (argv[i - 1].charAt(1)) { - case 'b': - try { - flag_predict_probability = (atoi(argv[i]) != 0); - } catch (final NumberFormatException e) { - exit_with_help(); - } - break; - - case 'q': - i--; - Linear.disableDebugOutput(); - break; - - default: - System.err.printf("unknown option: -%d%n", argv[i - 1].charAt(1)); - exit_with_help(); - break; - } - } - if (i >= argv.length || argv.length <= i + 2) { - exit_with_help(); - } - - BufferedReader reader = null; - Writer writer = null; - try { - reader = new BufferedReader(new InputStreamReader(new FileInputStream(argv[i]), Linear.FILE_CHARSET)); - writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(argv[i + 2]), Linear.FILE_CHARSET)); - - final Model model = Linear.loadModel(new File(argv[i + 1])); - doPredict(reader, writer, model); - } finally { - closeQuietly(reader); - closeQuietly(writer); - } - } -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Problem.java b/src/main/java/de/bwaldvogel/denseliblinear/Problem.java deleted file mode 100644 index 2ede79d..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/Problem.java +++ /dev/null @@ -1,62 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import java.io.File; -import java.io.IOException; - -/** - *

- * Describes the problem - *

- * - * For example, if we have the following training data: - * - *
- *  LABEL       ATTR1   ATTR2   ATTR3   ATTR4   ATTR5
- *  -----       -----   -----   -----   -----   -----
- *  1           0       0.1     0.2     0       0
- *  2           0       0.1     0.3    -1.2     0
- *  1           0.4     0       0       0       0
- *  2           0       0.1     0       1.4     0.5
- *  3          -0.1    -0.2     0.1     1.1     0.1
- * 
- *  and bias = 1, then the components of problem are:
- * 
- *  l = 5
- *  n = 6
- * 
- *  y -> 1 2 1 2 3
- * 
- *  x -> [ ] -> (2,0.1) (3,0.2) (6,1) (-1,?)
- *       [ ] -> (2,0.1) (3,0.3) (4,-1.2) (6,1) (-1,?)
- *       [ ] -> (1,0.4) (6,1) (-1,?)
- *       [ ] -> (2,0.1) (4,1.4) (5,0.5) (6,1) (-1,?)
- *       [ ] -> (1,-0.1) (2,-0.2) (3,0.1) (4,1.1) (5,0.1) (6,1) (-1,?)
- * 
- */ -public class Problem { - - /** the number of training data */ - public int l; - - /** the number of features (including the bias feature if bias >= 0) */ - public int n; - - /** an array containing the target values */ - public double[] y; - - /** dense array of features */ - public double[][] x; - - /** - * If bias >= 0, we assume that one additional feature is added to the - * end of each data instance - */ - public double bias; - - /** - * see {@link Train#readProblem(File, double)} - */ - public static Problem readFromFile(File file, double bias) throws IOException, InvalidInputDataException { - return Train.readProblem(file, bias); - } -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/SolverMCSVM_CS.java b/src/main/java/de/bwaldvogel/denseliblinear/SolverMCSVM_CS.java deleted file mode 100644 index 20c13e9..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/SolverMCSVM_CS.java +++ /dev/null @@ -1,293 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import static de.bwaldvogel.denseliblinear.Linear.copyOf; -import static de.bwaldvogel.denseliblinear.Linear.info; -import static de.bwaldvogel.denseliblinear.Linear.swap; - -/** - * A coordinate descent algorithm for multi-class support vector machines by - * Crammer and Singer - * - *
- * min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i
- * s.t. \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i
- * 
- * where e^m_i = 0 if y_i = m,
- * e^m_i = 1 if y_i != m,
- * C^m_i = C if m = y_i,
- * C^m_i = 0 if m != y_i,
- * and w_m(\alpha) = \sum_i \alpha^m_i x_i
- * 
- * Given:
- * x, y, C
- * eps is the stopping tolerance
- * 
- * solution will be put in w
- * 
- * See Appendix of LIBLINEAR paper, Fan et al. (2008)
- * 
- */ -class SolverMCSVM_CS { - - private final double[] B; - private final double[] C; - private final double eps; - private final double[] G; - private final int max_iter; - private final int w_size, l; - private final int nr_class; - private final Problem prob; - - public SolverMCSVM_CS(Problem prob, int nr_class, double[] C) { - this(prob, nr_class, C, 0.1); - } - - public SolverMCSVM_CS(Problem prob, int nr_class, double[] C, double eps) { - this(prob, nr_class, C, eps, 100000); - } - - public SolverMCSVM_CS(Problem prob, int nr_class, double[] weighted_C, double eps, int max_iter) { - this.w_size = prob.n; - this.l = prob.l; - this.nr_class = nr_class; - this.eps = eps; - this.max_iter = max_iter; - this.prob = prob; - this.C = weighted_C; - this.B = new double[nr_class]; - this.G = new double[nr_class]; - } - - private int GETI(int i) { - return (int) prob.y[i]; - } - - private boolean be_shrunk(int i, int m, int yi, double alpha_i, double minG) { - double bound = 0; - if (m == yi) - bound = C[GETI(i)]; - if (alpha_i == bound && G[m] < minG) - return true; - return false; - } - - public void solve(double[] w) { - int i, m, s; - int iter = 0; - final double[] alpha = new double[l * nr_class]; - final double[] alpha_new = new double[nr_class]; - final int[] index = new int[l]; - final double[] QD = new double[l]; - final int[] d_ind = new int[nr_class]; - final double[] d_val = new double[nr_class]; - final int[] alpha_index = new int[nr_class * l]; - final int[] y_index = new int[l]; - int active_size = l; - final int[] active_size_i = new int[l]; - double eps_shrink = Math.max(10.0 * eps, 1.0); // stopping tolerance for - // shrinking - boolean start_from_all = true; - - // Initial alpha can be set here. Note that - // sum_m alpha[i*nr_class+m] = 0, for all i=1,...,l-1 - // alpha[i*nr_class+m] <= C[GETI(i)] if prob->y[i] == m - // alpha[i*nr_class+m] <= 0 if prob->y[i] != m - // If initial alpha isn't zero, uncomment the for loop below to - // initialize w - for (i = 0; i < l * nr_class; i++) - alpha[i] = 0; - - for (i = 0; i < w_size * nr_class; i++) - w[i] = 0; - for (i = 0; i < l; i++) { - for (m = 0; m < nr_class; m++) - alpha_index[i * nr_class + m] = m; - QD[i] = 0; - for (final double val : prob.x[i]) { - QD[i] += val * val; - - // Uncomment the for loop if initial alpha isn't zero - // for(m=0; mindex-1)*nr_class+m] += alpha[i*nr_class+m]*val; - } - active_size_i[i] = nr_class; - y_index[i] = (int) prob.y[i]; - index[i] = i; - } - - final DoubleArrayPointer alpha_i = new DoubleArrayPointer(alpha, 0); - final IntArrayPointer alpha_index_i = new IntArrayPointer(alpha_index, 0); - - while (iter < max_iter) { - double stopping = Double.NEGATIVE_INFINITY; - - for (i = 0; i < active_size; i++) { - // int j = i+rand()%(active_size-i); - final int j = i + Linear.random.nextInt(active_size - i); - swap(index, i, j); - } - for (s = 0; s < active_size; s++) { - - i = index[s]; - final double Ai = QD[i]; - // double *alpha_i = &alpha[i*nr_class]; - alpha_i.setOffset(i * nr_class); - - // int *alpha_index_i = &alpha_index[i*nr_class]; - alpha_index_i.setOffset(i * nr_class); - - if (Ai > 0) { - for (m = 0; m < active_size_i[i]; m++) - G[m] = 1; - if (y_index[i] < active_size_i[i]) - G[y_index[i]] = 0; - - for (int ind = 0; ind < prob.x[i].length; ind++) { - // double *w_i = &w[ind*nr_class]; - final int w_offset = ind * nr_class; - for (m = 0; m < active_size_i[i]; m++) - // G[m] += w_i[alpha_index_i[m]]*(prob.x[i][ind); - G[m] += w[w_offset + alpha_index_i.get(m)] * prob.x[i][ind]; - - } - - double minG = Double.POSITIVE_INFINITY; - double maxG = Double.NEGATIVE_INFINITY; - for (m = 0; m < active_size_i[i]; m++) { - if (alpha_i.get(alpha_index_i.get(m)) < 0 && G[m] < minG) - minG = G[m]; - if (G[m] > maxG) - maxG = G[m]; - } - if (y_index[i] < active_size_i[i]) { - if (alpha_i.get((int) prob.y[i]) < C[GETI(i)] && G[y_index[i]] < minG) { - minG = G[y_index[i]]; - } - } - - for (m = 0; m < active_size_i[i]; m++) { - if (be_shrunk(i, m, y_index[i], alpha_i.get(alpha_index_i.get(m)), minG)) { - active_size_i[i]--; - while (active_size_i[i] > m) { - if (!be_shrunk(i, active_size_i[i], y_index[i], - alpha_i.get(alpha_index_i.get(active_size_i[i])), minG)) - { - swap(alpha_index_i, m, active_size_i[i]); - swap(G, m, active_size_i[i]); - if (y_index[i] == active_size_i[i]) - y_index[i] = m; - else if (y_index[i] == m) - y_index[i] = active_size_i[i]; - break; - } - active_size_i[i]--; - } - } - } - - if (active_size_i[i] <= 1) { - active_size--; - swap(index, s, active_size); - s--; - continue; - } - - if (maxG - minG <= 1e-12) - continue; - else - stopping = Math.max(maxG - minG, stopping); - - for (m = 0; m < active_size_i[i]; m++) - B[m] = G[m] - Ai * alpha_i.get(alpha_index_i.get(m)); - - solve_sub_problem(Ai, y_index[i], C[GETI(i)], active_size_i[i], alpha_new); - int nz_d = 0; - for (m = 0; m < active_size_i[i]; m++) { - final double d = alpha_new[m] - alpha_i.get(alpha_index_i.get(m)); - alpha_i.set(alpha_index_i.get(m), alpha_new[m]); - if (Math.abs(d) >= 1e-12) { - d_ind[nz_d] = alpha_index_i.get(m); - d_val[nz_d] = d; - nz_d++; - } - } - - for (int ind = 0; ind < prob.x[i].length; ind++) { - // double *w_i = &w[ind*nr_class]; - final int w_offset = ind * nr_class; - for (m = 0; m < nz_d; m++) { - w[w_offset + d_ind[m]] += d_val[m] * prob.x[i][ind]; - } - } - } - } - - iter++; - - if (iter % 10 == 0) { - info("."); - } - - if (stopping < eps_shrink) { - if (stopping < eps && start_from_all == true) - break; - else { - active_size = l; - for (i = 0; i < l; i++) - active_size_i[i] = nr_class; - info("*"); - eps_shrink = Math.max(eps_shrink / 2, eps); - start_from_all = true; - } - } else - start_from_all = false; - } - - info("%noptimization finished, #iter = %d%n", iter); - if (iter >= max_iter) - info("%nWARNING: reaching max number of iterations%n"); - - // calculate objective value - double v = 0; - int nSV = 0; - for (i = 0; i < w_size * nr_class; i++) - v += w[i] * w[i]; - v = 0.5 * v; - for (i = 0; i < l * nr_class; i++) { - v += alpha[i]; - if (Math.abs(alpha[i]) > 0) - nSV++; - } - for (i = 0; i < l; i++) - v -= alpha[i * nr_class + (int) prob.y[i]]; - info("Objective value = %f%n", v); - info("nSV = %d%n", nSV); - - } - - private void solve_sub_problem(double A_i, int yi, double C_yi, int active_i, double[] alpha_new) { - - int r; - assert active_i <= B.length; // no padding - final double[] D = copyOf(B, active_i); - // clone(D, B, active_i); - - if (yi < active_i) - D[yi] += A_i * C_yi; - - // qsort(D, active_i, sizeof(double), compare_double); - ArraySorter.reversedMergesort(D); - - double beta = D[0] - A_i * C_yi; - for (r = 1; r < active_i && beta < r * D[r]; r++) - beta += D[r]; - beta /= r; - - for (r = 0; r < active_i; r++) { - if (r == yi) - alpha_new[r] = Math.min(C_yi, (beta - B[r]) / A_i); - else - alpha_new[r] = Math.min(0.0, (beta - B[r]) / A_i); - } - } -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/SolverType.java b/src/main/java/de/bwaldvogel/denseliblinear/SolverType.java deleted file mode 100644 index a792743..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/SolverType.java +++ /dev/null @@ -1,129 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import java.util.HashMap; -import java.util.Map; - - -public enum SolverType { - - /** - * L2-regularized logistic regression (primal) - * - * (fka L2_LR) - */ - L2R_LR(0, true, false), - - /** - * L2-regularized L2-loss support vector classification (dual) - * - * (fka L2LOSS_SVM_DUAL) - */ - L2R_L2LOSS_SVC_DUAL(1, false, false), - - /** - * L2-regularized L2-loss support vector classification (primal) - * - * (fka L2LOSS_SVM) - */ - L2R_L2LOSS_SVC(2, false, false), - - /** - * L2-regularized L1-loss support vector classification (dual) - * - * (fka L1LOSS_SVM_DUAL) - */ - L2R_L1LOSS_SVC_DUAL(3, false, false), - - /** - * multi-class support vector classification by Crammer and Singer - */ - MCSVM_CS(4, false, false), - - /** - * L1-regularized L2-loss support vector classification - * - * @since 1.5 - */ - L1R_L2LOSS_SVC(5, false, false), - - /** - * L1-regularized logistic regression - * - * @since 1.5 - */ - L1R_LR(6, true, false), - - /** - * L2-regularized logistic regression (dual) - * - * @since 1.7 - */ - L2R_LR_DUAL(7, true, false), - - /** - * L2-regularized L2-loss support vector regression (dual) - * - * @since 1.91 - */ - L2R_L2LOSS_SVR(11, false, true), - - /** - * L2-regularized L1-loss support vector regression (dual) - * - * @since 1.91 - */ - L2R_L2LOSS_SVR_DUAL(12, false, true), - - /** - * L2-regularized L2-loss support vector regression (primal) - * - * @since 1.91 - */ - L2R_L1LOSS_SVR_DUAL(13, false, true), - - ; - - private final boolean logisticRegressionSolver; - private final boolean supportVectorRegression; - private final int id; - - private SolverType( int id, boolean logisticRegressionSolver, boolean supportVectorRegression ) { - this.id = id; - this.logisticRegressionSolver = logisticRegressionSolver; - this.supportVectorRegression = supportVectorRegression; - } - - private static Map SOLVERS_BY_ID = new HashMap(); - static { - for (SolverType solverType : SolverType.values()) { - SolverType old = SOLVERS_BY_ID.put(Integer.valueOf(solverType.getId()), solverType); - if (old != null) throw new Error("duplicate solver type ID: " + solverType.getId()); - } - } - - public int getId() { - return id; - } - - public static SolverType getById(int id) { - SolverType solverType = SOLVERS_BY_ID.get(Integer.valueOf(id)); - if (solverType == null) { - throw new RuntimeException("found no solvertype for id " + id); - } - return solverType; - } - - /** - * @since 1.9 - */ - public boolean isLogisticRegressionSolver() { - return logisticRegressionSolver; - } - - /** - * @since 1.91 - */ - public boolean isSupportVectorRegression() { - return supportVectorRegression; - } -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Train.java b/src/main/java/de/bwaldvogel/denseliblinear/Train.java deleted file mode 100644 index 2a690fc..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/Train.java +++ /dev/null @@ -1,420 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import static de.bwaldvogel.denseliblinear.Linear.atof; -import static de.bwaldvogel.denseliblinear.Linear.atoi; - -import java.io.BufferedReader; -import java.io.File; -import java.io.FileReader; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.NoSuchElementException; -import java.util.StringTokenizer; - -public class Train { - - public static void main(String[] args) throws IOException, InvalidInputDataException { - new Train().run(args); - } - - private double bias = 1; - private boolean cross_validation = false; - private String inputFilename; - private String modelFilename; - private int nr_fold; - private Parameter param = null; - private Problem prob = null; - - private void do_cross_validation() { - - double total_error = 0; - double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; - final double[] target = new double[prob.l]; - - long start, stop; - start = System.currentTimeMillis(); - Linear.crossValidation(prob, param, nr_fold, target); - stop = System.currentTimeMillis(); - System.out.println("time: " + (stop - start) + " ms"); - - if (param.solverType.isSupportVectorRegression()) { - for (int i = 0; i < prob.l; i++) { - final double y = prob.y[i]; - final double v = target[i]; - total_error += (v - y) * (v - y); - sumv += v; - sumy += y; - sumvv += v * v; - sumyy += y * y; - sumvy += v * y; - } - System.out.printf("Cross Validation Mean squared error = %g%n", total_error / prob.l); - System.out.printf("Cross Validation Squared correlation coefficient = %g%n", // - ((prob.l * sumvy - sumv * sumy) * (prob.l * sumvy - sumv * sumy)) - / ((prob.l * sumvv - sumv * sumv) * (prob.l * sumyy - sumy * sumy))); - } else { - int total_correct = 0; - for (int i = 0; i < prob.l; i++) - if (target[i] == prob.y[i]) - ++total_correct; - - System.out.printf("correct: %d%n", total_correct); - System.out.printf("Cross Validation Accuracy = %g%%%n", 100.0 * total_correct / prob.l); - } - } - - private void exit_with_help() { - System.out.printf("Usage: train [options] training_set_file [model_file]%n" // - + "options:%n" - + "-s type : set type of solver (default 1)%n" - + " for multi-class classification%n" - + " 0 -- L2-regularized logistic regression (primal)%n" - + " 1 -- L2-regularized L2-loss support vector classification (dual)%n" - + " 2 -- L2-regularized L2-loss support vector classification (primal)%n" - + " 3 -- L2-regularized L1-loss support vector classification (dual)%n" - + " 4 -- support vector classification by Crammer and Singer%n" - + " 5 -- L1-regularized L2-loss support vector classification%n" - + " 6 -- L1-regularized logistic regression%n" - + " 7 -- L2-regularized logistic regression (dual)%n" - + " for regression%n" - + " 11 -- L2-regularized L2-loss support vector regression (primal)%n" - + " 12 -- L2-regularized L2-loss support vector regression (dual)%n" - + " 13 -- L2-regularized L1-loss support vector regression (dual)%n" - + "-c cost : set the parameter C (default 1)%n" - + "-p epsilon : set the epsilon in loss function of SVR (default 0.1)%n" - + "-e epsilon : set tolerance of termination criterion%n" - + " -s 0 and 2%n" + " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,%n" - + " where f is the primal function and pos/neg are # of%n" - + " positive/negative data (default 0.01)%n" + " -s 11%n" - + " |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.001)%n" - + " -s 1, 3, 4 and 7%n" + " Dual maximal violation <= eps; similar to libsvm (default 0.1)%n" - + " -s 5 and 6%n" - + " |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,%n" - + " where f is the primal function (default 0.01)%n" - + " -s 12 and 13\n" - + " |f'(alpha)|_1 <= eps |f'(alpha0)|,\n" - + " where f is the dual function (default 0.1)\n" - + "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)%n" - + "-wi weight: weights adjust the parameter C of different classes (see README for details)%n" - + "-v n: n-fold cross validation mode%n" - + "-q : quiet mode (no outputs)%n"); - System.exit(1); - } - - Problem getProblem() { - return prob; - } - - double getBias() { - return bias; - } - - Parameter getParameter() { - return param; - } - - void parse_command_line(String argv[]) { - int i; - - // eps: see setting below - param = new Parameter(SolverType.L2R_L2LOSS_SVC_DUAL, 1, Double.POSITIVE_INFINITY, 0.1); - // default values - bias = -1; - cross_validation = false; - - // parse options - for (i = 0; i < argv.length; i++) { - if (argv[i].charAt(0) != '-') - break; - if (++i >= argv.length) - exit_with_help(); - switch (argv[i - 1].charAt(1)) { - case 's': - param.solverType = SolverType.getById(atoi(argv[i])); - break; - case 'c': - param.setC(atof(argv[i])); - break; - case 'p': - param.setP(atof(argv[i])); - break; - case 'e': - param.setEps(atof(argv[i])); - break; - case 'B': - bias = atof(argv[i]); - break; - case 'w': - final int weightLabel = atoi(argv[i - 1].substring(2)); - final double weight = atof(argv[i]); - param.weightLabel = addToArray(param.weightLabel, weightLabel); - param.weight = addToArray(param.weight, weight); - break; - case 'v': - cross_validation = true; - nr_fold = atoi(argv[i]); - if (nr_fold < 2) { - System.err.println("n-fold cross validation: n must >= 2"); - exit_with_help(); - } - break; - case 'q': - i--; - Linear.disableDebugOutput(); - break; - default: - System.err.println("unknown option"); - exit_with_help(); - } - } - - // determine filenames - - if (i >= argv.length) - exit_with_help(); - - inputFilename = argv[i]; - - if (i < argv.length - 1) - modelFilename = argv[i + 1]; - else { - int p = argv[i].lastIndexOf('/'); - ++p; // whew... - modelFilename = argv[i].substring(p) + ".model"; - } - - if (param.eps == Double.POSITIVE_INFINITY) { - switch (param.solverType) { - case L2R_LR: - case L2R_L2LOSS_SVC: - param.setEps(0.01); - break; - case L2R_L2LOSS_SVR: - param.setEps(0.001); - break; - case L2R_L2LOSS_SVC_DUAL: - case L2R_L1LOSS_SVC_DUAL: - case MCSVM_CS: - case L2R_LR_DUAL: - param.setEps(0.1); - break; - case L1R_L2LOSS_SVC: - case L1R_LR: - param.setEps(0.01); - break; - case L2R_L1LOSS_SVR_DUAL: - case L2R_L2LOSS_SVR_DUAL: - param.setEps(0.1); - break; - default: - throw new IllegalStateException("unknown solver type: " + param.solverType); - } - } - } - - /** - * reads a problem from LibSVM format - * - * @param file - * the SVM file - * @throws IOException - * obviously in case of any I/O exception ;) - * @throws InvalidInputDataException - * if the input file is not correctly formatted - */ - static int readProblemFeatureDim(File file) throws IOException, InvalidInputDataException { - final BufferedReader fp = new BufferedReader(new FileReader(file)); - int max_index = 0; - int lineNr = 0; - - try { - while (true) { - final String line = fp.readLine(); - if (line == null) - break; - lineNr++; - - final StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:"); - String token; - try { - token = st.nextToken(); - } catch (final NoSuchElementException e) { - throw new InvalidInputDataException("empty line", file, lineNr, e); - } - - final int m = st.countTokens() / 2; - - int indexBefore = 0; - for (int j = 0; j < m; j++) { - token = st.nextToken(); - int index; - try { - index = atoi(token); - } catch (final NumberFormatException e) { - throw new InvalidInputDataException("invalid index: " + token, file, lineNr, e); - } - - // assert that indices are valid and sorted - if (index < 0) - throw new InvalidInputDataException("invalid index: " + index, file, lineNr); - if (index <= indexBefore) - throw new InvalidInputDataException("indices must be sorted in ascending order", file, lineNr); - indexBefore = index; - - token = st.nextToken(); - - if (index > max_index) { - max_index = index; - } - } - } - - return max_index; - } finally { - fp.close(); - } - } - - /** - * reads a problem from LibSVM format - * - * @param file - * the SVM file - * @throws IOException - * obviously in case of any I/O exception ;) - * @throws InvalidInputDataException - * if the input file is not correctly formatted - */ - public static Problem readProblem(File file, double bias) throws IOException, InvalidInputDataException { - final BufferedReader fp = new BufferedReader(new FileReader(file)); - final List vy = new ArrayList(); - final List vx = new ArrayList(); - - int lineNr = 0; - - final int w_size = readProblemFeatureDim(file); - - try { - while (true) { - final String line = fp.readLine(); - if (line == null) - break; - lineNr++; - - final StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:"); - String token; - try { - token = st.nextToken(); - } catch (final NoSuchElementException e) { - throw new InvalidInputDataException("empty line", file, lineNr, e); - } - - try { - vy.add(atof(token)); - } catch (final NumberFormatException e) { - throw new InvalidInputDataException("invalid label: " + token, file, lineNr, e); - } - - final int m = st.countTokens() / 2; - double[] x; - if (bias >= 0) { - x = new double[w_size + 1]; - } else { - x = new double[w_size]; - } - int indexBefore = 0; - for (int j = 0; j < m; j++) { - - token = st.nextToken(); - int index; - try { - index = atoi(token); - } catch (final NumberFormatException e) { - throw new InvalidInputDataException("invalid index: " + token, file, lineNr, e); - } - - // assert that indices are valid and sorted - if (index < 0) - throw new InvalidInputDataException("invalid index: " + index, file, lineNr); - if (index <= indexBefore) - throw new InvalidInputDataException("indices must be sorted in ascending order", file, lineNr); - indexBefore = index; - - token = st.nextToken(); - try { - final double value = atof(token); - x[index - 1] = value; - } catch (final NumberFormatException e) { - throw new InvalidInputDataException("invalid value: " + token, file, lineNr); - } - } - - vx.add(x); - } - - return constructProblem(vy, vx, w_size, bias); - } finally { - fp.close(); - } - } - - void readProblem(String filename) throws IOException, InvalidInputDataException { - prob = Train.readProblem(new File(filename), bias); - } - - private static int[] addToArray(int[] array, int newElement) { - final int length = array != null ? array.length : 0; - final int[] newArray = new int[length + 1]; - if (array != null && length > 0) { - System.arraycopy(array, 0, newArray, 0, length); - } - newArray[length] = newElement; - return newArray; - } - - private static double[] addToArray(double[] array, double newElement) { - final int length = array != null ? array.length : 0; - final double[] newArray = new double[length + 1]; - if (array != null && length > 0) { - System.arraycopy(array, 0, newArray, 0, length); - } - newArray[length] = newElement; - return newArray; - } - - private static Problem constructProblem(List vy, List vx, int max_index, double bias) { - final Problem prob = new Problem(); - prob.bias = bias; - prob.l = vy.size(); - prob.n = max_index; - if (bias >= 0) { - prob.n++; - } - prob.x = new double[prob.l][]; - for (int i = 0; i < prob.l; i++) { - prob.x[i] = vx.get(i); - - if (bias >= 0) { - prob.x[i][max_index] = bias; - } - } - - prob.y = new double[prob.l]; - for (int i = 0; i < prob.l; i++) - prob.y[i] = vy.get(i).doubleValue(); - - return prob; - } - - private void run(String[] args) throws IOException, InvalidInputDataException { - parse_command_line(args); - readProblem(inputFilename); - if (cross_validation) - do_cross_validation(); - else { - final Model model = Linear.train(prob, param); - Linear.saveModel(new File(modelFilename), model); - } - } -} diff --git a/src/main/java/de/bwaldvogel/denseliblinear/Tron.java b/src/main/java/de/bwaldvogel/denseliblinear/Tron.java deleted file mode 100644 index 1235175..0000000 --- a/src/main/java/de/bwaldvogel/denseliblinear/Tron.java +++ /dev/null @@ -1,260 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import static de.bwaldvogel.denseliblinear.Linear.info; - -/** - * Trust Region Newton Method optimization - */ -class Tron { - - private final Function fun_obj; - - private final double eps; - - private final int max_iter; - - public Tron( final Function fun_obj ) { - this(fun_obj, 0.1); - } - - public Tron( final Function fun_obj, double eps ) { - this(fun_obj, eps, 1000); - } - - public Tron( final Function fun_obj, double eps, int max_iter ) { - this.fun_obj = fun_obj; - this.eps = eps; - this.max_iter = max_iter; - } - - void tron(double[] w) { - // Parameters for updating the iterates. - double eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75; - - // Parameters for updating the trust region size delta. - double sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4; - - int n = fun_obj.get_nr_variable(); - int i, cg_iter; - double delta, snorm, one = 1.0; - double alpha, f, fnew, prered, actred, gs; - int search = 1, iter = 1; - double[] s = new double[n]; - double[] r = new double[n]; - double[] w_new = new double[n]; - double[] g = new double[n]; - - for (i = 0; i < n; i++) - w[i] = 0; - - f = fun_obj.fun(w); - fun_obj.grad(w, g); - delta = euclideanNorm(g); - double gnorm1 = delta; - double gnorm = gnorm1; - - if (gnorm <= eps * gnorm1) search = 0; - - iter = 1; - - while (iter <= max_iter && search != 0) { - cg_iter = trcg(delta, g, s, r); - - System.arraycopy(w, 0, w_new, 0, n); - daxpy(one, s, w_new); - - gs = dot(g, s); - prered = -0.5 * (gs - dot(s, r)); - fnew = fun_obj.fun(w_new); - - // Compute the actual reduction. - actred = f - fnew; - - // On the first iteration, adjust the initial step bound. - snorm = euclideanNorm(s); - if (iter == 1) delta = Math.min(delta, snorm); - - // Compute prediction alpha*snorm of the step. - if (fnew - f - gs <= 0) - alpha = sigma3; - else - alpha = Math.max(sigma1, -0.5 * (gs / (fnew - f - gs))); - - // Update the trust region bound according to the ratio of actual to - // predicted reduction. - if (actred < eta0 * prered) - delta = Math.min(Math.max(alpha, sigma1) * snorm, sigma2 * delta); - else if (actred < eta1 * prered) - delta = Math.max(sigma1 * delta, Math.min(alpha * snorm, sigma2 * delta)); - else if (actred < eta2 * prered) - delta = Math.max(sigma1 * delta, Math.min(alpha * snorm, sigma3 * delta)); - else - delta = Math.max(delta, Math.min(alpha * snorm, sigma3 * delta)); - - info("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d%n", iter, actred, prered, delta, f, gnorm, cg_iter); - - if (actred > eta0 * prered) { - iter++; - System.arraycopy(w_new, 0, w, 0, n); - f = fnew; - fun_obj.grad(w, g); - - gnorm = euclideanNorm(g); - if (gnorm <= eps * gnorm1) break; - } - if (f < -1.0e+32) { - info("WARNING: f < -1.0e+32%n"); - break; - } - if (Math.abs(actred) <= 0 && prered <= 0) { - info("WARNING: actred and prered <= 0%n"); - break; - } - if (Math.abs(actred) <= 1.0e-12 * Math.abs(f) && Math.abs(prered) <= 1.0e-12 * Math.abs(f)) { - info("WARNING: actred and prered too small%n"); - break; - } - } - } - - private int trcg(double delta, double[] g, double[] s, double[] r) { - int n = fun_obj.get_nr_variable(); - double one = 1; - double[] d = new double[n]; - double[] Hd = new double[n]; - double rTr, rnewTrnew, cgtol; - - for (int i = 0; i < n; i++) { - s[i] = 0; - r[i] = -g[i]; - d[i] = r[i]; - } - cgtol = 0.1 * euclideanNorm(g); - - int cg_iter = 0; - rTr = dot(r, r); - - while (true) { - if (euclideanNorm(r) <= cgtol) break; - cg_iter++; - fun_obj.Hv(d, Hd); - - double alpha = rTr / dot(d, Hd); - daxpy(alpha, d, s); - if (euclideanNorm(s) > delta) { - info("cg reaches trust region boundary%n"); - alpha = -alpha; - daxpy(alpha, d, s); - - double std = dot(s, d); - double sts = dot(s, s); - double dtd = dot(d, d); - double dsq = delta * delta; - double rad = Math.sqrt(std * std + dtd * (dsq - sts)); - if (std >= 0) - alpha = (dsq - sts) / (std + rad); - else - alpha = (rad - std) / dtd; - daxpy(alpha, d, s); - alpha = -alpha; - daxpy(alpha, Hd, r); - break; - } - alpha = -alpha; - daxpy(alpha, Hd, r); - rnewTrnew = dot(r, r); - double beta = rnewTrnew / rTr; - scale(beta, d); - daxpy(one, r, d); - rTr = rnewTrnew; - } - - return (cg_iter); - } - - /** - * constant times a vector plus a vector - * - *
-     * vector2 += constant * vector1
-     * 
- * - * @since 1.8 - */ - private static void daxpy(double constant, double vector1[], double vector2[]) { - if (constant == 0) return; - - assert vector1.length == vector2.length; - for (int i = 0; i < vector1.length; i++) { - vector2[i] += constant * vector1[i]; - } - } - - /** - * returns the dot product of two vectors - * - * @since 1.8 - */ - private static double dot(double vector1[], double vector2[]) { - - double product = 0; - assert vector1.length == vector2.length; - for (int i = 0; i < vector1.length; i++) { - product += vector1[i] * vector2[i]; - } - return product; - - } - - /** - * returns the euclidean norm of a vector - * - * @since 1.8 - */ - private static double euclideanNorm(double vector[]) { - - int n = vector.length; - - if (n < 1) { - return 0; - } - - if (n == 1) { - return Math.abs(vector[0]); - } - - // this algorithm is (often) more accurate than just summing up the squares and taking the square-root afterwards - - double scale = 0; // scaling factor that is factored out - double sum = 1; // basic sum of squares from which scale has been factored out - for (int i = 0; i < n; i++) { - if (vector[i] != 0) { - double abs = Math.abs(vector[i]); - // try to get the best scaling factor - if (scale < abs) { - double t = scale / abs; - sum = 1 + sum * (t * t); - scale = abs; - } else { - double t = abs / scale; - sum += t * t; - } - } - } - - return scale * Math.sqrt(sum); - } - - /** - * scales a vector by a constant - * - * @since 1.8 - */ - private static void scale(double constant, double vector[]) { - if (constant == 1.0) return; - for (int i = 0; i < vector.length; i++) { - vector[i] *= constant; - } - - } -} diff --git a/src/test/java/de/bwaldvogel/denseliblinear/ArrayPointerTest.java b/src/test/java/de/bwaldvogel/denseliblinear/ArrayPointerTest.java deleted file mode 100644 index d524c82..0000000 --- a/src/test/java/de/bwaldvogel/denseliblinear/ArrayPointerTest.java +++ /dev/null @@ -1,63 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import static org.fest.assertions.Assertions.assertThat; -import static org.fest.assertions.Fail.fail; - -import org.junit.Test; - -import de.bwaldvogel.denseliblinear.DoubleArrayPointer; -import de.bwaldvogel.denseliblinear.IntArrayPointer; - - -public class ArrayPointerTest { - - @Test - public void testGetIntArrayPointer() { - int[] foo = new int[] {1, 2, 3, 4, 6}; - IntArrayPointer pFoo = new IntArrayPointer(foo, 2); - assertThat(pFoo.get(0)).isEqualTo(3); - assertThat(pFoo.get(1)).isEqualTo(4); - assertThat(pFoo.get(2)).isEqualTo(6); - try { - pFoo.get(3); - fail("ArrayIndexOutOfBoundsException expected"); - } catch (ArrayIndexOutOfBoundsException e) {} - } - - @Test - public void testSetIntArrayPointer() { - int[] foo = new int[] {1, 2, 3, 4, 6}; - IntArrayPointer pFoo = new IntArrayPointer(foo, 2); - pFoo.set(2, 5); - assertThat(foo).isEqualTo(new int[] {1, 2, 3, 4, 5}); - try { - pFoo.set(3, 0); - fail("ArrayIndexOutOfBoundsException expected"); - } catch (ArrayIndexOutOfBoundsException e) {} - } - - @Test - public void testGetDoubleArrayPointer() { - double[] foo = new double[] {1, 2, 3, 4, 6}; - DoubleArrayPointer pFoo = new DoubleArrayPointer(foo, 2); - assertThat(pFoo.get(0)).isEqualTo(3); - assertThat(pFoo.get(1)).isEqualTo(4); - assertThat(pFoo.get(2)).isEqualTo(6); - try { - pFoo.get(3); - fail("ArrayIndexOutOfBoundsException expected"); - } catch (ArrayIndexOutOfBoundsException e) {} - } - - @Test - public void testSetDoubleArrayPointer() { - double[] foo = new double[] {1, 2, 3, 4, 6}; - DoubleArrayPointer pFoo = new DoubleArrayPointer(foo, 2); - pFoo.set(2, 5); - assertThat(foo).isEqualTo(new double[] {1, 2, 3, 4, 5}); - try { - pFoo.set(3, 0); - fail("ArrayIndexOutOfBoundsException expected"); - } catch (ArrayIndexOutOfBoundsException e) {} - } -} diff --git a/src/test/java/de/bwaldvogel/denseliblinear/ArraySorterTest.java b/src/test/java/de/bwaldvogel/denseliblinear/ArraySorterTest.java deleted file mode 100644 index 38fb53a..0000000 --- a/src/test/java/de/bwaldvogel/denseliblinear/ArraySorterTest.java +++ /dev/null @@ -1,58 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import static de.bwaldvogel.denseliblinear.Linear.swap; -import static org.fest.assertions.Assertions.assertThat; - -import java.util.Random; - -import org.junit.Test; - -import de.bwaldvogel.denseliblinear.ArraySorter; - - -public class ArraySorterTest { - - private Random random = new Random(); - - private void assertDescendingOrder(double[] array) { - double before = array[0]; - for (double d : array) { - // accept that case - if (d == 0.0 && before == -0.0) continue; - - assertThat(d).isLessThanOrEqualTo(before); - before = d; - } - } - - private void shuffleArray(double[] array) { - - for (int i = 0; i < array.length; i++) { - int j = random.nextInt(array.length); - swap(array, i, j); - } - } - - @Test - public void testReversedMergesort() { - - for (int k = 1; k <= 16 * 8096; k *= 2) { - // create random array - double[] array = new double[k]; - for (int i = 0; i < array.length; i++) { - array[i] = random.nextDouble(); - } - - ArraySorter.reversedMergesort(array); - assertDescendingOrder(array); - } - } - - @Test - public void testReversedMergesortWithMeanValues() { - double[] array = new double[] {1.0, -0.0, -1.1, 2.0, 3.0, 0.0, 4.0, -0.0, 0.0}; - shuffleArray(array); - ArraySorter.reversedMergesort(array); - assertDescendingOrder(array); - } -} diff --git a/src/test/java/de/bwaldvogel/denseliblinear/LinearTest.java b/src/test/java/de/bwaldvogel/denseliblinear/LinearTest.java deleted file mode 100644 index 753715f..0000000 --- a/src/test/java/de/bwaldvogel/denseliblinear/LinearTest.java +++ /dev/null @@ -1,517 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import static org.fest.assertions.Assertions.assertThat; -import static org.fest.assertions.Fail.fail; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; - -import java.io.File; -import java.io.IOException; -import java.io.Writer; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Random; -import java.util.Set; -import java.util.TreeSet; - -import org.fest.assertions.Delta; -import org.junit.BeforeClass; -import org.junit.Test; -import org.powermock.api.mockito.PowerMockito; - -public class LinearTest { - - private static Random random = new Random(12345); - - @BeforeClass - public static void disableDebugOutput() { - // Linear.disableDebugOutput(); - } - - public static Model createRandomModel() { - final Model model = new Model(); - model.solverType = SolverType.L2R_LR; - model.bias = 2; - model.label = new int[] { 1, Integer.MAX_VALUE, 2 }; - model.w = new double[model.label.length * 300]; - for (int i = 0; i < model.w.length; i++) { - // precision should be at least 1e-4 - model.w[i] = Math.round(random.nextDouble() * 100000.0) / 10000.0; - } - - // force at least one value to be zero - model.w[random.nextInt(model.w.length)] = 0.0; - model.w[random.nextInt(model.w.length)] = -0.0; - - model.nr_feature = model.w.length / model.label.length - 1; - model.nr_class = model.label.length; - return model; - } - - public static Problem createRandomProblem(int numClasses) { - final Problem prob = new Problem(); - prob.bias = -1; - prob.l = random.nextInt(100) + 1; - prob.n = random.nextInt(100) + 1; - prob.x = new double[prob.l][]; - prob.y = new double[prob.l]; - - for (int i = 0; i < prob.l; i++) { - - prob.y[i] = random.nextInt(numClasses); - - final Set randomNumbers = new TreeSet(); - final int num = random.nextInt(prob.n); - for (int j = 0; j < num; j++) { - randomNumbers.add(random.nextInt(prob.n)); - } - final List randomIndices = new ArrayList(randomNumbers); - Collections.sort(randomIndices); - - prob.x[i] = new double[prob.n]; - for (int j = 0; j < randomIndices.size(); j++) { - prob.x[i][randomIndices.get(j)] = random.nextDouble(); - } - } - return prob; - } - - /** - * create a very simple problem and check if the clearly separated examples - * are recognized as such - */ - @Test - public void testTrainPredict() { - final Problem prob = new Problem(); - prob.bias = -1; - prob.l = 4; - prob.n = 4; - prob.x = new double[4][4]; - - prob.x[0][0] = 1; - prob.x[0][1] = 1; - - prob.x[1][2] = 1; - prob.x[2][2] = 1; - - prob.x[3][0] = 2; - prob.x[3][1] = 1; - prob.x[3][3] = 1; - - prob.y = new double[4]; - prob.y[0] = 0; - prob.y[1] = 1; - prob.y[2] = 1; - prob.y[3] = 0; - - for (final SolverType solver : SolverType.values()) { - for (double C = 0.1; C <= 100.; C *= 1.2) { - // compared the behavior with the C version - if (C < 0.2) - if (solver == SolverType.L1R_L2LOSS_SVC) - continue; - if (C < 0.7) - if (solver == SolverType.L1R_LR) - continue; - - if (solver.isSupportVectorRegression()) { - continue; - } - - final Parameter param = new Parameter(solver, C, 0.1, 0.1); - final Model model = Linear.train(prob, param); - - final double[] featureWeights = model.getFeatureWeights(); - if (solver == SolverType.MCSVM_CS) { - assertThat(featureWeights.length).isEqualTo(8); - } else { - assertThat(featureWeights.length).isEqualTo(4); - } - - int i = 0; - for (final double value : prob.y) { - final double prediction = Linear.predict(model, prob.x[i]); - assertThat(prediction).as("prediction with solver " + solver).isEqualTo(value); - if (model.isProbabilityModel()) { - final double[] estimates = new double[model.getNrClass()]; - final double probabilityPrediction = Linear.predictProbability(model, prob.x[i], estimates); - assertThat(probabilityPrediction).isEqualTo(prediction); - assertThat(estimates[(int) probabilityPrediction]).isGreaterThanOrEqualTo( - 1.0 / model.getNrClass()); - double estimationSum = 0; - for (final double estimate : estimates) { - estimationSum += estimate; - } - assertThat(estimationSum).isEqualTo(1.0, Delta.delta(0.001)); - } - i++; - } - } - } - } - - @Test - public void testCrossValidation() throws Exception { - - final int numClasses = random.nextInt(10) + 1; - - final Problem prob = createRandomProblem(numClasses); - - final Parameter param = new Parameter(SolverType.L2R_LR, 10, 0.01); - final int nr_fold = 10; - final double[] target = new double[prob.l]; - Linear.crossValidation(prob, param, nr_fold, target); - - for (final double clazz : target) { - assertThat(clazz).isGreaterThanOrEqualTo(0).isLessThan(numClasses); - } - } - - @Test - public void testLoadSaveModel() throws Exception { - - Model model = null; - for (final SolverType solverType : SolverType.values()) { - model = createRandomModel(); - model.solverType = solverType; - - final File tempFile = File.createTempFile("liblinear", "modeltest"); - tempFile.deleteOnExit(); - Linear.saveModel(tempFile, model); - - final Model loadedModel = Linear.loadModel(tempFile); - assertThat(loadedModel).isEqualTo(model); - } - } - - @Test - public void testPredictProbabilityWrongSolver() throws Exception { - final Problem prob = new Problem(); - prob.l = 1; - prob.n = 1; - prob.x = new double[prob.l][prob.n]; - prob.y = new double[prob.l]; - for (int i = 0; i < prob.l; i++) { - prob.y[i] = i; - } - - final SolverType solverType = SolverType.L2R_L1LOSS_SVC_DUAL; - final Parameter param = new Parameter(solverType, 10, 0.1); - final Model model = Linear.train(prob, param); - try { - Linear.predictProbability(model, prob.x[0], new double[1]); - fail("IllegalArgumentException expected"); - } catch (final IllegalArgumentException e) { - assertThat(e.getMessage()).isEqualTo("probability output is only supported for logistic regression." // - + " This is currently only supported by the following solvers:" // - + " L2R_LR, L1R_LR, L2R_LR_DUAL"); - } - } - - @Test - public void testRealloc() { - - int[] f = new int[] { 1, 2, 3 }; - f = Linear.copyOf(f, 5); - f[3] = 4; - f[4] = 5; - assertThat(f).isEqualTo(new int[] { 1, 2, 3, 4, 5 }); - } - - @Test - public void testAtoi() { - assertThat(Linear.atoi("+25")).isEqualTo(25); - assertThat(Linear.atoi("-345345")).isEqualTo(-345345); - assertThat(Linear.atoi("+0")).isEqualTo(0); - assertThat(Linear.atoi("0")).isEqualTo(0); - assertThat(Linear.atoi("2147483647")).isEqualTo(Integer.MAX_VALUE); - assertThat(Linear.atoi("-2147483648")).isEqualTo(Integer.MIN_VALUE); - } - - @Test(expected = NumberFormatException.class) - public void testAtoiInvalidData() { - Linear.atoi("+"); - } - - @Test(expected = NumberFormatException.class) - public void testAtoiInvalidData2() { - Linear.atoi("abc"); - } - - @Test(expected = NumberFormatException.class) - public void testAtoiInvalidData3() { - Linear.atoi(" "); - } - - @Test - public void testAtof() { - assertThat(Linear.atof("+25")).isEqualTo(25); - assertThat(Linear.atof("-25.12345678")).isEqualTo(-25.12345678); - assertThat(Linear.atof("0.345345299")).isEqualTo(0.345345299); - } - - @Test(expected = NumberFormatException.class) - public void testAtofInvalidData() { - Linear.atof("0.5t"); - } - - @Test - public void testSaveModelWithIOException() throws Exception { - final Model model = createRandomModel(); - - final Writer out = PowerMockito.mock(Writer.class); - - final IOException ioException = new IOException("some reason"); - - doThrow(ioException).when(out).flush(); - - try { - Linear.saveModel(out, model); - fail("IOException expected"); - } catch (final IOException e) { - assertThat(e).isEqualTo(ioException); - } - - verify(out).flush(); - verify(out, times(1)).close(); - } - - /** - * compared input/output values with the C version (1.51) - * - *
-	 * IN:
-	 * res prob.l = 4
-	 * res prob.n = 4
-	 * 0: (2,1) (4,1)
-	 * 1: (1,1)
-	 * 2: (3,1)
-	 * 3: (2,2) (3,1) (4,1)
-	 * 
-	 * TRANSPOSED:
-	 * 
-	 * res prob.l = 4
-	 * res prob.n = 4
-	 * 0: (2,1)
-	 * 1: (1,1) (4,2)
-	 * 2: (3,1) (4,1)
-	 * 3: (1,1) (4,1)
-	 * 
- */ - @Test - public void testTranspose() throws Exception { - final Problem prob = new Problem(); - prob.bias = -1; - prob.l = 4; - prob.n = 4; - prob.x = new double[4][4]; - - prob.x[0][1] = 1; - prob.x[0][3] = 1; - - prob.x[1][0] = 1; - prob.x[2][2] = 1; - - prob.x[3][1] = 2; - prob.x[3][2] = 1; - prob.x[3][3] = 1; - - prob.y = new double[4]; - prob.y[0] = 0; - prob.y[1] = 1; - prob.y[2] = 1; - prob.y[3] = 0; - - final Problem transposed = Linear.transpose(prob); - - assertThat(transposed.x[0].length).isEqualTo(4); - assertThat(transposed.x[1].length).isEqualTo(4); - assertThat(transposed.x[2].length).isEqualTo(4); - assertThat(transposed.x[3].length).isEqualTo(4); - - assertThat(transposed.x[0][1]).isEqualTo(1); - - assertThat(transposed.x[1][0]).isEqualTo(1); - assertThat(transposed.x[1][3]).isEqualTo(2); - - assertThat(transposed.x[2][2]).isEqualTo(1); - assertThat(transposed.x[2][3]).isEqualTo(1); - - assertThat(transposed.x[3][0]).isEqualTo(1); - assertThat(transposed.x[3][3]).isEqualTo(1); - - assertThat(transposed.y).isEqualTo(prob.y); - } - - /** - * - * compared input/output values with the C version (1.51) - * - *
-	 * IN:
-	 * res prob.l = 5
-	 * res prob.n = 10
-	 * 0: (1,7) (3,3) (5,2)
-	 * 1: (2,1) (4,5) (5,3) (7,4) (8,2)
-	 * 2: (1,9) (3,1) (5,1) (10,7)
-	 * 3: (1,2) (2,2) (3,9) (4,7) (5,8) (6,1) (7,5) (8,4)
-	 * 4: (3,1) (10,3)
-	 * 
-	 * TRANSPOSED:
-	 * 
-	 * res prob.l = 5
-	 * res prob.n = 10
-	 * 0: (1,7) (3,9) (4,2)
-	 * 1: (2,1) (4,2)
-	 * 2: (1,3) (3,1) (4,9) (5,1)
-	 * 3: (2,5) (4,7)
-	 * 4: (1,2) (2,3) (3,1) (4,8)
-	 * 5: (4,1)
-	 * 6: (2,4) (4,5)
-	 * 7: (2,2) (4,4)
-	 * 8:
-	 * 9: (3,7) (5,3)
-	 * 
- */ - @Test - public void testTranspose2() throws Exception { - final Problem prob = new Problem(); - prob.bias = -1; - prob.l = 5; - prob.n = 10; - prob.x = new double[5][10]; - - prob.x[0][0] = 7; - prob.x[0][2] = 3; - prob.x[0][4] = 2; - - prob.x[1][1] = 1; - prob.x[1][3] = 5; - prob.x[1][4] = 3; - prob.x[1][6] = 4; - prob.x[1][7] = 2; - - prob.x[2][0] = 9; - prob.x[2][2] = 1; - prob.x[2][4] = 1; - prob.x[2][9] = 7; - - prob.x[3][0] = 2; - prob.x[3][1] = 2; - prob.x[3][2] = 9; - prob.x[3][3] = 7; - prob.x[3][4] = 8; - prob.x[3][5] = 1; - prob.x[3][6] = 5; - prob.x[3][7] = 4; - - prob.x[4][2] = 1; - prob.x[4][9] = 3; - - prob.y = new double[5]; - prob.y[0] = 0; - prob.y[1] = 1; - prob.y[2] = 1; - prob.y[3] = 0; - prob.y[4] = 1; - - final Problem transposed = Linear.transpose(prob); - - assertThat(transposed.x[0]).hasSize(5); - assertThat(transposed.x[1]).hasSize(5); - assertThat(transposed.x[2]).hasSize(5); - assertThat(transposed.x[3]).hasSize(5); - assertThat(transposed.x[4]).hasSize(5); - assertThat(transposed.x[5]).hasSize(5); - assertThat(transposed.x[7]).hasSize(5); - assertThat(transposed.x[7]).hasSize(5); - assertThat(transposed.x[8]).hasSize(5); - assertThat(transposed.x[9]).hasSize(5); - - assertThat(transposed.x[0][0]).isEqualTo(7); - assertThat(transposed.x[0][2]).isEqualTo(9); - assertThat(transposed.x[0][3]).isEqualTo(2); - - assertThat(transposed.x[1][1]).isEqualTo(1); - assertThat(transposed.x[1][3]).isEqualTo(2); - - assertThat(transposed.x[2][0]).isEqualTo(3); - assertThat(transposed.x[2][2]).isEqualTo(1); - assertThat(transposed.x[2][3]).isEqualTo(9); - assertThat(transposed.x[2][4]).isEqualTo(1); - - assertThat(transposed.x[3][1]).isEqualTo(5); - assertThat(transposed.x[3][3]).isEqualTo(7); - - assertThat(transposed.x[4][0]).isEqualTo(2); - assertThat(transposed.x[4][1]).isEqualTo(3); - assertThat(transposed.x[4][2]).isEqualTo(1); - assertThat(transposed.x[4][3]).isEqualTo(8); - - assertThat(transposed.x[5][3]).isEqualTo(1); - - assertThat(transposed.x[6][1]).isEqualTo(4); - assertThat(transposed.x[6][3]).isEqualTo(5); - - assertThat(transposed.x[7][1]).isEqualTo(2); - assertThat(transposed.x[7][3]).isEqualTo(4); - - assertThat(transposed.x[9][2]).isEqualTo(7); - assertThat(transposed.x[9][4]).isEqualTo(3); - - assertThat(transposed.y).isEqualTo(prob.y); - } - - /** - * compared input/output values with the C version (1.51) - * - * IN: res prob.l = 3 res prob.n = 4 0: (1,2) (3,1) (4,3) 1: (1,9) (2,7) - * (3,3) (4,3) 2: (2,1) - * - * TRANSPOSED: - * - * res prob.l = 3 * res prob.n = 4 0: (1,2) (2,9) 1: (2,7) (3,1) 2: (1,1) - * (2,3) 3: (1,3) (2,3) - * - */ - @Test - public void testTranspose3() throws Exception { - - final Problem prob = new Problem(); - prob.l = 3; - prob.n = 4; - prob.y = new double[3]; - prob.x = new double[3][4]; - - prob.x[0][0] = 2; - prob.x[0][2] = 1; - prob.x[0][3] = 3; - prob.x[1][0] = 9; - prob.x[1][1] = 7; - prob.x[1][2] = 3; - prob.x[1][3] = 3; - - prob.x[2][1] = 1; - - final Problem transposed = Linear.transpose(prob); - assertThat(transposed.x).hasSize(4); - assertThat(transposed.x[0]).hasSize(3); - assertThat(transposed.x[1]).hasSize(3); - assertThat(transposed.x[2]).hasSize(3); - assertThat(transposed.x[3]).hasSize(3); - - assertThat(transposed.x[0][0]).isEqualTo(2); - assertThat(transposed.x[0][1]).isEqualTo(9); - - assertThat(transposed.x[1][1]).isEqualTo(7); - assertThat(transposed.x[1][2]).isEqualTo(1); - - assertThat(transposed.x[2][0]).isEqualTo(1); - assertThat(transposed.x[2][1]).isEqualTo(3); - - assertThat(transposed.x[3][0]).isEqualTo(3); - assertThat(transposed.x[3][1]).isEqualTo(3); - } -} diff --git a/src/test/java/de/bwaldvogel/denseliblinear/ParameterTest.java b/src/test/java/de/bwaldvogel/denseliblinear/ParameterTest.java deleted file mode 100644 index c0c55d6..0000000 --- a/src/test/java/de/bwaldvogel/denseliblinear/ParameterTest.java +++ /dev/null @@ -1,127 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import static org.fest.assertions.Assertions.assertThat; -import static org.junit.Assert.fail; - -import org.junit.Before; -import org.junit.Test; - -import de.bwaldvogel.denseliblinear.Parameter; -import de.bwaldvogel.denseliblinear.SolverType; - - -public class ParameterTest { - - private Parameter _param; - - @Before - public void setUp() { - _param = new Parameter(SolverType.L2R_L1LOSS_SVC_DUAL, 100, 1e-3); - } - - @Test - public void testSetWeights() { - - assertThat(_param.weight).isNull(); - assertThat(_param.getNumWeights()).isEqualTo(0); - - double[] weights = new double[] {0, 1, 2, 3, 4, 5}; - int[] weightLabels = new int[] {1, 1, 1, 1, 2, 3}; - _param.setWeights(weights, weightLabels); - - assertThat(_param.getNumWeights()).isEqualTo(6); - - // assert parameter uses a copy - weights[0]++; - assertThat(_param.getWeights()[0]).isEqualTo(0); - weightLabels[0]++; - assertThat(_param.getWeightLabels()[0]).isEqualTo(1); - - weights = new double[] {0, 1, 2, 3, 4, 5}; - weightLabels = new int[] {1}; - try { - _param.setWeights(weights, weightLabels); - fail("IllegalArgumentException expected"); - } catch (IllegalArgumentException e) { - assertThat(e.getMessage()).contains("same").contains("length"); - } - } - - @Test - public void testGetWeights() { - double[] weights = new double[] {0, 1, 2, 3, 4, 5}; - int[] weightLabels = new int[] {1, 1, 1, 1, 2, 3}; - _param.setWeights(weights, weightLabels); - - assertThat(_param.getWeights()).isEqualTo(weights); - _param.getWeights()[0]++; // shouldn't change the parameter as we should get a copy - assertThat(_param.getWeights()).isEqualTo(weights); - - assertThat(_param.getWeightLabels()).isEqualTo(weightLabels); - _param.getWeightLabels()[0]++; // shouldn't change the parameter as we should get a copy - assertThat(_param.getWeightLabels()[0]).isEqualTo(1); - } - - @Test - public void testSetC() { - _param.setC(0.0001); - assertThat(_param.getC()).isEqualTo(0.0001); - _param.setC(1); - _param.setC(100); - assertThat(_param.getC()).isEqualTo(100); - _param.setC(Double.MAX_VALUE); - - try { - _param.setC(-1); - fail("IllegalArgumentException expected"); - } catch (IllegalArgumentException e) { - assertThat(e.getMessage()).contains("must").contains("not").contains("<= 0"); - } - - try { - _param.setC(0); - fail("IllegalArgumentException expected"); - } catch (IllegalArgumentException e) { - assertThat(e.getMessage()).contains("must").contains("not").contains("<= 0"); - } - } - - @Test - public void testSetEps() { - _param.setEps(0.0001); - assertThat(_param.getEps()).isEqualTo(0.0001); - _param.setEps(1); - _param.setEps(100); - assertThat(_param.getEps()).isEqualTo(100); - _param.setEps(Double.MAX_VALUE); - - try { - _param.setEps(-1); - fail("IllegalArgumentException expected"); - } catch (IllegalArgumentException e) { - assertThat(e.getMessage()).contains("must").contains("not").contains("<= 0"); - } - - try { - _param.setEps(0); - fail("IllegalArgumentException expected"); - } catch (IllegalArgumentException e) { - assertThat(e.getMessage()).contains("must").contains("not").contains("<= 0"); - } - } - - @Test - public void testSetSolverType() { - for (SolverType type : SolverType.values()) { - _param.setSolverType(type); - assertThat(_param.getSolverType()).isEqualTo(type); - } - try { - _param.setSolverType(null); - fail("IllegalArgumentException expected"); - } catch (IllegalArgumentException e) { - assertThat(e.getMessage()).contains("must").contains("not").contains("null"); - } - } - -} diff --git a/src/test/java/de/bwaldvogel/denseliblinear/PredictTest.java b/src/test/java/de/bwaldvogel/denseliblinear/PredictTest.java deleted file mode 100644 index 7c411fe..0000000 --- a/src/test/java/de/bwaldvogel/denseliblinear/PredictTest.java +++ /dev/null @@ -1,57 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import static org.fest.assertions.Assertions.assertThat; -import static org.mockito.Mockito.mock; - -import java.io.BufferedReader; -import java.io.PrintStream; -import java.io.StringReader; -import java.io.StringWriter; -import java.io.Writer; - -import org.junit.Before; -import org.junit.Test; - -import de.bwaldvogel.denseliblinear.Model; -import de.bwaldvogel.denseliblinear.Predict; - - -public class PredictTest { - - private Model testModel = LinearTest.createRandomModel(); - private StringBuilder sb = new StringBuilder(); - private Writer writer = new StringWriter(); - - @Before - public void setUp() { - System.setOut(mock(PrintStream.class)); // dev/null - assertThat(testModel.getNrClass()).isGreaterThanOrEqualTo(2); - assertThat(testModel.getNrFeature()).isGreaterThanOrEqualTo(10); - } - - private void testWithLines(StringBuilder sb) throws Exception { - BufferedReader reader = new BufferedReader(new StringReader(sb.toString())); - - Predict.doPredict(reader, writer, testModel); - } - - @Test(expected = RuntimeException.class) - public void testDoPredictCorruptLine() throws Exception { - sb.append(testModel.label[0]).append(" abc").append("\n"); - testWithLines(sb); - } - - @Test(expected = RuntimeException.class) - public void testDoPredictCorruptLine2() throws Exception { - sb.append(testModel.label[0]).append(" 1:").append("\n"); - testWithLines(sb); - } - - @Test - public void testDoPredict() throws Exception { - sb.append(testModel.label[0]).append(" 1:0.32393").append("\n"); - sb.append(testModel.label[1]).append(" 2:-71.555 9:88223").append("\n"); - testWithLines(sb); - assertThat(writer.toString()).isNotEmpty(); - } -} diff --git a/src/test/java/de/bwaldvogel/denseliblinear/TrainTest.java b/src/test/java/de/bwaldvogel/denseliblinear/TrainTest.java deleted file mode 100644 index 75558c9..0000000 --- a/src/test/java/de/bwaldvogel/denseliblinear/TrainTest.java +++ /dev/null @@ -1,210 +0,0 @@ -package de.bwaldvogel.denseliblinear; - -import static org.fest.assertions.Assertions.assertThat; - -import java.io.BufferedWriter; -import java.io.File; -import java.io.FileWriter; -import java.util.ArrayList; -import java.util.Collection; - -import org.junit.Test; - -public class TrainTest { - - @Test - public void testParseCommandLine() { - final Train train = new Train(); - - for (final SolverType solver : SolverType.values()) { - train.parse_command_line(new String[] { "-B", "5.3", "-s", "" + solver.getId(), "-p", "0.01", "model-filename" }); - final Parameter param = train.getParameter(); - assertThat(param.solverType).isEqualTo(solver); - // check default eps - if (solver.getId() == 0 || solver.getId() == 2 // - || solver.getId() == 5 || solver.getId() == 6) - { - assertThat(param.eps).isEqualTo(0.01); - } else if (solver.getId() == 7) { - assertThat(param.eps).isEqualTo(0.1); - } else if (solver.getId() == 11) { - assertThat(param.eps).isEqualTo(0.001); - } else { - assertThat(param.eps).isEqualTo(0.1); - } - // check if bias is set - assertThat(train.getBias()).isEqualTo(5.3); - assertThat(param.p).isEqualTo(0.01); - } - } - - @Test - // https://github.com/bwaldvogel/liblinear-java/issues/4 - public void - testParseWeights() throws Exception - { - final Train train = new Train(); - train.parse_command_line(new String[] { "-v", "10", "-c", "10", "-w1", "1.234", "model-filename" }); - Parameter parameter = train.getParameter(); - assertThat(parameter.weightLabel).isEqualTo(new int[] { 1 }); - assertThat(parameter.weight).isEqualTo(new double[] { 1.234 }); - - train.parse_command_line(new String[] { "-w1", "1.234", "-w2", "0.12", "-w3", "7", "model-filename" }); - parameter = train.getParameter(); - assertThat(parameter.weightLabel).isEqualTo(new int[] { 1, 2, 3 }); - assertThat(parameter.weight).isEqualTo(new double[] { 1.234, 0.12, 7 }); - } - - @Test - public void testReadProblem() throws Exception { - - final File file = File.createTempFile("svm", "test"); - file.deleteOnExit(); - - final Collection lines = new ArrayList(); - lines.add("1 1:1 3:1 4:1 6:1"); - lines.add("2 2:1 3:1 5:1 7:1"); - lines.add("1 3:1 5:1"); - lines.add("1 1:1 4:1 7:1"); - lines.add("2 4:1 5:1 7:1"); - final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); - try { - for (final String line : lines) - writer.append(line).append("\n"); - } finally { - writer.close(); - } - - final Train train = new Train(); - train.readProblem(file.getAbsolutePath()); - - final Problem prob = train.getProblem(); - assertThat(prob.bias).isEqualTo(1); - assertThat(prob.y).hasSize(lines.size()); - assertThat(prob.y).isEqualTo(new double[] { 1, 2, 1, 1, 2 }); - assertThat(prob.n).isEqualTo(8); - assertThat(prob.l).isEqualTo(prob.y.length); - assertThat(prob.x).hasSize(prob.y.length); - - for (final double[] nodes : prob.x) { - - assertThat(nodes.length).isLessThanOrEqualTo(prob.n); - for (int ind = 0; ind < prob.n; ind++) { - // bias term - if (prob.bias >= 0 && ind == prob.n - 1) { - // assertThat(ind).isEqualTo(prob.n); - assertThat(nodes[ind]).isEqualTo(prob.bias); - } else { - assertThat(ind).isLessThan(prob.n); - } - } - } - } - - /** - * unit-test for Issue #1 - * (http://github.com/bwaldvogel/liblinear-java/issues#issue/1) - */ - @Test - public void testReadProblemEmptyLine() throws Exception { - - final File file = File.createTempFile("svm", "test"); - file.deleteOnExit(); - - final Collection lines = new ArrayList(); - lines.add("1 1:1 3:1 4:1 6:1"); - lines.add("2 "); - final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); - try { - for (final String line : lines) - writer.append(line).append("\n"); - } finally { - writer.close(); - } - - final Problem prob = Train.readProblem(file, -1.0); - assertThat(prob.bias).isEqualTo(-1); - assertThat(prob.y).hasSize(lines.size()); - assertThat(prob.y).isEqualTo(new double[] { 1, 2 }); - assertThat(prob.n).isEqualTo(6); - assertThat(prob.l).isEqualTo(prob.y.length); - assertThat(prob.x).hasSize(prob.y.length); - - assertThat(prob.x[0]).hasSize(6); - assertThat(prob.x[1]).hasSize(6); - } - - @Test(expected = InvalidInputDataException.class) - public void testReadUnsortedProblem() throws Exception { - final File file = File.createTempFile("svm", "test"); - file.deleteOnExit(); - - final Collection lines = new ArrayList(); - lines.add("1 1:1 3:1 4:1 6:1"); - lines.add("2 2:1 3:1 5:1 7:1"); - lines.add("1 3:1 5:1 4:1"); // here's the mistake: not correctly - // sorted - - final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); - try { - for (final String line : lines) - writer.append(line).append("\n"); - } finally { - writer.close(); - } - - final Train train = new Train(); - train.readProblem(file.getAbsolutePath()); - } - - @Test(expected = InvalidInputDataException.class) - public void testReadProblemWithInvalidIndex() throws Exception { - final File file = File.createTempFile("svm", "test"); - file.deleteOnExit(); - - final Collection lines = new ArrayList(); - lines.add("1 1:1 3:1 4:1 6:1"); - lines.add("2 2:1 3:1 5:1 -4:1"); - - final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); - try { - for (final String line : lines) - writer.append(line).append("\n"); - } finally { - writer.close(); - } - - final Train train = new Train(); - try { - train.readProblem(file.getAbsolutePath()); - } catch (final InvalidInputDataException e) { - throw e; - } - } - - @Test(expected = InvalidInputDataException.class) - public void testReadWrongProblem() throws Exception { - final File file = File.createTempFile("svm", "test"); - file.deleteOnExit(); - - final Collection lines = new ArrayList(); - lines.add("1 1:1 3:1 4:1 6:1"); - lines.add("2 2:1 3:1 5:1 7:1"); - lines.add("1 3:1 5:a"); // here's the mistake: incomplete line - - final BufferedWriter writer = new BufferedWriter(new FileWriter(file)); - try { - for (final String line : lines) - writer.append(line).append("\n"); - } finally { - writer.close(); - } - - final Train train = new Train(); - try { - train.readProblem(file.getAbsolutePath()); - } catch (final InvalidInputDataException e) { - throw e; - } - } -} From e7bb94c79c70f0eaf3530dafd2de8286c2da8c37 Mon Sep 17 00:00:00 2001 From: Jonathon Hare Date: Wed, 17 Jul 2013 16:37:29 +0100 Subject: [PATCH 4/4] documentation --- README.md | 4 +--- pom.xml | 17 +++++++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 02d2ac3..a271bef 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,4 @@ -[![Build Status](https://travis-ci.org/bwaldvogel/liblinear-java.png?branch=master)](https://travis-ci.org/bwaldvogel/liblinear-java) - -This is the Java version of LIBLINEAR. +This is the Java version of LIBLINEAR with added support for dense features. The project site of the original `C++` version is located at http://www.csie.ntu.edu.tw/~cjlin/liblinear/ diff --git a/pom.xml b/pom.xml index d3b28e3..6c97ecf 100644 --- a/pom.xml +++ b/pom.xml @@ -1,11 +1,11 @@ 4.0.0 de.bwaldvogel - liblinear + liblinear-dense jar liblinear 1.92 - Java port of Liblinear + Java port of Liblinear, with added support for dense features http://www.bwaldvogel.de/liblinear-java/ @@ -23,9 +23,9 @@ - scm:git:git@github.com:bwaldvogel/liblinear-java.git - scm:git:git@github.com:bwaldvogel/liblinear-java.git - git@github.com:bwaldvogel/liblinear-java.git + scm:git:git@github.com:jonhare/liblinear-java.git + scm:git:git@github.com:jonhare/liblinear-java.git + git@github.com:jonhare/liblinear-java.git @@ -34,10 +34,15 @@ Benedikt Waldvogel mail at bwaldvogel.de + + jonhare + Jonathon Hare + jsh2@ecs.soton.ac.uk + - liblinear-${project.version} + liblinear-dense-${project.version} maven-compiler-plugin