Skip to content

Commit 84416d8

Browse files
authored
Add files via upload
1 parent b5afed7 commit 84416d8

1 file changed

Lines changed: 76 additions & 0 deletions

File tree

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
By: Eastmount CSDN xiuzhang 2024-04-12
4+
"""
5+
import os
6+
import cv2
7+
import numpy as np
8+
import matplotlib.pyplot as plt
9+
from sklearn.linear_model import LogisticRegression
10+
from sklearn.model_selection import train_test_split
11+
from sklearn.metrics import confusion_matrix, classification_report
12+
13+
#-----------------------------------------------------------------------
14+
#第一步 读取数据集并划分训练集
15+
X = [] #定义图像名称
16+
Y = [] #定义图像分类类标
17+
Z = [] #定义图像像素
18+
19+
#遍历文件夹读取图片
20+
for i in range(0, 12):
21+
for f in os.listdir("final_data/%s" % i):
22+
X.append("final_data//" +str(i) + "//" + str(f))
23+
Y.append(i)
24+
X = np.array(X)
25+
Y = np.array(Y)
26+
print(X[:2])
27+
28+
X_train, X_test, y_train, y_test = train_test_split(X, Y,
29+
test_size=0.3, random_state=1)
30+
print(len(X_train), len(X_test), len(y_train), len(y_test))
31+
#3696 1584 3696 1584
32+
33+
#-----------------------------------------------------------------------
34+
#第二步 图像读取及转换为像素直方图
35+
#训练集
36+
XX_train = []
37+
for i in X_train:
38+
image = cv2.imread(i)
39+
img = cv2.resize(image, (64,64),
40+
interpolation=cv2.INTER_CUBIC)
41+
hist = cv2.calcHist([img], [0,1], None,
42+
[256,256], [0.0,255.0,0.0,255.0])
43+
XX_train.append(((hist/255).flatten()))
44+
45+
#测试集
46+
XX_test = []
47+
for i in X_test:
48+
image = cv2.imread(i)
49+
img = cv2.resize(image, (64,64),
50+
interpolation=cv2.INTER_CUBIC)
51+
hist = cv2.calcHist([img], [0,1], None,
52+
[256,256], [0.0,255.0,0.0,255.0])
53+
XX_test.append(((hist/255).flatten()))
54+
55+
#-----------------------------------------------------------------------
56+
#第三步 基于机器学习的图像分类处理
57+
clf = LogisticRegression(C=100.0,random_state=1)
58+
clf.fit(XX_train, y_train)
59+
predictions_labels = clf.predict(XX_test)
60+
print('预测结果:')
61+
print(predictions_labels)
62+
print('算法评价:')
63+
print(classification_report(y_test, predictions_labels,digits=4))
64+
65+
#输出前10张图片及预测结果
66+
k = 0
67+
while k<10:
68+
print(X_test[k])
69+
image = cv2.imread(X_test[k])
70+
print(predictions_labels[k])
71+
cv2.imshow("img", image)
72+
cv2.waitKey(0)
73+
cv2.destroyAllWindows()
74+
k = k + 1
75+
76+

0 commit comments

Comments
 (0)