Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Refactor out different test cases to separate tests
The previous solution had two test conditions (strict and not strict)
and several scikit-learn versions, because of two distinct changes
within scikit-learn (the removal of min_impurity_split in 1.0, and the
restructuring of public/private models in 0.24).
I refactored out the separate test cases to greatly simplify the
individual tests, and I added a test case for scikit-learn>=1.0,
which was previously not covered.
  • Loading branch information
PGijsbers committed Oct 17, 2022
commit 6bf77afb1f8039c8f4766a8d847ede2b8348cad9
67 changes: 45 additions & 22 deletions tests/test_flows/test_flow_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,32 +324,55 @@ def test_get_flow_reinstantiate_model_no_extension(self):
)

@unittest.skipIf(
LooseVersion(sklearn.__version__) == "0.19.1", reason="Target flow is from sklearn 0.19.1"
LooseVersion(sklearn.__version__) == "0.19.1",
reason="Requires scikit-learn!=0.19.1, because target flow is from that version.",
)
def test_get_flow_reinstantiate_model_wrong_version(self):
# Note that CI does not test against 0.19.1.
def test_get_flow_with_reinstantiate_strict_with_wrong_version_raises_exception(self):
openml.config.server = self.production_server
_, sklearn_major, _ = LooseVersion(sklearn.__version__).version[:3]
if sklearn_major > 23:
flow = 18587 # 18687, 18725 --- flows building random forest on >= 0.23
flow_sklearn_version = "0.23.1"
else:
flow = 8175
flow_sklearn_version = "0.19.1"
expected = (
"Trying to deserialize a model with dependency "
"sklearn=={} not satisfied.".format(flow_sklearn_version)
)
flow = 8175
expected = "Trying to deserialize a model with dependency sklearn==0.19.1 not satisfied."
self.assertRaisesRegex(
ValueError, expected, openml.flows.get_flow, flow_id=flow, reinstantiate=True
ValueError,
expected,
openml.flows.get_flow,
flow_id=flow,
reinstantiate=True,
strict_version=True,
)
if LooseVersion(sklearn.__version__) > "0.19.1":
# 0.18 actually can't deserialize this because of incompatibility
flow = openml.flows.get_flow(flow_id=flow, reinstantiate=True, strict_version=False)
# ensure that a new flow was created
assert flow.flow_id is None
assert "sklearn==0.19.1" not in flow.dependencies
assert "sklearn>=0.19.1" not in flow.dependencies

@unittest.skipIf(
LooseVersion(sklearn.__version__) < "1" and LooseVersion(sklearn.__version__) != "1.0.0",
reason="Requires scikit-learn < 1.0.1."
# Because scikit-learn dropped min_impurity_split hyperparameter in 1.0,
# and the requested flow is from 1.0.0 exactly.
)
def test_get_flow_reinstantiate_flow_not_strict_post_1(self):
openml.config.server = self.production_server
flow = openml.flows.get_flow(flow_id=19190, reinstantiate=True, strict_version=False)
assert flow.flow_id is None
assert "sklearn==1.0.0" not in flow.dependencies

@unittest.skipIf(
(LooseVersion(sklearn.__version__) < "0.23.2")
or ("1.0" < LooseVersion(sklearn.__version__)),
reason="Requires scikit-learn 0.23.2 or ~0.24."
# Because these still have min_impurity_split, but with new scikit-learn module structure."
)
def test_get_flow_reinstantiate_flow_not_strict_023_and_024(self):
openml.config.server = self.production_server
flow = openml.flows.get_flow(flow_id=18587, reinstantiate=True, strict_version=False)
assert flow.flow_id is None
assert "sklearn==0.23.1" not in flow.dependencies

@unittest.skipIf(
"0.23" < LooseVersion(sklearn.__version__),
reason="Requires scikit-learn<=0.23, because the scikit-learn module structure changed.",
)
def test_get_flow_reinstantiate_flow_not_strict_pre_023(self):
openml.config.server = self.production_server
flow = openml.flows.get_flow(flow_id=8175, reinstantiate=True, strict_version=False)
assert flow.flow_id is None
assert "sklearn==0.19.1" not in flow.dependencies

def test_get_flow_id(self):
if self.long_version:
Expand Down