33Authors
44 * Peter Plantinga 2020
55 * Aku Rouhe 2020
6+ * Jianchen Li 2022
67"""
78
89import re
1516import functools
1617import collections
1718import ruamel.yaml
18- from ruamel.yaml.representer import RepresenterError
1919import operator as op
2020from io import StringIO
2121from collections import OrderedDict
22+ from ruamel.yaml.representer import RepresenterError
2223
2324
2425# NOTE: Empty dict as default parameter is fine here since overrides are never
@@ -164,7 +165,7 @@ def load_hyperpyyaml(
164165 yaml.Loader.add_multi_constructor("!new:", _construct_object)
165166 yaml.Loader.add_multi_constructor("!name:", _construct_name)
166167 yaml.Loader.add_multi_constructor("!module:", _construct_module)
167- # yaml.Loader.add_multi_constructor("!apply:", _apply_function)
168+ yaml.Loader.add_multi_constructor("!apply:", _apply_function)
168169
169170 # NOTE: Here we apply a somewhat dirty trick.
170171 # We change the yaml object construction to be deep=True by default.
@@ -325,7 +326,7 @@ def resolve_references(yaml_stream, overrides=None, overrides_must_match=False):
325326
326327
327328def _walk_tree_and_resolve(key, current_node, tree, overrides, file_path):
328- """A recursive function for resolving ``!ref`` and ``!copy `` tags.
329+ """A recursive function for resolving ``!ref``, ``!copy`` and ``!applyref `` tags.
329330
330331 Loads additional yaml files if ``!include:`` tags are used.
331332 Also throws an error if ``!PLACEHOLDER`` tags are encountered.
@@ -385,7 +386,7 @@ def _walk_tree_and_resolve(key, current_node, tree, overrides, file_path):
385386
386387 # Include external yaml files
387388 elif tag_value.startswith("!include:"):
388- filename = tag_value[len("!include:") :]
389+ filename = tag_value[len("!include:"):]
389390
390391 # Update overrides with child keys
391392 if isinstance(current_node, dict):
@@ -404,9 +405,9 @@ def _walk_tree_and_resolve(key, current_node, tree, overrides, file_path):
404405 current_node = ruamel_yaml.load(included_yaml)
405406
406407 # Get the return value of a function
407- elif tag_value.startswith("!apply :"):
408- function = tag_value[len("!apply :") :]
409- current_node = _apply_function (function, current_node)
408+ elif tag_value.startswith("!applyref :"):
409+ function = tag_value[len("!applyref :"):]
410+ current_node = _applyref_function (function, current_node)
410411
411412 # Return node after all resolution is done.
412413 return current_node
@@ -441,10 +442,12 @@ def _get_args(node):
441442 # Example:
442443 # seed: 1024
443444 # __set_seed: !apply:libs.support.utils.set_all_seed
444- # args: [!ref <seed>]
445- # kwargs: {deterministic: True}
446- if "args" in kwargs and "kwargs" in kwargs and len(kwargs) == 2:
447- return kwargs['args'], kwargs['kwargs']
445+ # _args:
446+ # - !ref <seed>
447+ # _kwargs:
448+ # deterministic: True
449+ if "_args" in kwargs and "_kwargs" in kwargs and len(kwargs) == 2:
450+ return kwargs['_args'], kwargs['_kwargs']
448451 else:
449452 return [], kwargs
450453 # SequenceNode
@@ -513,26 +516,7 @@ def _construct_module(loader, module_name, node):
513516 return module
514517
515518
516- # def _apply_function(loader, callable_string, node):
517- # callable_ = pydoc.locate(callable_string)
518- # if callable_ is None:
519- # raise ImportError("There is no such callable as %s" % callable_string)
520- #
521- # if not inspect.isroutine(callable_):
522- # raise ValueError(
523- # f"!apply:{callable_string} should be a callable, but is {callable_}"
524- # )
525- #
526- # try:
527- # args, kwargs = _load_node(loader, node)
528- # return callable_(*args, **kwargs)
529- # except TypeError as e:
530- # err_msg = "Invalid argument to callable %s" % callable_string
531- # e.args = (err_msg, *e.args)
532- # raise
533-
534-
535- def _apply_function(callable_string, node):
519+ def _apply_function(loader, callable_string, node):
536520 callable_ = pydoc.locate(callable_string)
537521 if callable_ is None:
538522 raise ImportError("There is no such callable as %s" % callable_string)
@@ -542,6 +526,25 @@ def _apply_function(callable_string, node):
542526 f"!apply:{callable_string} should be a callable, but is {callable_}"
543527 )
544528
529+ try:
530+ args, kwargs = _load_node(loader, node)
531+ return callable_(*args, **kwargs)
532+ except TypeError as e:
533+ err_msg = "Invalid argument to callable %s" % callable_string
534+ e.args = (err_msg, *e.args)
535+ raise
536+
537+
538+ def _applyref_function(callable_string, node):
539+ callable_ = pydoc.locate(callable_string)
540+ if callable_ is None:
541+ raise ImportError("There is no such callable as %s" % callable_string)
542+
543+ if not inspect.isroutine(callable_):
544+ raise ValueError(
545+ f"!applyref:{callable_string} should be a callable, but is {callable_}"
546+ )
547+
545548 try:
546549 args, kwargs = _get_args(node)
547550 out = callable_(*args, **kwargs)
0 commit comments