Skip to content

Commit d5d6855

Browse files
author
linyiqun
committed
SVM支持向量机算法实现例子,通过libsvm包的方式
SVM支持向量机算法实现例子,通过libsvm包的方式
1 parent bbbb32f commit d5d6855

11 files changed

Lines changed: 3005 additions & 0 deletions

DataMining_SVM/Client.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package DataMining_SVM;
2+
3+
/**
4+
* SVM支持向量机场景调用类
5+
* @author lyq
6+
*
7+
*/
8+
public class Client {
9+
public static void main(String[] args){
10+
//训练集数据文件路径
11+
String trainDataPath = "C:\\Users\\lyq\\Desktop\\icon\\trainInput.txt";
12+
//测试数据文件路径
13+
String testDataPath = "C:\\Users\\lyq\\Desktop\\icon\\testInput.txt";
14+
}
15+
16+
}

DataMining_SVM/SVM.java

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package DataMining_SVM;
2+
3+
import java.io.BufferedReader;
4+
import java.io.File;
5+
import java.io.FileReader;
6+
import java.util.ArrayList;
7+
import java.util.List;
8+
9+
import DataMining_SVM.libsvm.svm;
10+
import DataMining_SVM.libsvm.svm_model;
11+
import DataMining_SVM.libsvm.svm_node;
12+
import DataMining_SVM.libsvm.svm_parameter;
13+
import DataMining_SVM.libsvm.svm_problem;
14+
15+
public class SVM {
16+
public static void main(String[] args) {
17+
// 定义训练集点a{10.0, 10.0} 和 点b{-10.0, -10.0},对应lable为{1.0, -1.0}
18+
List<Double> label = new ArrayList<Double>();
19+
List<svm_node[]> nodeSet = new ArrayList<svm_node[]>();
20+
getData(nodeSet, label, "C:\\Users\\lyq\\Desktop\\icon\\trainInput.txt");
21+
22+
int dataRange = nodeSet.get(0).length;
23+
svm_node[][] datas = new svm_node[nodeSet.size()][dataRange]; // 训练集的向量表
24+
for (int i = 0; i < datas.length; i++) {
25+
for (int j = 0; j < dataRange; j++) {
26+
datas[i][j] = nodeSet.get(i)[j];
27+
}
28+
}
29+
double[] lables = new double[label.size()]; // a,b 对应的lable
30+
for (int i = 0; i < lables.length; i++) {
31+
lables[i] = label.get(i);
32+
}
33+
34+
// 定义svm_problem对象
35+
svm_problem problem = new svm_problem();
36+
problem.l = nodeSet.size(); // 向量个数
37+
problem.x = datas; // 训练集向量表
38+
problem.y = lables; // 对应的lable数组
39+
40+
// 定义svm_parameter对象
41+
svm_parameter param = new svm_parameter();
42+
param.svm_type = svm_parameter.EPSILON_SVR;
43+
param.kernel_type = svm_parameter.LINEAR;
44+
param.cache_size = 100;
45+
param.eps = 0.00001;
46+
param.C = 1.9;
47+
// 训练SVM分类模型
48+
System.out.println(svm.svm_check_parameter(problem, param));
49+
// 如果参数没有问题,则svm.svm_check_parameter()函数返回null,否则返回error描述。
50+
svm_model model = svm.svm_train(problem, param);
51+
// svm.svm_train()训练出SVM分类模型
52+
53+
// 获取测试数据
54+
List<Double> testlabel = new ArrayList<Double>();
55+
List<svm_node[]> testnodeSet = new ArrayList<svm_node[]>();
56+
getData(testnodeSet, testlabel, "C:\\Users\\lyq\\Desktop\\icon\\testInput.txt");
57+
58+
svm_node[][] testdatas = new svm_node[testnodeSet.size()][dataRange]; // 训练集的向量表
59+
for (int i = 0; i < testdatas.length; i++) {
60+
for (int j = 0; j < dataRange; j++) {
61+
testdatas[i][j] = testnodeSet.get(i)[j];
62+
}
63+
}
64+
double[] testlables = new double[testlabel.size()]; // a,b 对应的lable
65+
for (int i = 0; i < testlables.length; i++) {
66+
testlables[i] = testlabel.get(i);
67+
}
68+
69+
// 预测测试数据的lable
70+
double err = 0.0;
71+
for (int i = 0; i < testdatas.length; i++) {
72+
double truevalue = testlables[i];
73+
System.out.print(truevalue + " ");
74+
double predictValue = svm.svm_predict(model, testdatas[i]);
75+
System.out.println(predictValue);
76+
err += Math.abs(predictValue - truevalue);
77+
}
78+
System.out.println("err=" + err / datas.length);
79+
}
80+
81+
public static void getData(List<svm_node[]> nodeSet, List<Double> label,
82+
String filename) {
83+
try {
84+
85+
FileReader fr = new FileReader(new File(filename));
86+
BufferedReader br = new BufferedReader(fr);
87+
String line = null;
88+
while ((line = br.readLine()) != null) {
89+
String[] datas = line.split(",");
90+
svm_node[] vector = new svm_node[datas.length - 1];
91+
for (int i = 0; i < datas.length - 1; i++) {
92+
svm_node node = new svm_node();
93+
node.index = i + 1;
94+
node.value = Double.parseDouble(datas[i]);
95+
vector[i] = node;
96+
}
97+
nodeSet.add(vector);
98+
double lablevalue = Double.parseDouble(datas[datas.length - 1]);
99+
label.add(lablevalue);
100+
}
101+
} catch (Exception e) {
102+
e.printStackTrace();
103+
}
104+
105+
}
106+
}
107+

DataMining_SVM/SVMTool.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package DataMining_SVM;
2+
3+
/**
4+
* SVMÖ§³ÖÏòÁ¿»ú¹¤¾ßÀà
5+
* @author lyq
6+
*
7+
*/
8+
public class SVMTool {
9+
10+
}

0 commit comments

Comments
 (0)