@@ -168,7 +168,7 @@ def test_serialize_model(self):
168168 ("splitter" , '"best"' ),
169169 )
170170 )
171- else :
171+ elif LooseVersion ( sklearn . __version__ ) < "1.0" :
172172 fixture_parameters = OrderedDict (
173173 (
174174 ("class_weight" , "null" ),
@@ -186,6 +186,24 @@ def test_serialize_model(self):
186186 ("splitter" , '"best"' ),
187187 )
188188 )
189+ else :
190+ fixture_parameters = OrderedDict (
191+ (
192+ ("class_weight" , "null" ),
193+ ("criterion" , '"entropy"' ),
194+ ("max_depth" , "null" ),
195+ ("max_features" , '"auto"' ),
196+ ("max_leaf_nodes" , "2000" ),
197+ ("min_impurity_decrease" , "0.0" ),
198+ ("min_samples_leaf" , "1" ),
199+ ("min_samples_split" , "2" ),
200+ ("min_weight_fraction_leaf" , "0.0" ),
201+ ("presort" , presort_val ),
202+ ("random_state" , "null" ),
203+ ("splitter" , '"best"' ),
204+ )
205+ )
206+
189207 if LooseVersion (sklearn .__version__ ) >= "0.22" :
190208 fixture_parameters .update ({"ccp_alpha" : "0.0" })
191209 fixture_parameters .move_to_end ("ccp_alpha" , last = False )
@@ -249,7 +267,7 @@ def test_serialize_model_clustering(self):
249267 ("verbose" , "0" ),
250268 )
251269 )
252- else :
270+ elif LooseVersion ( sklearn . __version__ ) < "1.0" :
253271 fixture_parameters = OrderedDict (
254272 (
255273 ("algorithm" , '"auto"' ),
@@ -265,6 +283,34 @@ def test_serialize_model_clustering(self):
265283 ("verbose" , "0" ),
266284 )
267285 )
286+ elif LooseVersion (sklearn .__version__ ) < "1.1" :
287+ fixture_parameters = OrderedDict (
288+ (
289+ ("algorithm" , '"auto"' ),
290+ ("copy_x" , "true" ),
291+ ("init" , '"k-means++"' ),
292+ ("max_iter" , "300" ),
293+ ("n_clusters" , "8" ),
294+ ("n_init" , "10" ),
295+ ("random_state" , "null" ),
296+ ("tol" , "0.0001" ),
297+ ("verbose" , "0" ),
298+ )
299+ )
300+ else :
301+ fixture_parameters = OrderedDict (
302+ (
303+ ("algorithm" , '"lloyd"' ),
304+ ("copy_x" , "true" ),
305+ ("init" , '"k-means++"' ),
306+ ("max_iter" , "300" ),
307+ ("n_clusters" , "8" ),
308+ ("n_init" , "10" ),
309+ ("random_state" , "null" ),
310+ ("tol" , "0.0001" ),
311+ ("verbose" , "0" ),
312+ )
313+ )
268314 fixture_structure = {"sklearn.cluster.{}.KMeans" .format (cluster_name ): []}
269315
270316 serialization , _ = self ._serialization_test_helper (
@@ -1335,12 +1381,19 @@ def test__get_fn_arguments_with_defaults(self):
13351381 (sklearn .tree .DecisionTreeClassifier .__init__ , 14 ),
13361382 (sklearn .pipeline .Pipeline .__init__ , 2 ),
13371383 ]
1338- else :
1384+ elif sklearn_version < "1.0" :
13391385 fns = [
13401386 (sklearn .ensemble .RandomForestRegressor .__init__ , 18 ),
13411387 (sklearn .tree .DecisionTreeClassifier .__init__ , 13 ),
13421388 (sklearn .pipeline .Pipeline .__init__ , 2 ),
13431389 ]
1390+ else :
1391+ # Tested with 1.0 and 1.1
1392+ fns = [
1393+ (sklearn .ensemble .RandomForestRegressor .__init__ , 17 ),
1394+ (sklearn .tree .DecisionTreeClassifier .__init__ , 12 ),
1395+ (sklearn .pipeline .Pipeline .__init__ , 2 ),
1396+ ]
13441397
13451398 for fn , num_params_with_defaults in fns :
13461399 defaults , defaultless = self .extension ._get_fn_arguments_with_defaults (fn )
0 commit comments