Skip to content

Commit 0ffd847

Browse files
Merge pull request sPredictorX1708#552 from Ishita-0112/master
Create SVM.java
2 parents 87c1ca1 + 1f7f0b6 commit 0ffd847

1 file changed

Lines changed: 150 additions & 0 deletions

File tree

Machine Learning/SVM.java

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/* SVM Classifier */
2+
3+
package smile.classification;
4+
5+
import smile.base.svm.KernelMachine;
6+
import smile.base.svm.LinearKernelMachine;
7+
import smile.base.svm.LASVM;
8+
import smile.util.SparseArray;
9+
import smile.math.kernel.BinarySparseLinearKernel;
10+
import smile.math.kernel.LinearKernel;
11+
import smile.math.kernel.MercerKernel;
12+
import smile.math.kernel.SparseLinearKernel;
13+
14+
/**
15+
* Support vector machines for classification. The basic support vector machine
16+
* is a binary linear classifier which chooses the hyperplane that represents
17+
* the largest separation, or margin, between the two classes. If such a
18+
* hyperplane exists, it is known as the maximum-margin hyperplane and the
19+
* linear classifier it defines is known as a maximum margin classifier.
20+
* <p>
21+
* If there exists no hyperplane that can perfectly split the positive and
22+
* negative instances, the soft margin method will choose a hyperplane
23+
* that splits the instances as cleanly as possible, while still maximizing
24+
* the distance to the nearest cleanly split instances.
25+
* <p>
26+
* The nonlinear SVMs are created by applying the kernel trick to
27+
* maximum-margin hyperplanes. The resulting algorithm is formally similar,
28+
* except that every dot product is replaced by a nonlinear kernel function.
29+
* This allows the algorithm to fit the maximum-margin hyperplane in a
30+
* transformed feature space. The transformation may be nonlinear and
31+
* the transformed space be high dimensional. For example, the feature space
32+
* corresponding Gaussian kernel is a Hilbert space of infinite dimension.
33+
* Thus though the classifier is a hyperplane in the high-dimensional feature
34+
* space, it may be nonlinear in the original input space. Maximum margin
35+
* classifiers are well regularized, so the infinite dimension does not spoil
36+
* the results.
37+
* <p>
38+
* The effectiveness of SVM depends on the selection of kernel, the kernel's
39+
* parameters, and soft margin parameter C. Given a kernel, best combination
40+
* of C and kernel's parameters is often selected by a grid-search with
41+
* cross validation.
42+
* <p>
43+
* The dominant approach for creating multi-class SVMs is to reduce the
44+
* single multi-class problem into multiple binary classification problems.
45+
* Common methods for such reduction is to build binary classifiers which
46+
* distinguish between (i) one of the labels to the rest (one-versus-all)
47+
* or (ii) between every pair of classes (one-versus-one). Classification
48+
* of new instances for one-versus-all case is done by a winner-takes-all
49+
* strategy, in which the classifier with the highest output function assigns
50+
* the class. For the one-versus-one approach, classification
51+
* is done by a max-wins voting strategy, in which every classifier assigns
52+
* the instance to one of the two classes, then the vote for the assigned
53+
* class is increased by one vote, and finally the class with most votes
54+
* determines the instance classification.
55+
56+
public class SVM<T> extends KernelMachine<T> implements Classifier<T> {
57+
/**
58+
* Constructor.
59+
* @param kernel Kernel function.
60+
* @param instances The instances in the kernel machine, e.g. support vectors.
61+
* @param weight The weights of instances.
62+
* @param b The intercept;
63+
*/
64+
public SVM(MercerKernel<T> kernel, T[] instances, double[] weight, double b) {
65+
super(kernel, instances, weight, b);
66+
}
67+
68+
@Override
69+
public int predict(T x) {
70+
return f(x) > 0 ? +1 : -1;
71+
}
72+
73+
/**
74+
* Fits a binary-class linear SVM.
75+
* @param x training samples.
76+
* @param y training labels.
77+
* @param C the soft margin penalty parameter.
78+
* @param tol the tolerance of convergence test.
79+
*/
80+
public static Classifier<double[]> fit(double[][] x, int[] y, double C, double tol) {
81+
LASVM<double[]> lasvm = new LASVM<>(new LinearKernel(), C, tol);
82+
KernelMachine<double[]> svm = lasvm.fit(x, y);
83+
84+
return new Classifier<double[]>() {
85+
LinearKernelMachine model = LinearKernelMachine.of(svm);
86+
87+
@Override
88+
public int predict(double[] x) {
89+
return model.f(x) > 0 ? +1 : -1;
90+
}
91+
};
92+
}
93+
94+
/**
95+
* Fits a binary-class linear SVM of binary sparse data.
96+
* @param x training samples.
97+
* @param y training labels.
98+
* @param p the dimension of input vector.
99+
* @param C the soft margin penalty parameter.
100+
* @param tol the tolerance of convergence test.
101+
*/
102+
public static Classifier<int[]> fit(int[][] x, int[] y, int p, double C, double tol) {
103+
LASVM<int[]> lasvm = new LASVM<>(new BinarySparseLinearKernel(), C, tol);
104+
KernelMachine<int[]> svm = lasvm.fit(x, y);
105+
106+
return new Classifier<int[]>() {
107+
LinearKernelMachine model = LinearKernelMachine.binary(p, svm);
108+
109+
@Override
110+
public int predict(int[] x) {
111+
return model.f(x) > 0 ? +1 : -1;
112+
}
113+
};
114+
}
115+
116+
/**
117+
* Fits a binary-class linear SVM.
118+
* @param x training samples.
119+
* @param y training labels.
120+
* @param p the dimension of input vector.
121+
* @param C the soft margin penalty parameter.
122+
* @param tol the tolerance of convergence test.
123+
*/
124+
public static Classifier<SparseArray> fit(SparseArray[] x, int[] y, int p, double C, double tol) {
125+
LASVM<SparseArray> lasvm = new LASVM<>(new SparseLinearKernel(), C, tol);
126+
KernelMachine<SparseArray> svm = lasvm.fit(x, y);
127+
128+
return new Classifier<SparseArray>() {
129+
LinearKernelMachine model = LinearKernelMachine.sparse(p, svm);
130+
131+
@Override
132+
public int predict(SparseArray x) {
133+
return model.f(x) > 0 ? +1 : -1;
134+
}
135+
};
136+
}
137+
138+
/**
139+
* Fits a binary-class SVM.
140+
* @param x training samples.
141+
* @param y training labels.
142+
* @param kernel the kernel function.
143+
* @param C the soft margin penalty parameter.
144+
* @param tol the tolerance of convergence test.
145+
*/
146+
public static <T> SVM<T> fit(T[] x, int[] y, MercerKernel<T> kernel, double C, double tol) {
147+
LASVM<T> lasvm = new LASVM<>(kernel, C, tol);
148+
return lasvm.fit(x, y).toSVM();
149+
}
150+
}

0 commit comments

Comments
 (0)