Skip to content

Commit f00b9c5

Browse files
change kmeans base
1 parent 98ae4e1 commit f00b9c5

9 files changed

Lines changed: 609 additions & 123 deletions

File tree

.idea/machine_learning_python.iml

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/workspace.xml

Lines changed: 424 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

kmeans/kmeans.py

Lines changed: 0 additions & 123 deletions
This file was deleted.

kmeans/kmeans_base.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
2+
import random
3+
4+
from collections import defaultdict
5+
6+
from sklearn.cluster import KMeans
7+
import numpy as np
8+
9+
from utils.misc_utils import distance, check_random_state
10+
from utils.data_generater import makeRandomPoint
11+
12+
13+
14+
class KMeansBase(object):
15+
16+
def __init__(self, n_clusters = 8, init="random", max_iter = 300, random_state = None):
17+
self.k = n_clusters
18+
self.init = init
19+
self.max_iter = max_iter
20+
self.random_state = random_state
21+
22+
23+
def fit(self, dataset):
24+
self.dataset = np.array(dataset)
25+
k_points = self._init_centroids(dataset)
26+
assignments = self.assign_points(dataset, k_points)
27+
old_assignments = None
28+
for i in range(self.max_iter):
29+
if assignments == old_assignments:
30+
break
31+
error = self.update_error(dataset, assignments)
32+
print("error", error)
33+
new_centers = self.update_centers(dataset, assignments)
34+
old_assignments = assignments
35+
assignments = self.assign_points(dataset, new_centers)
36+
return zip(assignments, dataset)
37+
38+
39+
# k个数据点,随机生成
40+
def _init_centroids(self, dataset):
41+
random_state = check_random_state(self.random_state)
42+
n_samples = dataset.shape[0]
43+
if self.init == "random":
44+
seeds = random_state.permutation(n_samples)[:self.k]
45+
centers = dataset[seeds]
46+
elif self.init == "k-means++":
47+
centers = []
48+
49+
return centers
50+
51+
52+
# 输入:points是一个聚类的点,维度相同
53+
# 输出:这些点的中心点
54+
def point_avg(self, points):
55+
return np.mean(points, axis=0)
56+
57+
58+
#输入:data_set是数据集的点,assignments是每个点在当前归为的类别
59+
#输出:新的中心点list
60+
def update_centers(self, dataset, assignments):
61+
new_means = defaultdict(list)
62+
centers = []
63+
for assignment, point in zip(assignments, dataset):
64+
new_means[assignment].append(point)
65+
66+
for points in new_means.values():
67+
newCenter = self.point_avg(points)
68+
centers.append(newCenter)
69+
70+
return centers
71+
72+
#输入:data_set是数据集的点,assignments是每个点在当前归为的类别
73+
#输出:新的误差值
74+
def update_error(self, dataset, assignments):
75+
new_means = defaultdict(list)
76+
error = 0
77+
for assignment, point in zip(assignments, dataset):
78+
new_means[assignment].append(point)
79+
80+
for points in new_means.values():
81+
newCenter = self.point_avg(points)
82+
error += np.sqrt(np.sum(np.square(points - newCenter)))
83+
84+
return error
85+
86+
#输入:data_set原始数据集,centers所有的中心点
87+
#输出:每个点对应的聚类类别
88+
def assign_points(self, dataset, centers):
89+
assignments = []
90+
for point in dataset:
91+
shortest = float("inf") # 正无穷
92+
shortest_index = 0
93+
for i in range(len(centers)):
94+
val = distance(point, centers[i])
95+
if val < shortest:
96+
shortest = val
97+
shortest_index = i
98+
assignments.append(shortest_index)
99+
return assignments
100+
101+
102+
103+
104+
105+
106+
if __name__ == "__main__":
107+
108+
iris = datasets.load_iris()
109+
km = KMeansBase(3)
110+
for k in km.fit(iris.data):
111+
print(k)
112+
113+
kmeans = KMeans(init='k-means++', n_clusters= 10, n_init=10)
114+
115+
# pointList = []
116+
# numPoints = 10000
117+
# dim = 1000
118+
# numClusters = 10
119+
# k = 0
120+
# for i in range(0,numClusters):
121+
# num = int(numPoints/numClusters)
122+
# p = makeRandomPoint(num,dim,k)
123+
# k += 5
124+
# pointList += p.tolist()
125+
#
126+
# start = time.time()
127+
# config= k_means(np.array(pointList), numClusters)
128+
# print("Time taken:",time.time() - start)

kmeans/kmeans_main.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from kmeans.kmeans_base import KMeansBase
2+
from sklearn import datasets
3+
4+
if __name__ == "__main__":
5+
6+
7+
8+
9+
iris = datasets.load_iris()
10+
km = KMeansBase(3)
11+
for k in km.fit(iris.data):
12+
print(k)
346 Bytes
Binary file not shown.
1.08 KB
Binary file not shown.

utils/misc_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import numpy as np
2+
import numbers
3+
4+
def distance(point1, point2):
5+
return np.sqrt(np.sum(np.square(point1 - point2)))
6+
7+
8+
def check_random_state(seed):
9+
"""Turn seed into a np.random.RandomState instance
10+
11+
Parameters
12+
----------
13+
seed : None | int | instance of RandomState
14+
If seed is None, return the RandomState singleton used by np.random.
15+
If seed is an int, return a new RandomState instance seeded with seed.
16+
If seed is already a RandomState instance, return it.
17+
Otherwise raise ValueError.
18+
"""
19+
if seed is None or seed is np.random:
20+
return np.random.mtrand._rand
21+
if isinstance(seed, (numbers.Integral, np.integer)):
22+
return np.random.RandomState(seed)
23+
if isinstance(seed, np.random.RandomState):
24+
return seed
25+
raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
26+
' instance' % seed)

0 commit comments

Comments
 (0)