Skip to content

Commit 281f277

Browse files
committed
Update n_init parameter for sklearn 1.2
1 parent 133fc6c commit 281f277

1 file changed

Lines changed: 9 additions & 8 deletions

File tree

tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -338,14 +338,15 @@ def test_serialize_model_clustering(self):
338338
)
339339
)
340340
else:
341+
n_init = '"warn"' if LooseVersion(sklearn.__version__) >= "1.2" else "10"
341342
fixture_parameters = OrderedDict(
342343
(
343344
("algorithm", '"lloyd"'),
344345
("copy_x", "true"),
345346
("init", '"k-means++"'),
346347
("max_iter", "300"),
347348
("n_clusters", "8"),
348-
("n_init", "10"),
349+
("n_init", n_init),
349350
("random_state", "null"),
350351
("tol", "0.0001"),
351352
("verbose", "0"),
@@ -358,13 +359,13 @@ def test_serialize_model_clustering(self):
358359
)
359360
structure = serialization.get_structure("name")
360361

361-
self.assertEqual(serialization.name, fixture_name)
362-
self.assertEqual(serialization.class_name, fixture_name)
363-
self.assertEqual(serialization.custom_name, fixture_short_name)
364-
self.assertEqual(serialization.description, fixture_description)
365-
self.assertEqual(serialization.parameters, fixture_parameters)
366-
self.assertEqual(serialization.dependencies, version_fixture)
367-
self.assertDictEqual(structure, fixture_structure)
362+
assert serialization.name == fixture_name
363+
assert serialization.class_name == fixture_name
364+
assert serialization.custom_name == fixture_short_name
365+
assert serialization.description == fixture_description
366+
assert serialization.parameters == fixture_parameters
367+
assert serialization.dependencies == version_fixture
368+
assert structure == fixture_structure
368369

369370
def test_serialize_model_with_subcomponent(self):
370371
model = sklearn.ensemble.AdaBoostClassifier(

0 commit comments

Comments
 (0)