Skip to content

Commit a954ce2

Browse files
authored
ADD raise exception when failing to create sklearn flow (#479)
* ADD raise exception when failing to create sklearn flow * Update changelog
1 parent 351e2b7 commit a954ce2

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

doc/progress.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ Changelog
1414
* Added serialize run / deserialize run function (for saving runs on disk before uploading)
1515
* FIX: fixed bug related to listing functions (returns correct listing size)
1616
* made openml.utils.list_all a hidden function (should be accessed only by the respective listing functions)
17+
* Improve error handling for issue `#479 <https://github.com/openml/openml-python/pull/479>`_:
18+
the OpenML connector fails earlier and with a better error message when
19+
failing to create a flow from the OpenML description.
1720

1821
0.3.0
1922
~~~~~

openml/flows/flow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,12 +313,13 @@ def _from_dict(cls, xml_dict):
313313
# try to parse to a model because not everything that can be
314314
# deserialized has to come from scikit-learn. If it can't be
315315
# serialized, but comes from scikit-learn this is worth an exception
316-
try:
316+
if (
317+
arguments['external_version'].startswith('sklearn==')
318+
or ',sklearn==' in arguments['external_version']
319+
):
317320
from .sklearn_converter import flow_to_sklearn
318321
model = flow_to_sklearn(flow)
319-
except Exception as e:
320-
if arguments['external_version'].startswith('sklearn'):
321-
raise e
322+
else:
322323
model = None
323324
flow.model = model
324325

openml/flows/sklearn_converter.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -394,12 +394,8 @@ def _deserialize_model(flow, **kwargs):
394394
parameter_dict[name] = rval
395395

396396
module_name = model_name.rsplit('.', 1)
397-
try:
398-
model_class = getattr(importlib.import_module(module_name[0]),
399-
module_name[1])
400-
except:
401-
warnings.warn('Cannot create model %s for flow.' % model_name)
402-
return None
397+
model_class = getattr(importlib.import_module(module_name[0]),
398+
module_name[1])
403399

404400
return model_class(**parameter_dict)
405401

0 commit comments

Comments
 (0)