Skip to content

Commit 8e48c30

Browse files
committed
rollout use size_t for array size check and pointer arithmetic
1 parent 19624ae commit 8e48c30

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

python/mujoco/rollout.cc

+15-13
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,11 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll,
7575
const mjtNum* state0, const mjtNum* warmstart0,
7676
const mjtNum* control, mjtNum* state, mjtNum* sensordata) {
7777
// sizes
78-
int nstate = mj_stateSize(m[0], mjSTATE_FULLPHYSICS);
79-
int ncontrol = mj_stateSize(m[0], control_spec);
80-
int nv = m[0]->nv, nbody = m[0]->nbody, neq = m[0]->neq;
81-
int nsensordata = m[0]->nsensordata;
78+
size_t nstate = static_cast<size_t>(mj_stateSize(m[0], mjSTATE_FULLPHYSICS));
79+
size_t ncontrol = static_cast<size_t>(mj_stateSize(m[0], control_spec));
80+
size_t nv = static_cast<size_t>(m[0]->nv);
81+
int nbody = m[0]->nbody, neq = m[0]->neq;
82+
size_t nsensordata = static_cast<size_t>(m[0]->nsensordata);
8283

8384
// clear user inputs if unspecified
8485
if (!(control_spec & mjSTATE_CTRL)) {
@@ -92,7 +93,7 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll,
9293
}
9394

9495
// loop over rollouts
95-
for (int r = start_roll; r < end_roll; r++) {
96+
for (size_t r = start_roll; r < end_roll; r++) {
9697
// clear user inputs if unspecified
9798
if (!(control_spec & mjSTATE_MOCAP_POS)) {
9899
for (int i = 0; i < nbody; i++) {
@@ -117,7 +118,7 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll,
117118

118119
// set warmstart accelerations
119120
if (warmstart0) {
120-
mju_copy(d->qacc_warmstart, warmstart0 + r*nv, nv);
121+
mju_copy(d->qacc_warmstart, warmstart0 + r * nv, nv);
121122
} else {
122123
mju_zero(d->qacc_warmstart, nv);
123124
}
@@ -128,7 +129,7 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll,
128129
}
129130

130131
// roll out trajectory
131-
for (int t = 0; t < nstep; t++) {
132+
for (size_t t = 0; t < nstep; t++) {
132133
// check for warnings
133134
bool nwarning = false;
134135
for (int i = 0; i < mjNWARNING; i++) {
@@ -141,7 +142,7 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll,
141142
// if any warnings, fill remaining outputs with current outputs, break
142143
if (nwarning) {
143144
for (; t < nstep; t++) {
144-
int step = r*nstep + t;
145+
size_t step = r*static_cast<size_t>(nstep) + t;
145146
if (state) {
146147
mj_getState(m[r], d, state + step*nstate, mjSTATE_FULLPHYSICS);
147148
}
@@ -152,24 +153,24 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll,
152153
break;
153154
}
154155

155-
int step = r*nstep + t;
156+
size_t step = r*static_cast<size_t>(nstep) + t;
156157

157158
// controls
158159
if (control) {
159-
mj_setState(m[r], d, control + step*ncontrol, control_spec);
160+
mj_setState(m[r], d, control + step * ncontrol, control_spec);
160161
}
161162

162163
// step
163164
mj_step(m[r], d);
164165

165166
// copy out new state
166167
if (state) {
167-
mj_getState(m[r], d, state + step*nstate, mjSTATE_FULLPHYSICS);
168+
mj_getState(m[r], d, state + step * nstate, mjSTATE_FULLPHYSICS);
168169
}
169170

170171
// copy out sensor values
171172
if (sensordata) {
172-
mju_copy(sensordata + step*nsensordata, d->sensordata, nsensordata);
173+
mju_copy(sensordata + step * nsensordata, d->sensordata, nsensordata);
173174
}
174175
}
175176
}
@@ -226,7 +227,8 @@ mjtNum* get_array_ptr(std::optional<const py::array_t<mjtNum>> arg,
226227
py::buffer_info info = arg->request();
227228

228229
// check size
229-
int expected_size = nbatch * nstep * dim;
230+
size_t expected_size =
231+
static_cast<size_t>(nbatch) * static_cast<size_t>(nstep) * static_cast<size_t>(dim);
230232
if (info.size != expected_size) {
231233
std::ostringstream msg;
232234
msg << name << ".size should be " << expected_size << ", got " << info.size;

0 commit comments

Comments
 (0)