Skip to content

Commit 6cb8f77

Browse files
alimuldalsaran-t
authored andcommitted
Add support for evaluating arbitrary nested structures of "variation" callables
This introduces a dependency on `dm-tree`. I also corrected the name of the `dm-env` package in `setup.py` and `requirements.txt` (the module is called `dm_env`, the package is called `dm-env`). PiperOrigin-RevId: 280667371 Change-Id: I5d23f354482fa9e8cad84e88bba54f20d25cf3d3
1 parent 6ffe864 commit 6cb8f77

3 files changed

Lines changed: 18 additions & 11 deletions

File tree

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2018 The dm_control Authors.
1+
# Copyright 2019 The dm_control Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -13,22 +13,27 @@
1313
# limitations under the License.
1414
# ============================================================================
1515

16-
"""Utility function for evaluating callables or constants."""
16+
"""Utilities for handling nested structures of callables or constants."""
1717

1818
from __future__ import absolute_import
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import tree
2223

23-
def evaluate(x, *args, **kwargs):
24-
"""Evaluates a callable or constant value.
24+
25+
def evaluate(structure, *args, **kwargs):
26+
"""Evaluates a arbitrarily nested structure of callables or constant values.
2527
2628
Args:
27-
x: Either a callable or a constant value.
28-
*args: Positional arguments passed to `x` if `x` is callable.
29-
**kwargs: Keyword arguments passed to `x` if `x` is callable.
29+
structure: An arbitrarily nested structure of callables or constant values.
30+
By "structures", we mean lists, tuples, namedtuples, or dicts.
31+
*args: Positional arguments passed to each callable in `structure`.
32+
**kwargs: Keyword arguments passed to each callable in `structure.
3033
3134
Returns:
32-
Either the result of calling `x` if `x` is callable or else `x`.
35+
The same nested structure, with each callable replaced by the value returned
36+
by calling it.
3337
"""
34-
return x(*args, **kwargs) if callable(x) else x
38+
return tree.map_structure(
39+
lambda x: x(*args, **kwargs) if callable(x) else x, structure)

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
absl-py==0.7.0
2-
dm_env
2+
dm-env
3+
dm-tree
34
enum34==1.1.6
45
future==0.17.1
56
futures==3.1.1

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ def find_data_files(package_dir, patterns):
172172
install_requires=[
173173
'absl-py>=0.7.0',
174174
'enum34; python_version < "3.4"',
175-
'dm_env',
175+
'dm-env',
176+
'dm-tree',
176177
'future',
177178
'futures; python_version == "2.7"',
178179
'glfw',

0 commit comments

Comments
 (0)