Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixing sklearn version compatibility issue
  • Loading branch information
Neeratyoy committed Nov 2, 2020
commit 3c3b3d851a81e91a51794fe06c19368fcadc443d
17 changes: 13 additions & 4 deletions openml/extensions/sklearn/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,13 +1551,22 @@ def check_if_model_fitted(self, model: Any) -> bool:
try:
# check if model is fitted
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted

check_is_fitted(model) # raises a NotFittedError if the model has not been trained
# Creating random dummy data of arbitrary size
dummy_data = np.random.uniform(size=(10, 3))
# Using 'predict' instead of 'sklearn.utils.validation.check_is_fitted' for a more
# robust check that works across sklearn versions and models. Internally, 'predict'
# should call 'check_is_fitted' for every concerned attribute, thus offering a more
# assured check than explicit calls to 'check_is_fitted'
model.predict(dummy_data)
# Will reach here if the model was fit on a dataset with 3 features
return True
except NotFittedError:
# model is not fitted, as is required
except NotFittedError: # needs to be the first exception to be caught
# Model is not fitted, as is required
return False
except ValueError:
# Will reach here if the model was fit on a dataset with more or less than 3 features
return True

def _run_model_on_fold(
self,
Expand Down