Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions openml/extensions/sklearn/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,10 +694,14 @@ def _serialize_model(self, model: Any) -> OpenMLFlow:
# will be part of the name (in brackets)
sub_components_names = ""
for key in subcomponents:
if isinstance(subcomponents[key], str):
Comment thread
mfeurer marked this conversation as resolved.
Outdated
name = subcomponents[key]
else:
name = subcomponents[key].name
if key in subcomponents_explicit:
sub_components_names += "," + key + "=" + subcomponents[key].name
sub_components_names += "," + key + "=" + name
else:
sub_components_names += "," + subcomponents[key].name
sub_components_names += "," + name

if sub_components_names:
# slice operation on string in order to get rid of leading comma
Expand Down Expand Up @@ -769,21 +773,25 @@ def _get_external_version_string(
external_versions.add(openml_version)
external_versions.add(sklearn_version)
for visitee in sub_components.values():
if isinstance(visitee, str):
Comment thread
mfeurer marked this conversation as resolved.
continue
for external_version in visitee.external_version.split(','):
external_versions.add(external_version)
return ','.join(list(sorted(external_versions)))

def _check_multiple_occurence_of_component_in_flow(
self,
model: Any,
sub_components: Dict[str, OpenMLFlow],
sub_components: Dict[str, Any],
Comment thread
mfeurer marked this conversation as resolved.
Outdated
) -> None:
to_visit_stack = [] # type: List[OpenMLFlow]
to_visit_stack.extend(sub_components.values())
known_sub_components = set() # type: Set[str]
while len(to_visit_stack) > 0:
visitee = to_visit_stack.pop()
if visitee.name in known_sub_components:
if isinstance(visitee, str):
Comment thread
mfeurer marked this conversation as resolved.
Outdated
known_sub_components.add(visitee)
elif visitee.name in known_sub_components:
raise ValueError('Found a second occurence of component %s when '
'trying to serialize %s.' % (visitee.name, model))
else:
Expand All @@ -796,7 +804,7 @@ def _extract_information_from_model(
) -> Tuple[
'OrderedDict[str, Optional[str]]',
'OrderedDict[str, Optional[Dict]]',
'OrderedDict[str, OpenMLFlow]',
'OrderedDict[str, Any]',
Comment thread
mfeurer marked this conversation as resolved.
Outdated
Set,
]:
# This function contains four "global" states and is quite long and
Expand All @@ -820,7 +828,7 @@ def _extract_information_from_model(
def flatten_all(list_):
""" Flattens arbitrary depth lists of lists (e.g. [[1,2],[3,[1]]] -> [1,2,3,1]). """
for el in list_:
if isinstance(el, (list, tuple)):
if isinstance(el, (list, tuple)) and len(el) > 0:
yield from flatten_all(el)
else:
yield el
Expand Down Expand Up @@ -860,7 +868,16 @@ def flatten_all(list_):
# length 3 is for ColumnTransformer
msg = 'Length of tuple does not match assumptions'
raise ValueError(msg)
if not isinstance(sub_component, (OpenMLFlow, type(None))):

if isinstance(sub_component, str):
Comment thread
Neeratyoy marked this conversation as resolved.
if sub_component != 'drop' and sub_component != 'passthrough':
msg = 'Second item of tuple does not match assumptions. ' \
'If string, can be only \'drop\' or \'passthrough\' but' \
'got %s' % sub_component
raise ValueError(msg)
else:
pass
elif not isinstance(sub_component, (OpenMLFlow, type(None))):
msg = 'Second item of tuple does not match assumptions. ' \
'Expected OpenMLFlow, got %s' % type(sub_component)
raise TypeError(msg)
Expand Down