Skip to content

Commit 80b4af0

Browse files
author
linyiqun
committed
K-最近邻算法的实现
K-最近邻算法的实现
1 parent 8fb5639 commit 80b4af0

5 files changed

Lines changed: 300 additions & 0 deletions

File tree

DataMining_KNN/Client.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package DataMining_KNN;
2+
3+
import java.util.ArrayList;
4+
import java.util.Collections;
5+
import java.util.Comparator;
6+
import java.util.List;
7+
8+
9+
/**
10+
* k×î½üÁÚËã·¨³¡¾°ÀàÐÍ
11+
* @author lyq
12+
*
13+
*/
14+
public class Client {
15+
public static void main(String[] args){
16+
String trainDataPath = "C:\\Users\\lyq\\Desktop\\icon\\trainInput.txt";
17+
String testDataPath = "C:\\Users\\lyq\\Desktop\\icon\\testinput.txt";
18+
19+
KNNTool tool = new KNNTool(trainDataPath, testDataPath);
20+
tool.knnCompute(3);
21+
22+
}
23+
24+
25+
26+
}

DataMining_KNN/KNNTool.java

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
package DataMining_KNN;
2+
3+
import java.io.BufferedReader;
4+
import java.io.File;
5+
import java.io.FileReader;
6+
import java.io.IOException;
7+
import java.util.ArrayList;
8+
import java.util.Arrays;
9+
import java.util.Collection;
10+
import java.util.Collections;
11+
import java.util.Comparator;
12+
import java.util.HashMap;
13+
import java.util.Map;
14+
15+
import org.apache.activemq.filter.ComparisonExpression;
16+
17+
/**
18+
* k最近邻算法工具类
19+
*
20+
* @author lyq
21+
*
22+
*/
23+
public class KNNTool {
24+
// 为4个类别设置权重,默认权重比一致
25+
public int[] classWeightArray = new int[] { 1, 1, 1, 1 };
26+
// 测试数据
27+
private String testDataPath;
28+
// 训练集数据地址
29+
private String trainDataPath;
30+
// 分类的不同类型
31+
private ArrayList<String> classTypes;
32+
// 结果数据
33+
private ArrayList<Sample> resultSamples;
34+
// 训练集数据列表容器
35+
private ArrayList<Sample> trainSamples;
36+
// 训练集数据
37+
private String[][] trainData;
38+
// 测试集数据
39+
private String[][] testData;
40+
41+
public KNNTool(String trainDataPath, String testDataPath) {
42+
this.trainDataPath = trainDataPath;
43+
this.testDataPath = testDataPath;
44+
readDataFormFile();
45+
}
46+
47+
/**
48+
* 从文件中阅读测试数和训练数据集
49+
*/
50+
private void readDataFormFile() {
51+
ArrayList<String[]> tempArray;
52+
53+
tempArray = fileDataToArray(trainDataPath);
54+
trainData = new String[tempArray.size()][];
55+
tempArray.toArray(trainData);
56+
57+
classTypes = new ArrayList<>();
58+
for (String[] s : tempArray) {
59+
if (!classTypes.contains(s[0])) {
60+
// 添加类型
61+
classTypes.add(s[0]);
62+
}
63+
}
64+
65+
tempArray = fileDataToArray(testDataPath);
66+
testData = new String[tempArray.size()][];
67+
tempArray.toArray(testData);
68+
}
69+
70+
/**
71+
* 将文件转为列表数据输出
72+
*
73+
* @param filePath
74+
* 数据文件的内容
75+
*/
76+
private ArrayList<String[]> fileDataToArray(String filePath) {
77+
File file = new File(filePath);
78+
ArrayList<String[]> dataArray = new ArrayList<String[]>();
79+
80+
try {
81+
BufferedReader in = new BufferedReader(new FileReader(file));
82+
String str;
83+
String[] tempArray;
84+
while ((str = in.readLine()) != null) {
85+
tempArray = str.split(" ");
86+
dataArray.add(tempArray);
87+
}
88+
in.close();
89+
} catch (IOException e) {
90+
e.getStackTrace();
91+
}
92+
93+
return dataArray;
94+
}
95+
96+
/**
97+
* 计算样本特征向量的欧几里得距离
98+
*
99+
* @param f1
100+
* 待比较样本1
101+
* @param f2
102+
* 待比较样本2
103+
* @return
104+
*/
105+
private int computeEuclideanDistance(Sample s1, Sample s2) {
106+
String[] f1 = s1.getFeatures();
107+
String[] f2 = s2.getFeatures();
108+
// 欧几里得距离
109+
int distance = 0;
110+
111+
for (int i = 0; i < f1.length; i++) {
112+
int subF1 = Integer.parseInt(f1[i]);
113+
int subF2 = Integer.parseInt(f2[i]);
114+
115+
distance += (subF1 - subF2) * (subF1 - subF2);
116+
}
117+
118+
return distance;
119+
}
120+
121+
/**
122+
* 计算K最近邻
123+
* @param k
124+
* 在多少的k范围内
125+
*/
126+
public void knnCompute(int k) {
127+
String className = "";
128+
String[] tempF = null;
129+
Sample temp;
130+
resultSamples = new ArrayList<>();
131+
trainSamples = new ArrayList<>();
132+
// 分类类别计数
133+
HashMap<String, Integer> classCount;
134+
// 类别权重比
135+
HashMap<String, Integer> classWeight = new HashMap<>();
136+
// 首先讲测试数据转化到结果数据中
137+
for (String[] s : testData) {
138+
temp = new Sample(s);
139+
resultSamples.add(temp);
140+
}
141+
142+
for (String[] s : trainData) {
143+
className = s[0];
144+
tempF = new String[s.length - 1];
145+
System.arraycopy(s, 1, tempF, 0, s.length - 1);
146+
temp = new Sample(className, tempF);
147+
trainSamples.add(temp);
148+
}
149+
150+
// 离样本最近排序的的训练集数据
151+
ArrayList<Sample> kNNSample = new ArrayList<>();
152+
// 计算训练数据集中离样本数据最近的K个训练集数据
153+
for (Sample s : resultSamples) {
154+
classCount = new HashMap<>();
155+
int index = 0;
156+
for (String type : classTypes) {
157+
// 开始时计数为0
158+
classCount.put(type, 0);
159+
classWeight.put(type, classWeightArray[index++]);
160+
}
161+
for (Sample tS : trainSamples) {
162+
int dis = computeEuclideanDistance(s, tS);
163+
tS.setDistance(dis);
164+
}
165+
166+
Collections.sort(trainSamples);
167+
kNNSample.clear();
168+
// 挑选出前k个数据作为分类标准
169+
for (int i = 0; i < trainSamples.size(); i++) {
170+
if (i < k) {
171+
kNNSample.add(trainSamples.get(i));
172+
} else {
173+
break;
174+
}
175+
}
176+
// 判定K个训练数据的多数的分类标准
177+
for (Sample s1 : kNNSample) {
178+
int num = classCount.get(s1.getClassName());
179+
// 进行分类权重的叠加,默认类别权重平等,可自行改变,近的权重大,远的权重小
180+
num += classWeight.get(s1.getClassName());
181+
classCount.put(s1.getClassName(), num);
182+
}
183+
184+
int maxCount = 0;
185+
// 筛选出k个训练集数据中最多的一个分类
186+
for (Map.Entry entry : classCount.entrySet()) {
187+
if ((Integer) entry.getValue() > maxCount) {
188+
maxCount = (Integer) entry.getValue();
189+
s.setClassName((String) entry.getKey());
190+
}
191+
}
192+
193+
System.out.print("测试数据特征:");
194+
for (String s1 : s.getFeatures()) {
195+
System.out.print(s1 + " ");
196+
}
197+
System.out.println("分类:" + s.getClassName());
198+
}
199+
}
200+
}

