Skip to content

Commit cb261ff

Browse files
author
yinchuandong
committed
完成svm_predict
1 parent 9de4d83 commit cb261ff

3 files changed

Lines changed: 785 additions & 0 deletions

File tree

src/train/Predict.java

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
package train;
2+
3+
import java.awt.image.BufferedImage;
4+
import java.io.BufferedReader;
5+
import java.io.File;
6+
import java.io.FileNotFoundException;
7+
import java.io.FileReader;
8+
import java.io.FilenameFilter;
9+
import java.io.IOException;
10+
import java.io.PrintWriter;
11+
import java.util.ArrayList;
12+
import java.util.HashMap;
13+
import java.util.Iterator;
14+
15+
import javax.imageio.ImageIO;
16+
17+
import svmHelper.svm_predict;
18+
import svmHelper.svm_train;
19+
20+
21+
public class Predict {
22+
23+
/**
24+
* 类标号map,如:a=>1 b=>2
25+
*/
26+
private HashMap<String, Integer> labelMap = null;
27+
28+
private HashMap<String, Integer[][]> imageMap = null;
29+
30+
public Predict(){
31+
init();
32+
}
33+
34+
private void init(){
35+
labelMap = new HashMap<String, Integer>();
36+
imageMap = new HashMap<String, Integer[][]>();
37+
38+
loadImageLabel();
39+
loadImage();
40+
}
41+
42+
/**
43+
* 加载类标号
44+
*/
45+
private void loadImageLabel(){
46+
BufferedReader reader = null;
47+
try {
48+
reader = new BufferedReader(new FileReader(new File("svm/label.txt")));
49+
String buff = null;
50+
while((buff = reader.readLine()) != null){
51+
String[] arr = buff.split(" ");
52+
labelMap.put(arr[0], Integer.parseInt(arr[1]));
53+
}
54+
55+
System.out.println("load image label finish!");
56+
57+
} catch (FileNotFoundException e) {
58+
e.printStackTrace();
59+
} catch (IOException e) {
60+
e.printStackTrace();
61+
} finally{
62+
if (reader != null) {
63+
try {
64+
reader.close();
65+
} catch (IOException e) {
66+
e.printStackTrace();
67+
}
68+
}
69+
}
70+
}
71+
72+
73+
private void loadImage(){
74+
File dir = new File("4_scale/");
75+
//只列出jpg
76+
File[] files = dir.listFiles(new FilenameFilter() {
77+
78+
public boolean isJpg(String file){
79+
if (file.toLowerCase().endsWith(".jpg")){
80+
return true;
81+
}else{
82+
return false;
83+
}
84+
}
85+
86+
@Override
87+
public boolean accept(File dir, String name) {
88+
// TODO Auto-generated method stub
89+
return isJpg(name);
90+
}
91+
});
92+
93+
for (File file : files) {
94+
try {
95+
transferToMap(file);
96+
} catch (Exception e) {
97+
e.printStackTrace();
98+
}
99+
}
100+
101+
System.out.println("load mage end");
102+
103+
}
104+
105+
/**
106+
* 获得类标号
107+
* @param className
108+
* @return
109+
*/
110+
private int getClassLabel(String className){
111+
if(labelMap.containsKey(className)){
112+
return labelMap.get(className);
113+
}else{
114+
return -1;
115+
}
116+
}
117+
118+
119+
/**
120+
* 将image 转换到 map中
121+
* @param file
122+
* @throws IOException
123+
*/
124+
private void transferToMap(File file) throws IOException{
125+
BufferedImage image = ImageIO.read(file);
126+
int width = image.getWidth();
127+
int height = image.getHeight();
128+
Integer[][] imgArr = new Integer[height][width];
129+
130+
for (int y = 0; y < height; y++) {
131+
for (int x = 0; x < width; x++) {
132+
//黑色点标记为1
133+
int value = ImageUtil.isBlack(image.getRGB(x, y)) ? 1 : 0;
134+
imgArr[y][x] = value;
135+
}
136+
}
137+
138+
this.imageMap.put(file.getName(), imgArr);
139+
}
140+
141+
/**
142+
* 转成svm 测试集的格式
143+
*/
144+
public void svmFormat(){
145+
146+
PrintWriter writer = null;
147+
try {
148+
writer = new PrintWriter(new File("svm/svm.test"));
149+
Iterator<String> iterator = this.imageMap.keySet().iterator();
150+
151+
while (iterator.hasNext()) {
152+
String fileName = (String) iterator.next();
153+
String className = ImageUtil.getImgClass(fileName);
154+
int classLabel = getClassLabel(className);
155+
156+
String tmpLine = classLabel + " ";
157+
Integer[][] imageArr = this.imageMap.get(fileName);
158+
159+
int index = 1;
160+
for (int i = 0; i < imageArr.length; i++) {
161+
for (int j = 0; j < imageArr[i].length; j++) {
162+
tmpLine += index + ":" + imageArr[i][j] + " ";
163+
index ++;
164+
}
165+
}
166+
writer.write(tmpLine + "\r\n");
167+
writer.flush();
168+
// System.out.println(tmpLine);
169+
}
170+
} catch (Exception e) {
171+
e.printStackTrace();
172+
} finally{
173+
if (writer != null) {
174+
writer.close();
175+
}
176+
}
177+
178+
}
179+
180+
public static void run() throws IOException{
181+
//predict参数
182+
String[] parg = {"svm/svm.test","svm/svm.model","svm/result.txt"};
183+
184+
System.out.println("训练开始");
185+
svm_predict.main(parg);
186+
System.out.println("训练结束");
187+
}
188+
189+
public static void main(String[] args){
190+
// Predict model = new Predict();
191+
// model.svmFormat();
192+
try {
193+
run();
194+
} catch (IOException e) {
195+
e.printStackTrace();
196+
}
197+
}
198+
199+
}

0 commit comments

Comments
 (0)