Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Refactoring to inspect.signature in extensions
  • Loading branch information
Neeratyoy committed Jul 31, 2020
commit 023ac406af607da934ce15c0933d14b8665a19c5
6 changes: 3 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ env:
- TEST_DIR=/tmp/test_dir/
- MODULE=openml
matrix:
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.21.2" RUN_FLAKE8="true" SKIP_TESTS="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.21.2" COVERAGE="true" DOCPUSH="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.23.1" TEST_DIST="true"
Comment thread
mfeurer marked this conversation as resolved.
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.22.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.23.1" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.22.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.22.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.21.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.21.2" RUN_FLAKE8="true" SKIP_TESTS="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.21.2" COVERAGE="true" DOCPUSH="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.20.2"
# Checks for older scikit-learn versions (which also don't nicely work with
# Python3.7)
Expand Down
16 changes: 10 additions & 6 deletions openml/extensions/sklearn/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,12 +994,16 @@ def _get_fn_arguments_with_defaults(self, fn_name: Callable) -> Tuple[Dict, Set]
a set with all parameters that do not have a default value
"""
# parameters with defaults are optional, all others are required.
signature = inspect.getfullargspec(fn_name)
if signature.defaults:
optional_params = dict(zip(reversed(signature.args), reversed(signature.defaults)))
else:
optional_params = dict()
required_params = {arg for arg in signature.args if arg not in optional_params}
parameters = inspect.signature(fn_name).parameters
required_params = set()
optional_params = dict()
for param in parameters.keys():
parameter = parameters.get(param)
default_val = parameter.default # type: ignore
if default_val is inspect.Signature.empty:
required_params.add(param)
else:
optional_params[param] = default_val
return optional_params, required_params

def _deserialize_model(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1248,9 +1248,9 @@ def test__get_fn_arguments_with_defaults(self):
]
else:
fns = [
(sklearn.ensemble.RandomForestRegressor.__init__, 0),
(sklearn.tree.DecisionTreeClassifier.__init__, 0),
(sklearn.pipeline.Pipeline.__init__, 0),
(sklearn.ensemble.RandomForestRegressor.__init__, 18),
(sklearn.tree.DecisionTreeClassifier.__init__, 14),
(sklearn.pipeline.Pipeline.__init__, 2),
]

for fn, num_params_with_defaults in fns:
Expand All @@ -1259,8 +1259,7 @@ def test__get_fn_arguments_with_defaults(self):
self.assertIsInstance(defaultless, set)
# check whether we have both defaults and defaultless params
self.assertEqual(len(defaults), num_params_with_defaults)
if sklearn_version < "0.23":
self.assertGreater(len(defaultless), 0)
self.assertGreater(len(defaultless), 0)
# check no overlap
self.assertSetEqual(set(defaults.keys()), set(defaults.keys()) - defaultless)
self.assertSetEqual(defaultless, defaultless - set(defaults.keys()))
Expand Down