Skip to content

Commit 5100997

Browse files
committed
Add dm_control.composer, a Python framework for constructing MuJoCo-based RL environments
Note: this is a dummy revision ID that is sufficiently new to include all the changes here. This change doesn't correspond to any single revision in Piper. PiperOrigin-RevId: 216541434
1 parent 62ae972 commit 5100997

35 files changed

Lines changed: 4682 additions & 0 deletions

dm_control/composer/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2018 The dm_control Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
16+
"""Module containing abstract base classes for Composer environments."""
17+
18+
from dm_control.composer.arena import Arena
19+
from dm_control.composer.constants import * # pylint: disable=wildcard-import
20+
from dm_control.composer.define import cached_property
21+
from dm_control.composer.define import observable
22+
from dm_control.composer.entity import Entity
23+
from dm_control.composer.entity import FreePropObservableMixin
24+
from dm_control.composer.entity import ModelWrapperEntity
25+
from dm_control.composer.entity import Observables
26+
from dm_control.composer.environment import Environment
27+
from dm_control.composer.environment import HOOK_NAMES
28+
from dm_control.composer.initializer import Initializer
29+
from dm_control.composer.robot import Robot
30+
from dm_control.composer.task import NullTask
31+
from dm_control.composer.task import Task

dm_control/composer/arena.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2018 The dm_control Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
16+
"""The base empty arena that defines global settings for Composer."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
24+
from dm_control import mjcf
25+
from dm_control.composer import entity as entity_module
26+
27+
_ARENA_XML_PATH = os.path.join(os.path.dirname(__file__), 'arena.xml')
28+
29+
30+
class Arena(entity_module.Entity):
31+
"""The base empty arena that defines global settings for Composer."""
32+
33+
def _build(self, name=None):
34+
"""Initializes this arena.
35+
36+
Args:
37+
name: (optional) A string, the name of this arena. If `None`, use the
38+
model name defined in the MJCF file.
39+
"""
40+
self._mjcf_root = mjcf.from_path(_ARENA_XML_PATH)
41+
if name:
42+
self._mjcf_root.model = name
43+
44+
def add_free_entity(self, entity):
45+
"""Includes an entity in the arena as a free-moving body."""
46+
frame = self.attach(entity)
47+
frame.add('freejoint')
48+
return frame
49+
50+
@property
51+
def mjcf_model(self):
52+
return self._mjcf_root

dm_control/composer/arena.xml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<mujoco model="base">
2+
<compiler coordinate='local' angle='radian' eulerseq='xyz' boundmass='1e-5' boundinertia='1e-11'/>
3+
4+
<option cone='elliptic' noslip_iterations='5' noslip_tolerance='0' timestep='0.002'/>
5+
6+
<visual>
7+
<map znear='0.01'/>
8+
<headlight diffuse='.6 .6 .6' ambient='.3 .3 .3' specular='0 0 0'/>
9+
<scale forcewidth='0.01' contactwidth='0.06' contactheight='0.01' jointwidth='.01' framewidth='.01' framelength='.3'/>
10+
</visual>
11+
</mujoco>

dm_control/composer/constants.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2018 The dm_control Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
16+
"""Module defining constant values for Composer."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
SENSOR_SITES_GROUP = 4

dm_control/composer/define.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2018 The dm_control Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
16+
"""Decorators for Entity methods returning elements and observables."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import abc
23+
import threading
24+
25+
26+
class cached_property(property): # pylint: disable=invalid-name
27+
"""A property that is evaluated only once per object instance."""
28+
29+
def __init__(self, func, doc=None):
30+
super(cached_property, self).__init__(fget=func, doc=doc)
31+
self.lock = threading.RLock()
32+
33+
def __get__(self, obj, cls):
34+
if obj is None:
35+
return self
36+
name = self.fget.__name__
37+
obj_dict = obj.__dict__
38+
try:
39+
# Try returning a precomputed value without locking first.
40+
# Profiling shows that the lock takes up a non-trivial amount of time.
41+
return obj_dict[name]
42+
except KeyError:
43+
# The value hasn't been computed, now we have to lock.
44+
with self.lock:
45+
try:
46+
# Check again whether another thread has already computed the value.
47+
return obj_dict[name]
48+
except KeyError:
49+
# Otherwise call the function, cache the result, and return it
50+
return obj_dict.setdefault(name, self.fget(obj))
51+
52+
53+
# A decorator for base.Observables methods returning an observable. This
54+
# decorator should be used by abstract base classes to indicate sub-classes need
55+
# to implement a corresponding @observavble annotated method.
56+
abstract_observable = abc.abstractproperty # pylint: disable=invalid-name
57+
58+
59+
class observable(cached_property): # pylint: disable=invalid-name
60+
"""A decorator for base.Observables methods returning an observable.
61+
62+
The body of the decorated function is evaluated at Entity construction time
63+
and the observable is cached.
64+
"""
65+
pass

0 commit comments

Comments
 (0)