Skip to content

Commit 655da23

Browse files
authored
Add files via upload
1 parent 123f2ac commit 655da23

1 file changed

Lines changed: 96 additions & 0 deletions

File tree

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# -*- coding: utf-8 -*-
2+
import os
3+
import cv2
4+
import numpy as np
5+
from sklearn.cross_validation import train_test_split
6+
from sklearn.metrics import confusion_matrix, classification_report
7+
8+
#----------------------------------------------------------------------------------
9+
# 第一步 切分训练集和测试集
10+
#----------------------------------------------------------------------------------
11+
12+
X = [] #定义图像名称
13+
Y = [] #定义图像分类类标
14+
Z = [] #定义图像像素
15+
16+
for i in range(0, 10):
17+
#遍历文件夹,读取图片
18+
for f in os.listdir("photo/%s" % i):
19+
#获取图像名称
20+
X.append("photo//" +str(i) + "//" + str(f))
21+
#获取图像类标即为文件夹名称
22+
Y.append(i)
23+
24+
X = np.array(X)
25+
Y = np.array(Y)
26+
27+
#随机率为100% 选取其中的30%作为测试集
28+
X_train, X_test, y_train, y_test = train_test_split(X, Y,
29+
test_size=0.3, random_state=1)
30+
31+
print len(X_train), len(X_test), len(y_train), len(y_test)
32+
33+
#----------------------------------------------------------------------------------
34+
# 第二步 图像读取及转换为像素直方图
35+
#----------------------------------------------------------------------------------
36+
37+
#训练集
38+
XX_train = []
39+
for i in X_train:
40+
#读取图像
41+
#print i
42+
image = cv2.imread(i)
43+
44+
#图像像素大小一致
45+
img = cv2.resize(image, (256,256),
46+
interpolation=cv2.INTER_CUBIC)
47+
48+
#计算图像直方图并存储至X数组
49+
hist = cv2.calcHist([img], [0,1], None,
50+
[256,256], [0.0,255.0,0.0,255.0])
51+
52+
XX_train.append(((hist/255).flatten()))
53+
54+
#测试集
55+
XX_test = []
56+
for i in X_test:
57+
#读取图像
58+
#print i
59+
image = cv2.imread(i)
60+
61+
#图像像素大小一致
62+
img = cv2.resize(image, (256,256),
63+
interpolation=cv2.INTER_CUBIC)
64+
65+
#计算图像直方图并存储至X数组
66+
hist = cv2.calcHist([img], [0,1], None,
67+
[256,256], [0.0,255.0,0.0,255.0])
68+
69+
XX_test.append(((hist/255).flatten()))
70+
71+
#----------------------------------------------------------------------------------
72+
# 第三步 基于朴素贝叶斯的图像分类处理
73+
#----------------------------------------------------------------------------------
74+
75+
from sklearn.naive_bayes import BernoulliNB
76+
clf = BernoulliNB().fit(XX_train, y_train)
77+
predictions_labels = clf.predict(XX_test)
78+
79+
print u'预测结果:'
80+
print predictions_labels
81+
82+
print u'算法评价:'
83+
print (classification_report(y_test, predictions_labels))
84+
85+
#输出前10张图片及预测结果
86+
k = 0
87+
while k<10:
88+
#读取图像
89+
print X_test[k]
90+
image = cv2.imread(X_test[k])
91+
print predictions_labels[k]
92+
#显示图像
93+
cv2.imshow("img", image)
94+
cv2.waitKey(0)
95+
cv2.destroyAllWindows()
96+
k = k + 1

0 commit comments

Comments
 (0)