DataMining_KNN/Sample.java

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package DataMining_KNN;
2+
3+
/**
4+
* 样本数据类
5+
*
6+
* @author lyq
7+
*
8+
*/
9+
public class Sample implements Comparable<Sample>{
10+
// 样本数据的分类名称
11+
private String className;
12+
// 样本数据的特征向量
13+
private String[] features;
14+
//测试样本之间的间距值,以此做排序
15+
private Integer distance;
16+
17+
public Sample(String[] features){
18+
this.features = features;
19+
}
20+
21+
public Sample(String className, String[] features){
22+
this.className = className;
23+
this.features = features;
24+
}
25+
26+
public String getClassName() {
27+
return className;
28+
}
29+
30+
public void setClassName(String className) {
31+
this.className = className;
32+
}
33+
34+
public String[] getFeatures() {
35+
return features;
36+
}
37+
38+
public void setFeatures(String[] features) {
39+
this.features = features;
40+
}
41+
42+
public Integer getDistance() {
43+
return distance;
44+
}
45+
46+
public void setDistance(int distance) {
47+
this.distance = distance;
48+
}
49+
50+
@Override
51+
public int compareTo(Sample o) {
52+
// TODO Auto-generated method stub
53+
return this.getDistance().compareTo(o.getDistance());
54+
}
55+
56+
}
57+

DataMining_KNN/testInput.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
1 2 3 2 4
2+
2 3 4 2 1
3+
8 7 2 3 5
4+
-3 -2 2 4 0
5+
-4 -4 -4 -4 -4
6+
1 2 3 4 4
7+
4 4 3 2 1
8+
3 3 3 2 4
9+
0 0 1 1 -2

DataMining_KNN/trainInput.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
a 1 2 3 4 5
2+
b 5 4 3 2 1
3+
c 3 3 3 3 3
4+
d -3 -3 -3 -3 -3
5+
a 1 2 3 4 4
6+
b 4 4 3 2 1
7+
c 3 3 3 2 4
8+
d 0 0 1 1 -2

0 commit comments

Comments
 (0)