Skip to content

Commit b285a22

Browse files
authored
Create blog52_01_Image_nb(像素统计).py
1 parent c15f7ba commit b285a22

1 file changed

Lines changed: 73 additions & 0 deletions

File tree

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
By: Eastmount CSDN xiuzhang 2024-04-12
4+
"""
5+
import os
6+
import cv2
7+
import numpy as np
8+
import matplotlib.pyplot as plt
9+
from sklearn.naive_bayes import BernoulliNB
10+
from sklearn.model_selection import train_test_split
11+
from sklearn.metrics import confusion_matrix, classification_report
12+
13+
#-----------------------------------------------------------------------
14+
#第一步 读取数据集并划分训练集
15+
X = [] #定义图像名称
16+
Y = [] #定义图像分类类标
17+
Z = [] #定义图像像素
18+
19+
#遍历文件夹读取图片
20+
for i in range(0, 12):
21+
for f in os.listdir("final_data/%s" % i):
22+
X.append("final_data//" +str(i) + "//" + str(f))
23+
Y.append(i)
24+
X = np.array(X)
25+
Y = np.array(Y)
26+
print(X[:2])
27+
28+
X_train, X_test, y_train, y_test = train_test_split(X, Y,
29+
test_size=0.3, random_state=1)
30+
print(len(X_train), len(X_test), len(y_train), len(y_test))
31+
#3696 1584 3696 1584
32+
33+
#-----------------------------------------------------------------------
34+
#第二步 图像读取及转换为像素直方图
35+
#训练集
36+
XX_train = []
37+
for i in X_train:
38+
image = cv2.imread(i)
39+
img = cv2.resize(image, (32,32),
40+
interpolation=cv2.INTER_CUBIC)
41+
hist = cv2.calcHist([img], [0,1], None,
42+
[256,256], [0.0,255.0,0.0,255.0])
43+
XX_train.append(((hist/255).flatten()))
44+
45+
#测试集
46+
XX_test = []
47+
for i in X_test:
48+
image = cv2.imread(i)
49+
img = cv2.resize(image, (32,32),
50+
interpolation=cv2.INTER_CUBIC)
51+
hist = cv2.calcHist([img], [0,1], None,
52+
[256,256], [0.0,255.0,0.0,255.0])
53+
XX_test.append(((hist/255).flatten()))
54+
55+
#-----------------------------------------------------------------------
56+
#第三步 基于机器学习的图像分类处理
57+
clf = BernoulliNB().fit(XX_train, y_train)
58+
predictions_labels = clf.predict(XX_test)
59+
print('预测结果:')
60+
print(predictions_labels)
61+
print('算法评价:')
62+
print(classification_report(y_test, predictions_labels,digits=4))
63+
64+
#输出前10张图片及预测结果
65+
k = 0
66+
while k<10:
67+
print(X_test[k])
68+
image = cv2.imread(X_test[k])
69+
print(predictions_labels[k])
70+
cv2.imshow("img", image)
71+
cv2.waitKey(0)
72+
cv2.destroyAllWindows()
73+
k = k + 1

0 commit comments

Comments
 (0)