Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactoring unit test
  • Loading branch information
Neeratyoy committed Feb 25, 2021
commit b5d019057da55871267ca1a4aa6fd8a522894b5b
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,7 @@ def test_paralizable_check(self):
# using this param distribution should raise an exception
illegal_param_dist = {"base__n_jobs": [-1, 0, 1]}
# using this param distribution should not raise an exception
legal_param_dist = {"base__max_depth": [2, 3, 4]}
legal_param_dist = {"n_estimators": [2, 3, 4]}

legal_models = [
sklearn.ensemble.RandomForestClassifier(),
Expand Down Expand Up @@ -1282,12 +1282,19 @@ def test_paralizable_check(self):

can_measure_cputime_answers = [True, False, False, True, False, False, True, False, False]
can_measure_walltime_answers = [True, True, False, True, True, False, True, True, False]
if LooseVersion(sklearn.__version__) < "0.20":
has_refit_time = [False, False, False, False, False, False, False, False, False]
else:
has_refit_time = [False, False, False, False, False, False, True, True, False]

for model, allowed_cputime, allowed_walltime in zip(
legal_models, can_measure_cputime_answers, can_measure_walltime_answers
X, y = sklearn.datasets.load_iris(return_X_y=True)
for model, allowed_cputime, allowed_walltime, refit_time in zip(
legal_models, can_measure_cputime_answers, can_measure_walltime_answers, has_refit_time
):
self.assertEqual(self.extension._can_measure_cputime(model), allowed_cputime)
self.assertEqual(self.extension._can_measure_wallclocktime(model), allowed_walltime)
model.fit(X, y)
self.assertEqual(refit_time, hasattr(model, "refit_time_"))

for model in illegal_models:
with self.assertRaises(PyOpenMLError):
Expand Down Expand Up @@ -2253,15 +2260,3 @@ def column_transformer_pipe(task_id):
run2 = column_transformer_pipe(23) # only numeric
TestBase._mark_entity_for_removal("run", run2.run_id)
self.assertEqual(run1.setup_id, run2.setup_id)

def test_for_refit_time_in_basesearchCV(self):
X, y = sklearn.datasets.load_iris(return_X_y=True)
rs = sklearn.model_selection.GridSearchCV(
estimator=sklearn.ensemble.RandomForestClassifier(),
param_grid={"n_estimators": [2, 3, 4, 5]},
)
rs.fit(X, y)
if LooseVersion(sklearn.__version__) < "0.20":
self.assertFalse(hasattr(rs, "refit_time_"))
else:
self.assertTrue(hasattr(rs, "refit_time_"))