-
Notifications
You must be signed in to change notification settings - Fork 745
Expand file tree
/
Copy pathbricks.py
More file actions
710 lines (591 loc) · 26.8 KB
/
bricks.py
File metadata and controls
710 lines (591 loc) · 26.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
# Copyright 2019 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tasks involving assembly and/or disassembly of bricks."""
import collections
from absl import logging
from dm_control import composer
from dm_control.composer import initializers
from dm_control.composer import variation
from dm_control.composer.observation import observable
from dm_control.composer.variation import distributions
from dm_control.entities import props
from dm_control.manipulation.shared import arenas
from dm_control.manipulation.shared import cameras
from dm_control.manipulation.shared import constants
from dm_control.manipulation.shared import observations
from dm_control.manipulation.shared import registry
from dm_control.manipulation.shared import robots
from dm_control.manipulation.shared import tags
from dm_control.manipulation.shared import workspaces
from dm_control.mujoco.wrapper import mjbindings
from dm_control.utils import rewards
import numpy as np
mjlib = mjbindings.mjlib
_BrickWorkspace = collections.namedtuple(
'_BrickWorkspace',
['prop_bbox', 'tcp_bbox', 'goal_hint_pos', 'goal_hint_quat', 'arm_offset'])
# Ensures that the prop does not collide with the table during initialization.
_PROP_Z_OFFSET = 1e-6
_WORKSPACE = _BrickWorkspace(
prop_bbox=workspaces.BoundingBox(
lower=(-0.1, -0.1, _PROP_Z_OFFSET),
upper=(0.1, 0.1, _PROP_Z_OFFSET)),
tcp_bbox=workspaces.BoundingBox(
lower=(-0.1, -0.1, 0.15),
upper=(0.1, 0.1, 0.4)),
goal_hint_pos=(0.2, 0.1, 0.),
goal_hint_quat=(-0.38268343, 0., 0., 0.92387953),
arm_offset=robots.ARM_OFFSET)
# Alpha value of the visual goal hint representing the goal state for each task.
_HINT_ALPHA = 0.75
# Distance thresholds for the shaping rewards for getting the top brick close
# to the bottom brick, and for 'clicking' them together.
_CLOSE_THRESHOLD = 0.01
_CLICK_THRESHOLD = 0.001
# Sequence of colors for the brick(s).
_COLOR_VALUES, _COLOR_NAMES = list(
zip(
((1., 0., 0.), 'red'),
((0., 1., 0.), 'green'),
((0., 0., 1.), 'blue'),
((0., 1., 1.), 'cyan'),
((1., 0., 1.), 'magenta'),
((1., 1., 0.), 'yellow'),
))
class _Common(composer.Task):
"""Common components of brick tasks."""
def __init__(self,
arena,
arm,
hand,
num_bricks,
obs_settings,
workspace,
control_timestep):
if not 2 <= num_bricks <= 6:
raise ValueError('`num_bricks` must be between 2 and 6, got {}.'
.format(num_bricks))
if num_bricks > 3:
# The default values computed by MuJoCo's compiler are too small if there
# are more than three stacked bricks, since each stacked pair generates
# a large number of contacts. The values below are sufficient for up to
# 6 stacked bricks.
# TODO(b/78331644): It may be useful to log the size of `physics.model`
# and `physics.data` after compilation to gauge the
# impact of these changes on MuJoCo's memory footprint.
arena.mjcf_model.size.nconmax = 400
arena.mjcf_model.size.njmax = 1200
self._arena = arena
self._arm = arm
self._hand = hand
self._arm.attach(self._hand)
self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
self.control_timestep = control_timestep
# Add custom camera observable.
self._task_observables = cameras.add_camera_observables(
arena, obs_settings, cameras.FRONT_CLOSE)
color_sequence = iter(_COLOR_VALUES)
brick_obs_options = observations.make_options(
obs_settings, observations.FREEPROP_OBSERVABLES)
bricks = []
brick_frames = []
goal_hint_bricks = []
for _ in range(num_bricks):
color = next(color_sequence)
brick = props.Duplo(color=color,
observable_options=brick_obs_options)
brick_frames.append(arena.add_free_entity(brick))
bricks.append(brick)
# Translucent, contactless brick with no observables. These are used to
# provide a visual hint representing the goal state for each task.
hint_brick = props.Duplo(color=color)
_hintify(hint_brick, alpha=_HINT_ALPHA)
arena.attach(hint_brick)
goal_hint_bricks.append(hint_brick)
self._bricks = bricks
self._brick_frames = brick_frames
self._goal_hint_bricks = goal_hint_bricks
# Position and quaternion for the goal hint.
self._goal_hint_pos = workspace.goal_hint_pos
self._goal_hint_quat = workspace.goal_hint_quat
self._tcp_initializer = initializers.ToolCenterPointInitializer(
self._hand, self._arm,
position=distributions.Uniform(*workspace.tcp_bbox),
quaternion=workspaces.DOWN_QUATERNION)
# Add sites for visual debugging.
workspaces.add_bbox_site(
body=self.root_entity.mjcf_model.worldbody,
lower=workspace.tcp_bbox.lower,
upper=workspace.tcp_bbox.upper,
rgba=constants.GREEN, name='tcp_spawn_area')
workspaces.add_bbox_site(
body=self.root_entity.mjcf_model.worldbody,
lower=workspace.prop_bbox.lower,
upper=workspace.prop_bbox.upper,
rgba=constants.BLUE, name='prop_spawn_area')
@property
def task_observables(self):
return self._task_observables
@property
def root_entity(self):
return self._arena
@property
def arm(self):
return self._arm
@property
def hand(self):
return self._hand
class Stack(_Common):
"""Build a stack of Duplo bricks."""
def __init__(self,
arena,
arm,
hand,
num_bricks,
target_height,
moveable_base,
randomize_order,
obs_settings,
workspace,
control_timestep):
"""Initializes a new `Stack` task.
Args:
arena: `composer.Entity` instance.
arm: `robot_base.RobotArm` instance.
hand: `robot_base.RobotHand` instance.
num_bricks: The total number of bricks; must be between 2 and 6.
target_height: The target number of bricks in the stack in order to get
maximum reward. Must be between 2 and `num_bricks`.
moveable_base: Boolean specifying whether or not the bottom brick should
be moveable.
randomize_order: Boolean specifying whether to randomize the desired order
of bricks in the stack at the start of each episode.
obs_settings: `observations.ObservationSettings` instance.
workspace: A `_BrickWorkspace` instance.
control_timestep: Float specifying the control timestep in seconds.
Raises:
ValueError: If `num_bricks` is not between 2 and 6, or if
`target_height` is not between 2 and `num_bricks - 1`.
"""
if not 2 <= target_height <= num_bricks:
raise ValueError('`target_height` must be between 2 and {}, got {}.'
.format(num_bricks, target_height))
super().__init__(
arena=arena,
arm=arm,
hand=hand,
num_bricks=num_bricks,
obs_settings=obs_settings,
workspace=workspace,
control_timestep=control_timestep)
self._moveable_base = moveable_base
self._randomize_order = randomize_order
self._target_height = target_height
self._prop_bbox = workspace.prop_bbox
# Shuffled at the start of each episode if `randomize_order` is True.
self._desired_order = np.arange(target_height)
# In the random order case, create a `prop_pose` observable that informs the
# agent of the desired order.
if randomize_order:
desired_order_observable = observable.Generic(self._get_desired_order)
desired_order_observable.configure(**obs_settings.prop_pose._asdict())
self._task_observables['desired_order'] = desired_order_observable
def _get_desired_order(self, physics):
del physics # Unused
return self._desired_order.astype(np.double)
def initialize_episode_mjcf(self, random_state):
if self._randomize_order:
self._desired_order = random_state.choice(
len(self._bricks), size=self._target_height, replace=False)
logging.info('Desired stack order (from bottom to top): [%s]',
' '.join(_COLOR_NAMES[i] for i in self._desired_order))
# If the base of the stack should be fixed, remove the freejoint for the
# first brick (and ensure that all the others have freejoints).
fixed_indices = [] if self._moveable_base else [self._desired_order[0]]
_add_or_remove_freejoints(attachment_frames=self._brick_frames,
fixed_indices=fixed_indices)
# We need to define the prop initializer for the bricks here rather than in
# the `__init__`, since `PropPlacer` looks for freejoints on instantiation.
self._brick_placer = initializers.PropPlacer(
props=self._bricks,
position=distributions.Uniform(*self._prop_bbox),
quaternion=workspaces.uniform_z_rotation,
settle_physics=True)
def initialize_episode(self, physics, random_state):
self._brick_placer(physics, random_state)
self._hand.set_grasp(physics, close_factors=random_state.uniform())
self._tcp_initializer(physics, random_state)
# Arrange the goal hint bricks in the desired stack order.
_build_stack(physics,
bricks=self._goal_hint_bricks,
base_pos=self._goal_hint_pos,
base_quat=self._goal_hint_quat,
order=self._desired_order,
random_state=random_state)
def get_reward(self, physics):
pairs = list(zip(self._desired_order[:-1], self._desired_order[1:]))
pairwise_rewards = _get_pairwise_stacking_rewards(
physics=physics, bricks=self._bricks, pairs=pairs)
# The final reward is an average over the pairwise rewards.
return np.mean(pairwise_rewards)
class Reassemble(_Common):
"""Disassemble a stack of Duplo bricks and reassemble it in another order."""
def __init__(self,
arena,
arm,
hand,
num_bricks,
randomize_initial_order,
randomize_desired_order,
obs_settings,
workspace,
control_timestep):
"""Initializes a new `Reassemble` task.
Args:
arena: `composer.Entity` instance.
arm: `robot_base.RobotArm` instance.
hand: `robot_base.RobotHand` instance.
num_bricks: The total number of bricks; must be between 2 and 6.
randomize_initial_order: Boolean specifying whether to randomize the
initial order of bricks in the stack at the start of each episode.
randomize_desired_order: Boolean specifying whether to independently
randomize the desired order of bricks in the stack at the start of each
episode. By default the desired order will be the reverse of the initial
order, with the exception of the base brick which is always the same as
in the initial order since it is welded in place.
obs_settings: `observations.ObservationSettings` instance.
workspace: A `_BrickWorkspace` instance.
control_timestep: Float specifying the control timestep in seconds.
Raises:
ValueError: If `num_bricks` is not between 2 and 6.
"""
super().__init__(
arena=arena,
arm=arm,
hand=hand,
num_bricks=num_bricks,
obs_settings=obs_settings,
workspace=workspace,
control_timestep=control_timestep)
self._randomize_initial_order = randomize_initial_order
self._randomize_desired_order = randomize_desired_order
# Randomized at the start of each episode if `randomize_initial_order` is
# True.
self._initial_order = np.arange(num_bricks)
# Randomized at the start of each episode if `randomize_desired_order` is
# True.
self._desired_order = self._initial_order.copy()
self._desired_order[1:] = self._desired_order[-1:0:-1]
# In the random order case, create a `prop_pose` observable that informs the
# agent of the desired order.
if randomize_desired_order:
desired_order_observable = observable.Generic(self._get_desired_order)
desired_order_observable.configure(**obs_settings.prop_pose._asdict())
self._task_observables['desired_order'] = desired_order_observable
# Distributions of positions and orientations for the base of the stack.
self._base_pos = distributions.Uniform(*workspace.prop_bbox)
self._base_quat = workspaces.uniform_z_rotation
def _get_desired_order(self, physics):
del physics # Unused
return self._desired_order.astype(np.double)
def initialize_episode_mjcf(self, random_state):
if self._randomize_initial_order:
random_state.shuffle(self._initial_order)
# The bottom brick will be fixed to the table, so it must be the same in
# both the initial and desired order.
self._desired_order[0] = self._initial_order[0]
# By default the desired order of the other bricks is the opposite of their
# initial order.
self._desired_order[1:] = self._initial_order[-1:0:-1]
if self._randomize_desired_order:
random_state.shuffle(self._desired_order[1:])
logging.info('Desired stack order (from bottom to top): [%s]',
' '.join(_COLOR_NAMES[i] for i in self._desired_order))
# Remove the freejoint from the bottom brick in the stack.
_add_or_remove_freejoints(attachment_frames=self._brick_frames,
fixed_indices=[self._initial_order[0]])
def initialize_episode(self, physics, random_state):
# Build the initial stack.
_build_stack(physics,
bricks=self._bricks,
base_pos=self._base_pos,
base_quat=self._base_quat,
order=self._initial_order,
random_state=random_state)
# Arrange the goal hint bricks into a stack with the desired order.
_build_stack(physics,
bricks=self._goal_hint_bricks,
base_pos=self._goal_hint_pos,
base_quat=self._goal_hint_quat,
order=self._desired_order,
random_state=random_state)
self._hand.set_grasp(physics, close_factors=random_state.uniform())
self._tcp_initializer(physics, random_state)
def get_reward(self, physics):
pairs = list(zip(self._desired_order[:-1], self._desired_order[1:]))
# We set `close_coef=0.` because the coarse shaping reward causes problems
# for this task (it means there is a strong disincentive to break up the
# initial stack).
pairwise_rewards = _get_pairwise_stacking_rewards(
physics=physics,
bricks=self._bricks,
pairs=pairs,
close_coef=0.)
# The final reward is an average over the pairwise rewards.
return np.mean(pairwise_rewards)
def _distance(pos1, pos2):
diff = pos1 - pos2
return sum(np.sqrt((diff * diff).sum(1)))
def _min_stud_to_hole_distance(physics, bottom_brick, top_brick):
# Positions of the top left and bottom right studs on the `bottom_brick` and
# the top left and bottom right holes on the `top_brick`.
stud_pos = physics.bind(bottom_brick.studs[[0, -1], [0, -1]]).xpos
hole_pos = physics.bind(top_brick.holes[[0, -1], [0, -1]]).xpos
# Bricks are rotationally symmetric, so we compute top left -> top left and
# top left -> bottom right distances and return whichever of these is smaller.
dist1 = _distance(stud_pos, hole_pos)
dist2 = _distance(stud_pos[::-1], hole_pos)
return min(dist1, dist2)
def _get_pairwise_stacking_rewards(physics, bricks, pairs, close_coef=0.1):
"""Returns a vector of shaping reward components based on pairwise distances.
Args:
physics: An `mjcf.Physics` instance.
bricks: A list of `composer.Entity` instances corresponding to bricks.
pairs: A list of `(bottom_idx, top_idx)` tuples specifying which pairs of
bricks should be measured.
close_coef: Float specfying the relative weight given to the coarse-
tolerance shaping component for getting the bricks close to one another
(as opposed to the fine-tolerance component for clicking them together).
Returns:
A numpy array of size `len(pairs)` containing values in (0, 1], where
1 corresponds to a stacked pair of bricks.
"""
distances = []
for bottom_idx, top_idx in pairs:
bottom_brick = bricks[bottom_idx]
top_brick = bricks[top_idx]
distances.append(
_min_stud_to_hole_distance(physics, bottom_brick, top_brick))
distances = np.hstack(distances)
# Coarse-tolerance component for bringing the holes close to the studs.
close = rewards.tolerance(
distances, bounds=(0, _CLOSE_THRESHOLD), margin=(_CLOSE_THRESHOLD * 10))
# Fine-tolerance component for clicking the bricks together.
clicked = rewards.tolerance(
distances, bounds=(0, _CLICK_THRESHOLD), margin=_CLICK_THRESHOLD)
# Weighted average of coarse and fine components for each pair of bricks.
return np.average([close, clicked], weights=[close_coef, 1.], axis=0)
def _build_stack(physics, bricks, base_pos, base_quat, order, random_state):
"""Builds a stack of bricks.
Args:
physics: Instance of `mjcf.Physics`.
bricks: Sequence of `composer.Entity` instances corresponding to bricks.
base_pos: Position of the base brick in the stack; either a (3,) numpy array
or a `variation.Variation` that yields such arrays.
base_quat: Quaternion of the base brick in the stack; either a (4,) numpy
array or a `variation.Variation` that yields such arrays.
order: Sequence of indices specifying the order in which to stack the
bricks.
random_state: An `np.random.RandomState` instance.
"""
base_pos = variation.evaluate(base_pos, random_state=random_state)
base_quat = variation.evaluate(base_quat, random_state=random_state)
bricks[order[0]].set_pose(physics, position=base_pos, quaternion=base_quat)
for bottom_idx, top_idx in zip(order[:-1], order[1:]):
bottom = bricks[bottom_idx]
top = bricks[top_idx]
stud_pos = physics.bind(bottom.studs[0, 0]).xpos
_, quat = bottom.get_pose(physics)
# The reward function treats top left -> top left and top left -> bottom
# right configurations as identical, so the orientations of the bricks are
# randomized so that 50% of the time the top brick is rotated 180 degrees
# relative to the brick below.
if random_state.rand() < 0.5:
quat = quat.copy()
axis = np.array([0., 0., 1.])
angle = np.pi
mjlib.mju_quatIntegrate(quat, axis, angle)
hole_idx = (-1, -1)
else:
hole_idx = (0, 0)
top.set_pose(physics, quaternion=quat)
# Set the position of the top brick so that its holes line up with the studs
# of the brick below.
offset = physics.bind(top.holes[hole_idx]).xpos
top_pos = stud_pos - offset
top.set_pose(physics, position=top_pos)
def _add_or_remove_freejoints(attachment_frames, fixed_indices):
"""Adds or removes freejoints from props.
Args:
attachment_frames: A list of `mjcf.Elements` corresponding to attachment
frames.
fixed_indices: A list of indices of attachment frames that should be fixed
to the world (i.e. have their freejoints removed). Freejoints will be
added to all other elements in `attachment_frames` if they do not already
possess them.
"""
for i, frame in enumerate(attachment_frames):
if i in fixed_indices:
if frame.freejoint:
frame.freejoint.remove()
elif not frame.freejoint:
frame.add('freejoint')
def _replace_alpha(rgba, alpha=0.3):
new_rgba = rgba.copy()
new_rgba[3] = alpha
return new_rgba
def _hintify(entity, alpha=None):
"""Modifies an entity for use as a 'visual hint'.
Contacts will be disabled for all geoms within the entity, and its bodies will
be converted to "mocap" bodies (which are viewed as fixed from the perspective
of the dynamics). The geom alpha values may also be overridden to render the
geoms as translucent.
Args:
entity: A `composer.Entity`, modified in place.
alpha: Optional float between 0 and 1, used to override the alpha values for
all of the geoms in this entity.
"""
for subentity in entity.iter_entities():
# TODO(b/112084359): This assumes that all geoms either define explicit RGBA
# values, or inherit from the top-level default. It will
# not correctly handle more complicated hierarchies of
# default classes.
if (alpha is not None
and subentity.mjcf_model.default.geom is not None
and subentity.mjcf_model.default.geom.rgba is not None):
subentity.mjcf_model.default.geom.rgba = _replace_alpha(
subentity.mjcf_model.default.geom.rgba, alpha=alpha)
for body in subentity.mjcf_model.find_all('body'):
body.mocap = 'true'
for geom in subentity.mjcf_model.find_all('geom'):
if alpha is not None and geom.rgba is not None:
geom.rgba = _replace_alpha(geom.rgba, alpha=alpha)
geom.contype = 0
geom.conaffinity = 0
def _stack(obs_settings, num_bricks, moveable_base, randomize_order,
target_height=None):
"""Configure and instantiate a Stack task.
Args:
obs_settings: `observations.ObservationSettings` instance.
num_bricks: The total number of bricks; must be between 2 and 6.
moveable_base: Boolean specifying whether or not the bottom brick should
be moveable.
randomize_order: Boolean specifying whether to randomize the desired order
of bricks in the stack at the start of each episode.
target_height: The target number of bricks in the stack in order to get
maximum reward. Must be between 2 and `num_bricks`. Defaults to
`num_bricks`.
Returns:
An instance of `Stack`.
"""
if target_height is None:
target_height = num_bricks
arena = arenas.Standard()
arm = robots.make_arm(obs_settings=obs_settings)
hand = robots.make_hand(obs_settings=obs_settings)
return Stack(arena=arena,
arm=arm,
hand=hand,
num_bricks=num_bricks,
target_height=target_height,
moveable_base=moveable_base,
randomize_order=randomize_order,
obs_settings=obs_settings,
workspace=_WORKSPACE,
control_timestep=constants.CONTROL_TIMESTEP)
@registry.add(tags.FEATURES)
def stack_2_bricks_features():
return _stack(obs_settings=observations.PERFECT_FEATURES, num_bricks=2,
moveable_base=False, randomize_order=False)
@registry.add(tags.VISION)
def stack_2_bricks_vision():
return _stack(obs_settings=observations.VISION, num_bricks=2,
moveable_base=False, randomize_order=False)
@registry.add(tags.FEATURES)
def stack_2_bricks_moveable_base_features():
return _stack(obs_settings=observations.PERFECT_FEATURES, num_bricks=2,
moveable_base=True, randomize_order=False)
@registry.add(tags.VISION)
def stack_2_bricks_moveable_base_vision():
return _stack(obs_settings=observations.VISION, num_bricks=2,
moveable_base=True, randomize_order=False)
@registry.add(tags.FEATURES)
def stack_3_bricks_features():
return _stack(obs_settings=observations.PERFECT_FEATURES, num_bricks=3,
moveable_base=False, randomize_order=False)
@registry.add(tags.VISION)
def stack_3_bricks_vision():
return _stack(obs_settings=observations.VISION, num_bricks=3,
moveable_base=False, randomize_order=False)
@registry.add(tags.FEATURES)
def stack_3_bricks_random_order_features():
return _stack(obs_settings=observations.PERFECT_FEATURES, num_bricks=3,
moveable_base=False, randomize_order=True)
@registry.add(tags.FEATURES)
def stack_2_of_3_bricks_random_order_features():
return _stack(obs_settings=observations.PERFECT_FEATURES, num_bricks=3,
moveable_base=False, randomize_order=True, target_height=2)
@registry.add(tags.VISION)
def stack_2_of_3_bricks_random_order_vision():
return _stack(obs_settings=observations.VISION, num_bricks=3,
moveable_base=False, randomize_order=True, target_height=2)
def _reassemble(obs_settings, num_bricks, randomize_initial_order,
randomize_desired_order):
"""Configure and instantiate a `Reassemble` task.
Args:
obs_settings: `observations.ObservationSettings` instance.
num_bricks: The total number of bricks; must be between 2 and 6.
randomize_initial_order: Boolean specifying whether to randomize the
initial order of bricks in the stack at the start of each episode.
randomize_desired_order: Boolean specifying whether to independently
randomize the desired order of bricks in the stack at the start of each
episode. By default the desired order will be the reverse of the initial
order, with the exception of the base brick which is always the same as
in the initial order since it is welded in place.
Returns:
An instance of `Reassemble`.
"""
arena = arenas.Standard()
arm = robots.make_arm(obs_settings=obs_settings)
hand = robots.make_hand(obs_settings=obs_settings)
return Reassemble(arena=arena,
arm=arm,
hand=hand,
num_bricks=num_bricks,
randomize_initial_order=randomize_initial_order,
randomize_desired_order=randomize_desired_order,
obs_settings=obs_settings,
workspace=_WORKSPACE,
control_timestep=constants.CONTROL_TIMESTEP)
@registry.add(tags.FEATURES)
def reassemble_3_bricks_fixed_order_features():
return _reassemble(obs_settings=observations.PERFECT_FEATURES, num_bricks=3,
randomize_initial_order=False,
randomize_desired_order=False)
@registry.add(tags.VISION)
def reassemble_3_bricks_fixed_order_vision():
return _reassemble(obs_settings=observations.VISION, num_bricks=3,
randomize_initial_order=False,
randomize_desired_order=False)
@registry.add(tags.FEATURES)
def reassemble_5_bricks_random_order_features():
return _reassemble(obs_settings=observations.PERFECT_FEATURES, num_bricks=5,
randomize_initial_order=True,
randomize_desired_order=True)
@registry.add(tags.VISION)
def reassemble_5_bricks_random_order_vision():
return _reassemble(obs_settings=observations.VISION, num_bricks=5,
randomize_initial_order=True,
randomize_desired_order=True)