Skip to content

Commit 2e99e8e

Browse files
committed
Define the "!applyref" tag
1 parent 1e47fa6 commit 2e99e8e

3 files changed

Lines changed: 78 additions & 31 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,6 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
tmp.py
132+
tmp.yaml

hyperpyyaml/core.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Authors
44
* Peter Plantinga 2020
55
* Aku Rouhe 2020
6+
* Jianchen Li 2022
67
"""
78

89
import re
@@ -15,10 +16,10 @@
1516
import functools
1617
import collections
1718
import ruamel.yaml
18-
from ruamel.yaml.representer import RepresenterError
1919
import operator as op
2020
from io import StringIO
2121
from collections import OrderedDict
22+
from ruamel.yaml.representer import RepresenterError
2223

2324

2425
# NOTE: Empty dict as default parameter is fine here since overrides are never
@@ -164,7 +165,7 @@ def load_hyperpyyaml(
164165
yaml.Loader.add_multi_constructor("!new:", _construct_object)
165166
yaml.Loader.add_multi_constructor("!name:", _construct_name)
166167
yaml.Loader.add_multi_constructor("!module:", _construct_module)
167-
# yaml.Loader.add_multi_constructor("!apply:", _apply_function)
168+
yaml.Loader.add_multi_constructor("!apply:", _apply_function)
168169

169170
# NOTE: Here we apply a somewhat dirty trick.
170171
# We change the yaml object construction to be deep=True by default.
@@ -325,7 +326,7 @@ def resolve_references(yaml_stream, overrides=None, overrides_must_match=False):
325326

326327

327328
def _walk_tree_and_resolve(key, current_node, tree, overrides, file_path):
328-
"""A recursive function for resolving ``!ref`` and ``!copy`` tags.
329+
"""A recursive function for resolving ``!ref``, ``!copy`` and ``!applyref`` tags.
329330

330331
Loads additional yaml files if ``!include:`` tags are used.
331332
Also throws an error if ``!PLACEHOLDER`` tags are encountered.
@@ -385,7 +386,7 @@ def _walk_tree_and_resolve(key, current_node, tree, overrides, file_path):
385386

386387
# Include external yaml files
387388
elif tag_value.startswith("!include:"):
388-
filename = tag_value[len("!include:") :]
389+
filename = tag_value[len("!include:"):]
389390

390391
# Update overrides with child keys
391392
if isinstance(current_node, dict):
@@ -404,9 +405,9 @@ def _walk_tree_and_resolve(key, current_node, tree, overrides, file_path):
404405
current_node = ruamel_yaml.load(included_yaml)
405406

406407
# Get the return value of a function
407-
elif tag_value.startswith("!apply:"):
408-
function = tag_value[len("!apply:") :]
409-
current_node = _apply_function(function, current_node)
408+
elif tag_value.startswith("!applyref:"):
409+
function = tag_value[len("!applyref:"):]
410+
current_node = _applyref_function(function, current_node)
410411

411412
# Return node after all resolution is done.
412413
return current_node
@@ -441,10 +442,12 @@ def _get_args(node):
441442
# Example:
442443
# seed: 1024
443444
# __set_seed: !apply:libs.support.utils.set_all_seed
444-
# args: [!ref <seed>]
445-
# kwargs: {deterministic: True}
446-
if "args" in kwargs and "kwargs" in kwargs and len(kwargs) == 2:
447-
return kwargs['args'], kwargs['kwargs']
445+
# _args:
446+
# - !ref <seed>
447+
# _kwargs:
448+
# deterministic: True
449+
if "_args" in kwargs and "_kwargs" in kwargs and len(kwargs) == 2:
450+
return kwargs['_args'], kwargs['_kwargs']
448451
else:
449452
return [], kwargs
450453
# SequenceNode
@@ -513,26 +516,7 @@ def _construct_module(loader, module_name, node):
513516
return module
514517

515518

516-
# def _apply_function(loader, callable_string, node):
517-
# callable_ = pydoc.locate(callable_string)
518-
# if callable_ is None:
519-
# raise ImportError("There is no such callable as %s" % callable_string)
520-
#
521-
# if not inspect.isroutine(callable_):
522-
# raise ValueError(
523-
# f"!apply:{callable_string} should be a callable, but is {callable_}"
524-
# )
525-
#
526-
# try:
527-
# args, kwargs = _load_node(loader, node)
528-
# return callable_(*args, **kwargs)
529-
# except TypeError as e:
530-
# err_msg = "Invalid argument to callable %s" % callable_string
531-
# e.args = (err_msg, *e.args)
532-
# raise
533-
534-
535-
def _apply_function(callable_string, node):
519+
def _apply_function(loader, callable_string, node):
536520
callable_ = pydoc.locate(callable_string)
537521
if callable_ is None:
538522
raise ImportError("There is no such callable as %s" % callable_string)
@@ -542,6 +526,25 @@ def _apply_function(callable_string, node):
542526
f"!apply:{callable_string} should be a callable, but is {callable_}"
543527
)
544528

529+
try:
530+
args, kwargs = _load_node(loader, node)
531+
return callable_(*args, **kwargs)
532+
except TypeError as e:
533+
err_msg = "Invalid argument to callable %s" % callable_string
534+
e.args = (err_msg, *e.args)
535+
raise
536+
537+
538+
def _applyref_function(callable_string, node):
539+
callable_ = pydoc.locate(callable_string)
540+
if callable_ is None:
541+
raise ImportError("There is no such callable as %s" % callable_string)
542+
543+
if not inspect.isroutine(callable_):
544+
raise ValueError(
545+
f"!applyref:{callable_string} should be a callable, but is {callable_}"
546+
)
547+
545548
try:
546549
args, kwargs = _get_args(node)
547550
out = callable_(*args, **kwargs)

tests/test_hyperyaml.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,47 @@ def test_load_hyperpyyaml(tmpdir):
177177
assert things["c"].kwargs["thing1"]() == "a string"
178178
assert things["c"].specific_key() == "a string"
179179

180+
# Applyref tag
181+
yaml = """
182+
a: 1
183+
b: 2
184+
c: !applyref:sum [[!ref <a>, !ref <b>]]
185+
d: !ref <c>-<c>
186+
"""
187+
things = load_hyperpyyaml(yaml)
188+
assert things["d"] == 0
189+
190+
# Applyref method
191+
yaml = """
192+
a: "A STRING"
193+
common_kwargs:
194+
thing1: !ref <a.lower>
195+
thing2: 2
196+
c: !applyref:hyperpyyaml.TestThing.from_keys
197+
args:
198+
- 1
199+
- 2
200+
kwargs: !ref <common_kwargs>
201+
"""
202+
things = load_hyperpyyaml(yaml)
203+
assert things["c"][:12] == "<hyperpyyaml"
204+
205+
yaml = """
206+
a: "A STRING"
207+
common_kwargs:
208+
thing1: !ref <a.lower>
209+
thing2: 2
210+
c: !applyref:hyperpyyaml.TestThing.from_keys
211+
_args: []
212+
_kwargs:
213+
args:
214+
- 1
215+
- 2
216+
kwargs: !ref <common_kwargs>
217+
"""
218+
things = load_hyperpyyaml(yaml)
219+
assert things["c"][:12] == "<hyperpyyaml"
220+
180221
# Refattr:
181222
yaml = """
182223
thing1: "A string"

0 commit comments

Comments
 (0)