-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathant_maze.py
319 lines (261 loc) · 10.6 KB
/
ant_maze.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import os
from typing import Tuple
from brax import base
from brax import math
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
import jax
from jax import numpy as jnp
import mujoco
import xml.etree.ElementTree as ET
# This is based on original Ant environment from Brax
# https://github.com/google/brax/blob/main/brax/envs/ant.py
# Maze creation dapted from: https://github.com/Farama-Foundation/D4RL/blob/master/d4rl/locomotion/maze_env.py
RESET = R = "r"
GOAL = G = "g"
U_MAZE = [
[1, 1, 1, 1, 1],
[1, R, G, G, 1],
[1, 1, 1, G, 1],
[1, G, G, G, 1],
[1, 1, 1, 1, 1],
]
U_MAZE_EVAL = [
[1, 1, 1, 1, 1],
[1, R, 0, 0, 1],
[1, 1, 1, 0, 1],
[1, G, G, G, 1],
[1, 1, 1, 1, 1],
]
BIG_MAZE = [
[1, 1, 1, 1, 1, 1, 1, 1],
[1, R, G, 1, 1, G, G, 1],
[1, G, G, 1, G, G, G, 1],
[1, 1, G, G, G, 1, 1, 1],
[1, G, G, 1, G, G, G, 1],
[1, G, 1, G, G, 1, G, 1],
[1, G, G, G, 1, G, G, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
]
BIG_MAZE_EVAL = [
[1, 1, 1, 1, 1, 1, 1, 1],
[1, R, 0, 1, 1, G, G, 1],
[1, 0, 0, 1, 0, G, G, 1],
[1, 1, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 1, 0, 0, 0, 1],
[1, 0, 1, G, 0, 1, G, 1],
[1, 0, G, G, 1, G, G, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
]
HARDEST_MAZE = [
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, R, G, G, G, 1, G, G, G, G, G, 1],
[1, G, 1, 1, G, 1, G, 1, G, 1, G, 1],
[1, G, G, G, G, G, G, 1, G, G, G, 1],
[1, G, 1, 1, 1, 1, G, 1, 1, 1, G, 1],
[1, G, G, 1, G, 1, G, G, G, G, G, 1],
[1, 1, G, 1, G, 1, G, 1, G, 1, 1, 1],
[1, G, G, 1, G, G, G, 1, G, G, G, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]
MAZE_HEIGHT = 0.5
def find_starts(structure, size_scaling):
starts = []
for i in range(len(structure)):
for j in range(len(structure[0])):
if structure[i][j] == RESET:
starts.append([i * size_scaling, j * size_scaling])
return jnp.array(starts)
def find_goals(structure, size_scaling):
goals = []
for i in range(len(structure)):
for j in range(len(structure[0])):
if structure[i][j] == GOAL:
goals.append([i * size_scaling, j * size_scaling])
return jnp.array(goals)
# Create a xml with maze and a list of possible goal positions
def make_maze(maze_layout_name, maze_size_scaling):
if maze_layout_name == "u_maze":
maze_layout = U_MAZE
elif maze_layout_name == "u_maze_eval":
maze_layout = U_MAZE_EVAL
elif maze_layout_name == "big_maze":
maze_layout = BIG_MAZE
elif maze_layout_name == "big_maze_eval":
maze_layout = BIG_MAZE_EVAL
elif maze_layout_name == "hardest_maze":
maze_layout = HARDEST_MAZE
else:
raise ValueError(f"Unknown maze layout: {maze_layout_name}")
xml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "assets", "ant_maze.xml")
possible_starts = find_starts(maze_layout, maze_size_scaling)
possible_goals = find_goals(maze_layout, maze_size_scaling)
tree = ET.parse(xml_path)
worldbody = tree.find(".//worldbody")
for i in range(len(maze_layout)):
for j in range(len(maze_layout[0])):
struct = maze_layout[i][j]
if struct == 1:
ET.SubElement(
worldbody,
"geom",
name="block_%d_%d" % (i, j),
pos="%f %f %f"
% (i * maze_size_scaling, j * maze_size_scaling, MAZE_HEIGHT / 2 * maze_size_scaling),
size="%f %f %f"
% (0.5 * maze_size_scaling, 0.5 * maze_size_scaling, MAZE_HEIGHT / 2 * maze_size_scaling),
type="box",
material="",
contype="1",
conaffinity="1",
rgba="0.7 0.5 0.3 1.0",
)
tree = tree.getroot()
xml_string = ET.tostring(tree)
return xml_string, possible_starts, possible_goals
class AntMaze(PipelineEnv):
def __init__(
self,
ctrl_cost_weight=0.5,
use_contact_forces=False,
contact_cost_weight=5e-4,
healthy_reward=1.0,
terminate_when_unhealthy=True,
healthy_z_range=(0.2, 1.0),
contact_force_range=(-1.0, 1.0),
reset_noise_scale=0.1,
exclude_current_positions_from_observation=False,
backend="generalized",
maze_layout_name="u_maze",
maze_size_scaling=4.0,
dense_reward: bool = False,
**kwargs,
):
xml_string, possible_starts, possible_goals = make_maze(maze_layout_name, maze_size_scaling)
sys = mjcf.loads(xml_string)
self.possible_starts = possible_starts
self.possible_goals = possible_goals
n_frames = 5
if backend in ["spring", "positional"]:
sys = sys.tree_replace({"opt.timestep": 0.005})
n_frames = 10
if backend == "mjx":
sys = sys.tree_replace(
{
"opt.solver": mujoco.mjtSolver.mjSOL_NEWTON,
"opt.disableflags": mujoco.mjtDisableBit.mjDSBL_EULERDAMP,
"opt.iterations": 1,
"opt.ls_iterations": 4,
}
)
if backend == "positional":
# TODO: does the same actuator strength work as in spring
sys = sys.replace(actuator=sys.actuator.replace(gear=200 * jnp.ones_like(sys.actuator.gear)))
kwargs["n_frames"] = kwargs.get("n_frames", n_frames)
super().__init__(sys=sys, backend=backend, **kwargs)
self._ctrl_cost_weight = ctrl_cost_weight
self._use_contact_forces = use_contact_forces
self._contact_cost_weight = contact_cost_weight
self._healthy_reward = healthy_reward
self._terminate_when_unhealthy = terminate_when_unhealthy
self._healthy_z_range = healthy_z_range
self._contact_force_range = contact_force_range
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
self.dense_reward = dense_reward
self.state_dim = 29
self.goal_indices = jnp.array([0, 1])
self.goal_reach_thresh = 0.5
if self._use_contact_forces:
raise NotImplementedError("use_contact_forces not implemented.")
def reset(self, rng: jax.Array) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2, rng3 = jax.random.split(rng, 4)
low, hi = -self._reset_noise_scale, self._reset_noise_scale
q = self.sys.init_q + jax.random.uniform(rng, (self.sys.q_size(),), minval=low, maxval=hi)
qd = hi * jax.random.normal(rng1, (self.sys.qd_size(),))
# set the start and target q, qd
start = self._random_start(rng2)
q = q.at[:2].set(start)
target = self._random_target(rng3)
q = q.at[-2:].set(target)
qd = qd.at[-2:].set(0)
pipeline_state = self.pipeline_init(q, qd)
obs = self._get_obs(pipeline_state)
reward, done, zero = jnp.zeros(3)
metrics = {
"reward_forward": zero,
"reward_survive": zero,
"reward_ctrl": zero,
"reward_contact": zero,
"x_position": zero,
"y_position": zero,
"distance_from_origin": zero,
"x_velocity": zero,
"y_velocity": zero,
"forward_reward": zero,
"dist": zero,
"success": zero,
"success_easy": zero,
}
state = State(pipeline_state, obs, reward, done, metrics)
return state
def step(self, state: State, action: jax.Array) -> State:
"""Run one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)
velocity = (pipeline_state.x.pos[0] - pipeline_state0.x.pos[0]) / self.dt
forward_reward = velocity[0]
min_z, max_z = self._healthy_z_range
is_healthy = jnp.where(pipeline_state.x.pos[0, 2] < min_z, 0.0, 1.0)
is_healthy = jnp.where(pipeline_state.x.pos[0, 2] > max_z, 0.0, is_healthy)
if self._terminate_when_unhealthy:
healthy_reward = self._healthy_reward
else:
healthy_reward = self._healthy_reward * is_healthy
ctrl_cost = self._ctrl_cost_weight * jnp.sum(jnp.square(action))
contact_cost = 0.0
old_obs = self._get_obs(pipeline_state0)
old_dist = jnp.linalg.norm(old_obs[:2] - old_obs[-2:])
obs = self._get_obs(pipeline_state)
dist = jnp.linalg.norm(obs[:2] - obs[-2:])
vel_to_target = (old_dist - dist) / self.dt
success = jnp.array(dist < self.goal_reach_thresh, dtype=float)
success_easy = jnp.array(dist < 2.0, dtype=float)
if self.dense_reward:
reward = 10 * vel_to_target + healthy_reward - ctrl_cost - contact_cost
else:
reward = success
done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
state.metrics.update(
reward_forward=forward_reward,
reward_survive=healthy_reward,
reward_ctrl=-ctrl_cost,
reward_contact=-contact_cost,
x_position=pipeline_state.x.pos[0, 0],
y_position=pipeline_state.x.pos[0, 1],
distance_from_origin=math.safe_norm(pipeline_state.x.pos[0]),
x_velocity=velocity[0],
y_velocity=velocity[1],
forward_reward=forward_reward,
dist=dist,
success=success,
success_easy=success_easy,
)
return state.replace(pipeline_state=pipeline_state, obs=obs, reward=reward, done=done)
def _get_obs(self, pipeline_state: base.State) -> jax.Array:
"""Observe ant body position and velocities."""
qpos = pipeline_state.q[:-2]
qvel = pipeline_state.qd[:-2]
target_pos = pipeline_state.x.pos[-1][:2]
if self._exclude_current_positions_from_observation:
qpos = qpos[2:]
return jnp.concatenate([qpos] + [qvel] + [target_pos])
def _random_target(self, rng: jax.Array) -> jax.Array:
"""Returns a random target location chosen from possibilities specified in the maze layout."""
idx = jax.random.randint(rng, (1,), 0, len(self.possible_goals))
return jnp.array(self.possible_goals[idx])[0]
def _random_start(self, rng: jax.Array) -> jax.Array:
"""Returns a random start location chosen from possibilities specified in the maze layout."""
idx = jax.random.randint(rng, (1,), 0, len(self.possible_starts))
return jnp.array(self.possible_starts[idx])[0]