@@ -37,7 +37,7 @@ const auto rollout_doc = R"(
37
37
Roll out open-loop trajectories from initial states, get resulting states and sensor values.
38
38
39
39
input arguments (required):
40
- model instance of MjModel
40
+ model list of MjModel instances of length nroll
41
41
data associated instance of MjData
42
42
nstep integer, number of steps to be taken for each trajectory
43
43
control_spec specification of controls, ncontrol = mj_stateSize(m, control_spec)
@@ -54,18 +54,18 @@ Roll out open-loop trajectories from initial states, get resulting states and se
54
54
// C-style rollout function, assumes all arguments are valid
55
55
// all input fields of d are initialised, contents at call time do not matter
56
56
// after returning, d will contain the last step of the last rollout
57
- void _unsafe_rollout (const mjModel* m, mjData* d, int nroll, int nstep, unsigned int control_spec,
57
+ void _unsafe_rollout (const mjModel** m, mjData* d, int nroll, int nstep, unsigned int control_spec,
58
58
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
59
59
mjtNum* state, mjtNum* sensordata) {
60
60
// sizes
61
- int nstate = mj_stateSize (m, mjSTATE_FULLPHYSICS);
62
- int ncontrol = mj_stateSize (m, control_spec);
63
- int nv = m->nv , nbody = m->nbody , neq = m->neq ;
64
- int nsensordata = m->nsensordata ;
61
+ int nstate = mj_stateSize (m[ 0 ] , mjSTATE_FULLPHYSICS);
62
+ int ncontrol = mj_stateSize (m[ 0 ] , control_spec);
63
+ int nv = m[ 0 ] ->nv , nbody = m[ 0 ] ->nbody , neq = m[ 0 ] ->neq ;
64
+ int nsensordata = m[ 0 ] ->nsensordata ;
65
65
66
66
// clear user inputs if unspecified
67
67
if (!(control_spec & mjSTATE_CTRL)) {
68
- mju_zero (d->ctrl , m->nu );
68
+ mju_zero (d->ctrl , m[ 0 ] ->nu );
69
69
}
70
70
if (!(control_spec & mjSTATE_QFRC_APPLIED)) {
71
71
mju_zero (d->qfrc_applied , nv);
@@ -75,26 +75,26 @@ void _unsafe_rollout(const mjModel* m, mjData* d, int nroll, int nstep, unsigned
75
75
}
76
76
if (!(control_spec & mjSTATE_MOCAP_POS)) {
77
77
for (int i = 0 ; i < nbody; i++) {
78
- int id = m->body_mocapid [i];
79
- if (id >= 0 ) mju_copy3 (d->mocap_pos +3 *id, m->body_pos +3 *i);
78
+ int id = m[ 0 ] ->body_mocapid [i];
79
+ if (id >= 0 ) mju_copy3 (d->mocap_pos +3 *id, m[ 0 ] ->body_pos +3 *i);
80
80
}
81
81
}
82
82
if (!(control_spec & mjSTATE_MOCAP_QUAT)) {
83
83
for (int i = 0 ; i < nbody; i++) {
84
- int id = m->body_mocapid [i];
85
- if (id >= 0 ) mju_copy4 (d->mocap_quat +4 *id, m->body_quat +4 *i);
84
+ int id = m[ 0 ] ->body_mocapid [i];
85
+ if (id >= 0 ) mju_copy4 (d->mocap_quat +4 *id, m[ 0 ] ->body_quat +4 *i);
86
86
}
87
87
}
88
88
if (!(control_spec & mjSTATE_EQ_ACTIVE)) {
89
89
for (int i = 0 ; i < neq; i++) {
90
- d->eq_active [i] = m->eq_active0 [i];
90
+ d->eq_active [i] = m[ 0 ] ->eq_active0 [i];
91
91
}
92
92
}
93
93
94
94
// loop over rollouts
95
95
for (int r = 0 ; r < nroll; r++) {
96
96
// set initial state
97
- mj_setState (m, d, state0 + r*nstate, mjSTATE_FULLPHYSICS);
97
+ mj_setState (m[r] , d, state0 + r*nstate, mjSTATE_FULLPHYSICS);
98
98
99
99
// set warmstart accelerations
100
100
if (warmstart0) {
@@ -124,7 +124,7 @@ void _unsafe_rollout(const mjModel* m, mjData* d, int nroll, int nstep, unsigned
124
124
for (; t < nstep; t++) {
125
125
int step = r*nstep + t;
126
126
if (state) {
127
- mj_getState (m, d, state + step*nstate, mjSTATE_FULLPHYSICS);
127
+ mj_getState (m[r] , d, state + step*nstate, mjSTATE_FULLPHYSICS);
128
128
}
129
129
if (sensordata) {
130
130
mju_copy (sensordata + step*nsensordata, d->sensordata , nsensordata);
@@ -137,15 +137,15 @@ void _unsafe_rollout(const mjModel* m, mjData* d, int nroll, int nstep, unsigned
137
137
138
138
// controls
139
139
if (control) {
140
- mj_setState (m, d, control + step*ncontrol, control_spec);
140
+ mj_setState (m[r] , d, control + step*ncontrol, control_spec);
141
141
}
142
142
143
143
// step
144
- mj_step (m, d);
144
+ mj_step (m[r] , d);
145
145
146
146
// copy out new state
147
147
if (state) {
148
- mj_getState (m, d, state + step*nstate, mjSTATE_FULLPHYSICS);
148
+ mj_getState (m[r] , d, state + step*nstate, mjSTATE_FULLPHYSICS);
149
149
}
150
150
151
151
// copy out sensor values
@@ -188,15 +188,20 @@ PYBIND11_MODULE(_rollout, pymodule) {
188
188
// get subsequent states and corresponding sensor values
189
189
pymodule.def (
190
190
" rollout" ,
191
- [](const MjModelWrapper& m, MjDataWrapper& d,
191
+ [](py::list m, MjDataWrapper& d,
192
192
int nstep, unsigned int control_spec,
193
193
const PyCArray state0,
194
194
std::optional<const PyCArray> warmstart0,
195
195
std::optional<const PyCArray> control,
196
196
std::optional<const PyCArray> state,
197
197
std::optional<const PyCArray> sensordata
198
198
) {
199
- const raw::MjModel* model = m.get ();
199
+ // get raw pointers
200
+ int nroll = state0.shape (0 );
201
+ const raw::MjModel* model_ptrs[nroll];
202
+ for (int r = 0 ; r < nroll; r++) {
203
+ model_ptrs[r] = m[r].cast <const MjModelWrapper*>()->get ();
204
+ }
200
205
raw::MjData* data = d.get ();
201
206
202
207
// check that some steps need to be taken, return if not
@@ -205,19 +210,17 @@ PYBIND11_MODULE(_rollout, pymodule) {
205
210
}
206
211
207
212
// get sizes
208
- int nstate = mj_stateSize (model, mjSTATE_FULLPHYSICS);
209
- int ncontrol = mj_stateSize (model, control_spec);
210
- int nroll = state0.shape (0 );
213
+ int nstate = mj_stateSize (model_ptrs[0 ], mjSTATE_FULLPHYSICS);
214
+ int ncontrol = mj_stateSize (model_ptrs[0 ], control_spec);
211
215
212
- // get raw pointers
213
216
mjtNum* state0_ptr = get_array_ptr (state0, " state0" , nroll, 1 , nstate);
214
217
mjtNum* warmstart0_ptr = get_array_ptr (warmstart0, " warmstart0" , nroll,
215
- 1 , model ->nv );
218
+ 1 , model_ptrs[ 0 ] ->nv );
216
219
mjtNum* control_ptr = get_array_ptr (control, " control" , nroll,
217
220
nstep, ncontrol);
218
221
mjtNum* state_ptr = get_array_ptr (state, " state" , nroll, nstep, nstate);
219
222
mjtNum* sensordata_ptr = get_array_ptr (sensordata, " sensordata" , nroll,
220
- nstep, model ->nsensordata );
223
+ nstep, model_ptrs[ 0 ] ->nsensordata );
221
224
222
225
// perform rollouts
223
226
{
@@ -226,7 +229,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
226
229
227
230
// call unsafe rollout function
228
231
InterceptMjErrors (_unsafe_rollout)(
229
- model , data, nroll, nstep, control_spec, state0_ptr,
232
+ model_ptrs , data, nroll, nstep, control_spec, state0_ptr,
230
233
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
231
234
}
232
235
},
0 commit comments