@@ -53,13 +53,14 @@ def setUp(self):
5353
5454 self .kmeans = KMeans (self .num_centers ,
5555 initial_clusters = kmeans_ops .RANDOM_INIT ,
56- batch_size = self .batch_size ,
5756 use_mini_batch = self .use_mini_batch ,
58- steps = 30 ,
59- continue_training = True ,
60- config = run_config .RunConfig (tf_random_seed = 14 ),
57+ config = self .config (14 ),
6158 random_seed = 12 )
6259
60+ @staticmethod
61+ def config (tf_random_seed ):
62+ return run_config .RunConfig (tf_random_seed = tf_random_seed )
63+
6364 @property
6465 def batch_size (self ):
6566 return self .num_points
@@ -86,7 +87,7 @@ def make_random_points(centers, num_points, max_offset=20):
8687
8788 def test_clusters (self ):
8889 kmeans = self .kmeans
89- kmeans .fit (x = self .points , steps = 0 )
90+ kmeans .fit (x = self .points , steps = 1 , batch_size = 8 )
9091 clusters = kmeans .clusters ()
9192 self .assertAllEqual (list (clusters .shape ),
9293 [self .num_centers , self .num_dims ])
@@ -97,10 +98,11 @@ def test_fit(self):
9798 return
9899 kmeans = self .kmeans
99100 kmeans .fit (x = self .points ,
100- steps = 1 )
101+ steps = 1 , batch_size = self . batch_size )
101102 score1 = kmeans .score (x = self .points )
102103 kmeans .fit (x = self .points ,
103- steps = 15 * self .num_points // self .batch_size )
104+ steps = 15 * self .num_points // self .batch_size ,
105+ batch_size = self .batch_size )
104106 score2 = kmeans .score (x = self .points )
105107 self .assertTrue (score1 > score2 )
106108 self .assertNear (self .true_score , score2 , self .true_score * 0.05 )
@@ -111,39 +113,36 @@ def test_monitor(self):
111113 return
112114 kmeans = KMeans (self .num_centers ,
113115 initial_clusters = kmeans_ops .RANDOM_INIT ,
114- batch_size = self .batch_size ,
115116 use_mini_batch = self .use_mini_batch ,
116- # Force it to train forever until the monitor stops it.
117- steps = None ,
118- continue_training = True ,
119117 config = run_config .RunConfig (tf_random_seed = 14 ),
120118 random_seed = 12 )
121119
122120 kmeans .fit (x = self .points ,
123121 # Force it to train forever until the monitor stops it.
124122 steps = None ,
123+ batch_size = self .batch_size ,
125124 relative_tolerance = 1e-4 )
126125 score = kmeans .score (x = self .points )
127126 self .assertNear (self .true_score , score , self .true_score * 0.005 )
128127
129128 def test_infer (self ):
130129 kmeans = self .kmeans
131- kmeans .fit (x = self .points )
130+ kmeans .fit (x = self .points , steps = 10 , batch_size = 128 )
132131 clusters = kmeans .clusters ()
133132
134133 # Make a small test set
135134 points , true_assignments , true_offsets = self .make_random_points (clusters ,
136135 10 )
137136 # Test predict
138- assignments = kmeans .predict (points )
137+ assignments = kmeans .predict (points , batch_size = self . batch_size )
139138 self .assertAllEqual (assignments , true_assignments )
140139
141140 # Test score
142- score = kmeans .score (points )
141+ score = kmeans .score (points , batch_size = 128 )
143142 self .assertNear (score , np .sum (true_offsets ), 0.01 * score )
144143
145144 # Test transform
146- transform = kmeans .transform (points )
145+ transform = kmeans .transform (points , batch_size = 128 )
147146 true_transform = np .maximum (
148147 0 ,
149148 np .sum (np .square (points ), axis = 1 , keepdims = True ) -
@@ -161,12 +160,9 @@ def test_fit_with_cosine_distance(self):
161160 initial_clusters = kmeans_ops .RANDOM_INIT ,
162161 distance_metric = kmeans_ops .COSINE_DISTANCE ,
163162 use_mini_batch = self .use_mini_batch ,
164- batch_size = 4 ,
165- steps = 30 ,
166- continue_training = True ,
167- config = run_config .RunConfig (tf_random_seed = 2 ),
163+ config = self .config (2 ),
168164 random_seed = 12 )
169- kmeans .fit (x = points )
165+ kmeans .fit (x = points , steps = 10 , batch_size = 4 )
170166 centers = normalize (kmeans .clusters ())
171167 self .assertAllClose (np .sort (centers , axis = 0 ),
172168 np .sort (true_centers , axis = 0 ))
@@ -184,18 +180,16 @@ def test_transform_with_cosine_distance(self):
184180 initial_clusters = kmeans_ops .RANDOM_INIT ,
185181 distance_metric = kmeans_ops .COSINE_DISTANCE ,
186182 use_mini_batch = self .use_mini_batch ,
187- batch_size = 8 ,
188- continue_training = True ,
189- config = run_config .RunConfig (tf_random_seed = 3 ))
190- kmeans .fit (x = points , steps = 30 )
183+ config = self .config (3 ))
184+ kmeans .fit (x = points , steps = 30 , batch_size = 8 )
191185
192186 centers = normalize (kmeans .clusters ())
193187 self .assertAllClose (np .sort (centers , axis = 0 ),
194188 np .sort (true_centers , axis = 0 ),
195189 atol = 1e-2 )
196190
197191 true_transform = 1 - cosine_similarity (points , centers )
198- transform = kmeans .transform (points )
192+ transform = kmeans .transform (points , batch_size = 8 )
199193 self .assertAllClose (transform , true_transform , atol = 1e-3 )
200194
201195 def test_predict_with_cosine_distance (self ):
@@ -217,20 +211,18 @@ def test_predict_with_cosine_distance(self):
217211 initial_clusters = kmeans_ops .RANDOM_INIT ,
218212 distance_metric = kmeans_ops .COSINE_DISTANCE ,
219213 use_mini_batch = self .use_mini_batch ,
220- batch_size = 8 ,
221- continue_training = True ,
222- config = run_config .RunConfig (tf_random_seed = 3 ))
223- kmeans .fit (x = points , steps = 30 )
214+ config = self .config (3 ))
215+ kmeans .fit (x = points , steps = 30 , batch_size = 8 )
224216
225217 centers = normalize (kmeans .clusters ())
226218 self .assertAllClose (np .sort (centers , axis = 0 ),
227219 np .sort (true_centers , axis = 0 ), atol = 1e-2 )
228220
229- assignments = kmeans .predict (points )
221+ assignments = kmeans .predict (points , batch_size = 8 )
230222 self .assertAllClose (centers [assignments ],
231223 true_centers [true_assignments ], atol = 1e-2 )
232224
233- score = kmeans .score (points )
225+ score = kmeans .score (points , batch_size = 8 )
234226 self .assertAllClose (score , true_score , atol = 1e-2 )
235227
236228 def test_predict_with_cosine_distance_and_kmeans_plus_plus (self ):
@@ -254,29 +246,27 @@ def test_predict_with_cosine_distance_and_kmeans_plus_plus(self):
254246 initial_clusters = kmeans_ops .KMEANS_PLUS_PLUS_INIT ,
255247 distance_metric = kmeans_ops .COSINE_DISTANCE ,
256248 use_mini_batch = self .use_mini_batch ,
257- batch_size = 12 ,
258- continue_training = True ,
259- config = run_config .RunConfig (tf_random_seed = 3 ))
260- kmeans .fit (x = points , steps = 30 )
249+ config = self .config (3 ))
250+ kmeans .fit (x = points , steps = 30 , batch_size = 12 )
261251
262252 centers = normalize (kmeans .clusters ())
263253 self .assertAllClose (sorted (centers .tolist ()),
264254 sorted (true_centers .tolist ()),
265255 atol = 1e-2 )
266256
267- assignments = kmeans .predict (points )
257+ assignments = kmeans .predict (points , batch_size = 12 )
268258 self .assertAllClose (centers [assignments ],
269259 true_centers [true_assignments ], atol = 1e-2 )
270260
271- score = kmeans .score (points )
261+ score = kmeans .score (points , batch_size = 12 )
272262 self .assertAllClose (score , true_score , atol = 1e-2 )
273263
274264 def test_fit_raise_if_num_clusters_larger_than_num_points_random_init (self ):
275265 points = np .array ([[2.0 , 3.0 ], [1.6 , 8.2 ]])
276266
277267 with self .assertRaisesOpError ('less' ):
278268 kmeans = KMeans (num_clusters = 3 , initial_clusters = kmeans_ops .RANDOM_INIT )
279- kmeans .fit (x = points )
269+ kmeans .fit (x = points , steps = 10 , batch_size = 8 )
280270
281271 def test_fit_raise_if_num_clusters_larger_than_num_points_kmeans_plus_plus (
282272 self ):
@@ -285,7 +275,7 @@ def test_fit_raise_if_num_clusters_larger_than_num_points_kmeans_plus_plus(
285275 with self .assertRaisesOpError (AssertionError ):
286276 kmeans = KMeans (num_clusters = 3 ,
287277 initial_clusters = kmeans_ops .KMEANS_PLUS_PLUS_INIT )
288- kmeans .fit (x = points )
278+ kmeans .fit (x = points , steps = 10 , batch_size = 8 )
289279
290280
291281class MiniBatchKMeansTest (KMeansTest ):
0 commit comments