Skip to content

Commit 2c0b983

Browse files
committed
Add scikit-learn 1.0 and 1.1 values for test (openml#1168)
* Add scikit-learn 1.0 and 1.1 values for test DecisionTree and RandomForestRegressor have one less default hyperparameter: `min_impurity_split` * Remove min_impurity_split requirements for >=1.0 * Update KMeans checks for scikit-learn 1.0 and 1.1
1 parent c4ae8cd commit 2c0b983

1 file changed

Lines changed: 56 additions & 3 deletions

File tree

tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def test_serialize_model(self):
168168
("splitter", '"best"'),
169169
)
170170
)
171-
else:
171+
elif LooseVersion(sklearn.__version__) < "1.0":
172172
fixture_parameters = OrderedDict(
173173
(
174174
("class_weight", "null"),
@@ -186,6 +186,24 @@ def test_serialize_model(self):
186186
("splitter", '"best"'),
187187
)
188188
)
189+
else:
190+
fixture_parameters = OrderedDict(
191+
(
192+
("class_weight", "null"),
193+
("criterion", '"entropy"'),
194+
("max_depth", "null"),
195+
("max_features", '"auto"'),
196+
("max_leaf_nodes", "2000"),
197+
("min_impurity_decrease", "0.0"),
198+
("min_samples_leaf", "1"),
199+
("min_samples_split", "2"),
200+
("min_weight_fraction_leaf", "0.0"),
201+
("presort", presort_val),
202+
("random_state", "null"),
203+
("splitter", '"best"'),
204+
)
205+
)
206+
189207
if LooseVersion(sklearn.__version__) >= "0.22":
190208
fixture_parameters.update({"ccp_alpha": "0.0"})
191209
fixture_parameters.move_to_end("ccp_alpha", last=False)
@@ -249,7 +267,7 @@ def test_serialize_model_clustering(self):
249267
("verbose", "0"),
250268
)
251269
)
252-
else:
270+
elif LooseVersion(sklearn.__version__) < "1.0":
253271
fixture_parameters = OrderedDict(
254272
(
255273
("algorithm", '"auto"'),
@@ -265,6 +283,34 @@ def test_serialize_model_clustering(self):
265283
("verbose", "0"),
266284
)
267285
)
286+
elif LooseVersion(sklearn.__version__) < "1.1":
287+
fixture_parameters = OrderedDict(
288+
(
289+
("algorithm", '"auto"'),
290+
("copy_x", "true"),
291+
("init", '"k-means++"'),
292+
("max_iter", "300"),
293+
("n_clusters", "8"),
294+
("n_init", "10"),
295+
("random_state", "null"),
296+
("tol", "0.0001"),
297+
("verbose", "0"),
298+
)
299+
)
300+
else:
301+
fixture_parameters = OrderedDict(
302+
(
303+
("algorithm", '"lloyd"'),
304+
("copy_x", "true"),
305+
("init", '"k-means++"'),
306+
("max_iter", "300"),
307+
("n_clusters", "8"),
308+
("n_init", "10"),
309+
("random_state", "null"),
310+
("tol", "0.0001"),
311+
("verbose", "0"),
312+
)
313+
)
268314
fixture_structure = {"sklearn.cluster.{}.KMeans".format(cluster_name): []}
269315

270316
serialization, _ = self._serialization_test_helper(
@@ -1335,12 +1381,19 @@ def test__get_fn_arguments_with_defaults(self):
13351381
(sklearn.tree.DecisionTreeClassifier.__init__, 14),
13361382
(sklearn.pipeline.Pipeline.__init__, 2),
13371383
]
1338-
else:
1384+
elif sklearn_version < "1.0":
13391385
fns = [
13401386
(sklearn.ensemble.RandomForestRegressor.__init__, 18),
13411387
(sklearn.tree.DecisionTreeClassifier.__init__, 13),
13421388
(sklearn.pipeline.Pipeline.__init__, 2),
13431389
]
1390+
else:
1391+
# Tested with 1.0 and 1.1
1392+
fns = [
1393+
(sklearn.ensemble.RandomForestRegressor.__init__, 17),
1394+
(sklearn.tree.DecisionTreeClassifier.__init__, 12),
1395+
(sklearn.pipeline.Pipeline.__init__, 2),
1396+
]
13441397

13451398
for fn, num_params_with_defaults in fns:
13461399
defaults, defaultless = self.extension._get_fn_arguments_with_defaults(fn)

0 commit comments

Comments
 (0)