|
| 1 | +/* SVM Classifier */ |
| 2 | + |
| 3 | +package smile.classification; |
| 4 | + |
| 5 | +import smile.base.svm.KernelMachine; |
| 6 | +import smile.base.svm.LinearKernelMachine; |
| 7 | +import smile.base.svm.LASVM; |
| 8 | +import smile.util.SparseArray; |
| 9 | +import smile.math.kernel.BinarySparseLinearKernel; |
| 10 | +import smile.math.kernel.LinearKernel; |
| 11 | +import smile.math.kernel.MercerKernel; |
| 12 | +import smile.math.kernel.SparseLinearKernel; |
| 13 | + |
| 14 | +/** |
| 15 | + * Support vector machines for classification. The basic support vector machine |
| 16 | + * is a binary linear classifier which chooses the hyperplane that represents |
| 17 | + * the largest separation, or margin, between the two classes. If such a |
| 18 | + * hyperplane exists, it is known as the maximum-margin hyperplane and the |
| 19 | + * linear classifier it defines is known as a maximum margin classifier. |
| 20 | + * <p> |
| 21 | + * If there exists no hyperplane that can perfectly split the positive and |
| 22 | + * negative instances, the soft margin method will choose a hyperplane |
| 23 | + * that splits the instances as cleanly as possible, while still maximizing |
| 24 | + * the distance to the nearest cleanly split instances. |
| 25 | + * <p> |
| 26 | + * The nonlinear SVMs are created by applying the kernel trick to |
| 27 | + * maximum-margin hyperplanes. The resulting algorithm is formally similar, |
| 28 | + * except that every dot product is replaced by a nonlinear kernel function. |
| 29 | + * This allows the algorithm to fit the maximum-margin hyperplane in a |
| 30 | + * transformed feature space. The transformation may be nonlinear and |
| 31 | + * the transformed space be high dimensional. For example, the feature space |
| 32 | + * corresponding Gaussian kernel is a Hilbert space of infinite dimension. |
| 33 | + * Thus though the classifier is a hyperplane in the high-dimensional feature |
| 34 | + * space, it may be nonlinear in the original input space. Maximum margin |
| 35 | + * classifiers are well regularized, so the infinite dimension does not spoil |
| 36 | + * the results. |
| 37 | + * <p> |
| 38 | + * The effectiveness of SVM depends on the selection of kernel, the kernel's |
| 39 | + * parameters, and soft margin parameter C. Given a kernel, best combination |
| 40 | + * of C and kernel's parameters is often selected by a grid-search with |
| 41 | + * cross validation. |
| 42 | + * <p> |
| 43 | + * The dominant approach for creating multi-class SVMs is to reduce the |
| 44 | + * single multi-class problem into multiple binary classification problems. |
| 45 | + * Common methods for such reduction is to build binary classifiers which |
| 46 | + * distinguish between (i) one of the labels to the rest (one-versus-all) |
| 47 | + * or (ii) between every pair of classes (one-versus-one). Classification |
| 48 | + * of new instances for one-versus-all case is done by a winner-takes-all |
| 49 | + * strategy, in which the classifier with the highest output function assigns |
| 50 | + * the class. For the one-versus-one approach, classification |
| 51 | + * is done by a max-wins voting strategy, in which every classifier assigns |
| 52 | + * the instance to one of the two classes, then the vote for the assigned |
| 53 | + * class is increased by one vote, and finally the class with most votes |
| 54 | + * determines the instance classification. |
| 55 | + |
| 56 | +public class SVM<T> extends KernelMachine<T> implements Classifier<T> { |
| 57 | + /** |
| 58 | + * Constructor. |
| 59 | + * @param kernel Kernel function. |
| 60 | + * @param instances The instances in the kernel machine, e.g. support vectors. |
| 61 | + * @param weight The weights of instances. |
| 62 | + * @param b The intercept; |
| 63 | + */ |
| 64 | + public SVM(MercerKernel<T> kernel, T[] instances, double[] weight, double b) { |
| 65 | + super(kernel, instances, weight, b); |
| 66 | + } |
| 67 | + |
| 68 | + @Override |
| 69 | + public int predict(T x) { |
| 70 | + return f(x) > 0 ? +1 : -1; |
| 71 | + } |
| 72 | + |
| 73 | + /** |
| 74 | + * Fits a binary-class linear SVM. |
| 75 | + * @param x training samples. |
| 76 | + * @param y training labels. |
| 77 | + * @param C the soft margin penalty parameter. |
| 78 | + * @param tol the tolerance of convergence test. |
| 79 | + */ |
| 80 | + public static Classifier<double[]> fit(double[][] x, int[] y, double C, double tol) { |
| 81 | + LASVM<double[]> lasvm = new LASVM<>(new LinearKernel(), C, tol); |
| 82 | + KernelMachine<double[]> svm = lasvm.fit(x, y); |
| 83 | + |
| 84 | + return new Classifier<double[]>() { |
| 85 | + LinearKernelMachine model = LinearKernelMachine.of(svm); |
| 86 | + |
| 87 | + @Override |
| 88 | + public int predict(double[] x) { |
| 89 | + return model.f(x) > 0 ? +1 : -1; |
| 90 | + } |
| 91 | + }; |
| 92 | + } |
| 93 | + |
| 94 | + /** |
| 95 | + * Fits a binary-class linear SVM of binary sparse data. |
| 96 | + * @param x training samples. |
| 97 | + * @param y training labels. |
| 98 | + * @param p the dimension of input vector. |
| 99 | + * @param C the soft margin penalty parameter. |
| 100 | + * @param tol the tolerance of convergence test. |
| 101 | + */ |
| 102 | + public static Classifier<int[]> fit(int[][] x, int[] y, int p, double C, double tol) { |
| 103 | + LASVM<int[]> lasvm = new LASVM<>(new BinarySparseLinearKernel(), C, tol); |
| 104 | + KernelMachine<int[]> svm = lasvm.fit(x, y); |
| 105 | + |
| 106 | + return new Classifier<int[]>() { |
| 107 | + LinearKernelMachine model = LinearKernelMachine.binary(p, svm); |
| 108 | + |
| 109 | + @Override |
| 110 | + public int predict(int[] x) { |
| 111 | + return model.f(x) > 0 ? +1 : -1; |
| 112 | + } |
| 113 | + }; |
| 114 | + } |
| 115 | + |
| 116 | + /** |
| 117 | + * Fits a binary-class linear SVM. |
| 118 | + * @param x training samples. |
| 119 | + * @param y training labels. |
| 120 | + * @param p the dimension of input vector. |
| 121 | + * @param C the soft margin penalty parameter. |
| 122 | + * @param tol the tolerance of convergence test. |
| 123 | + */ |
| 124 | + public static Classifier<SparseArray> fit(SparseArray[] x, int[] y, int p, double C, double tol) { |
| 125 | + LASVM<SparseArray> lasvm = new LASVM<>(new SparseLinearKernel(), C, tol); |
| 126 | + KernelMachine<SparseArray> svm = lasvm.fit(x, y); |
| 127 | + |
| 128 | + return new Classifier<SparseArray>() { |
| 129 | + LinearKernelMachine model = LinearKernelMachine.sparse(p, svm); |
| 130 | + |
| 131 | + @Override |
| 132 | + public int predict(SparseArray x) { |
| 133 | + return model.f(x) > 0 ? +1 : -1; |
| 134 | + } |
| 135 | + }; |
| 136 | + } |
| 137 | + |
| 138 | + /** |
| 139 | + * Fits a binary-class SVM. |
| 140 | + * @param x training samples. |
| 141 | + * @param y training labels. |
| 142 | + * @param kernel the kernel function. |
| 143 | + * @param C the soft margin penalty parameter. |
| 144 | + * @param tol the tolerance of convergence test. |
| 145 | + */ |
| 146 | + public static <T> SVM<T> fit(T[] x, int[] y, MercerKernel<T> kernel, double C, double tol) { |
| 147 | + LASVM<T> lasvm = new LASVM<>(kernel, C, tol); |
| 148 | + return lasvm.fit(x, y).toSVM(); |
| 149 | + } |
| 150 | +} |
0 commit comments