@@ -338,14 +338,15 @@ def test_serialize_model_clustering(self):
338338 )
339339 )
340340 else :
341+ n_init = '"warn"' if LooseVersion (sklearn .__version__ ) >= "1.2" else "10"
341342 fixture_parameters = OrderedDict (
342343 (
343344 ("algorithm" , '"lloyd"' ),
344345 ("copy_x" , "true" ),
345346 ("init" , '"k-means++"' ),
346347 ("max_iter" , "300" ),
347348 ("n_clusters" , "8" ),
348- ("n_init" , "10" ),
349+ ("n_init" , n_init ),
349350 ("random_state" , "null" ),
350351 ("tol" , "0.0001" ),
351352 ("verbose" , "0" ),
@@ -358,13 +359,13 @@ def test_serialize_model_clustering(self):
358359 )
359360 structure = serialization .get_structure ("name" )
360361
361- self . assertEqual ( serialization .name , fixture_name )
362- self . assertEqual ( serialization .class_name , fixture_name )
363- self . assertEqual ( serialization .custom_name , fixture_short_name )
364- self . assertEqual ( serialization .description , fixture_description )
365- self . assertEqual ( serialization .parameters , fixture_parameters )
366- self . assertEqual ( serialization .dependencies , version_fixture )
367- self . assertDictEqual ( structure , fixture_structure )
362+ assert serialization .name == fixture_name
363+ assert serialization .class_name == fixture_name
364+ assert serialization .custom_name == fixture_short_name
365+ assert serialization .description == fixture_description
366+ assert serialization .parameters == fixture_parameters
367+ assert serialization .dependencies == version_fixture
368+ assert structure == fixture_structure
368369
369370 def test_serialize_model_with_subcomponent (self ):
370371 model = sklearn .ensemble .AdaBoostClassifier (
@@ -1449,22 +1450,19 @@ def test_deserialize_complex_with_defaults(self):
14491450 pipe_orig = sklearn .pipeline .Pipeline (steps = steps )
14501451
14511452 pipe_adjusted = sklearn .clone (pipe_orig )
1452- if LooseVersion (sklearn .__version__ ) < "0.23" :
1453- params = {
1454- "Imputer__strategy" : "median" ,
1455- "OneHotEncoder__sparse" : False ,
1456- "Estimator__n_estimators" : 10 ,
1457- "Estimator__base_estimator__n_estimators" : 10 ,
1458- "Estimator__base_estimator__base_estimator__learning_rate" : 0.1 ,
1459- }
1460- else :
1461- params = {
1462- "Imputer__strategy" : "mean" ,
1463- "OneHotEncoder__sparse" : True ,
1464- "Estimator__n_estimators" : 50 ,
1465- "Estimator__base_estimator__n_estimators" : 10 ,
1466- "Estimator__base_estimator__base_estimator__learning_rate" : 0.1 ,
1467- }
1453+ impute_strategy = "median" if LooseVersion (sklearn .__version__ ) < "0.23" else "mean"
1454+ sparse = LooseVersion (sklearn .__version__ ) >= "0.23"
1455+ estimator_name = (
1456+ "base_estimator" if LooseVersion (sklearn .__version__ ) < "1.2" else "estimator"
1457+ )
1458+ params = {
1459+ "Imputer__strategy" : impute_strategy ,
1460+ "OneHotEncoder__sparse" : sparse ,
1461+ "Estimator__n_estimators" : 10 ,
1462+ f"Estimator__{ estimator_name } __n_estimators" : 10 ,
1463+ f"Estimator__{ estimator_name } __{ estimator_name } __learning_rate" : 0.1 ,
1464+ }
1465+
14681466 pipe_adjusted .set_params (** params )
14691467 flow = self .extension .model_to_flow (pipe_adjusted )
14701468 pipe_deserialized = self .extension .flow_to_model (flow , initialize_with_defaults = True )
0 commit comments