Skip to content

Commit b52c36d

Browse files
committed
rollout python versus native benchmark
1 parent 22baa4a commit b52c36d

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

python/mujoco/rollout_benchmark.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,15 @@ def call_rollout(self, model_list, initial_state, nstep):
4242

4343
def run(self, model_list, initial_state, nstep):
4444
nroll = len(model_list)
45+
46+
# Divide jobs evenly across threads (as in test)
47+
# better for very wide rollouts
48+
# chunk_size = max(1, nroll // self.num_workers)
49+
50+
# Divide jobs across threads with a chunksize 1/10th of the even amount
51+
# new strategy, helps with load balancing
4552
chunk_size = max(1, nroll // (10 * self.num_workers))
53+
4654
nfulljobs = nroll // chunk_size;
4755
chunk_remainder = nroll % chunk_size;
4856
njobs = nfulljobs
@@ -65,7 +73,8 @@ def run(self, model_list, initial_state, nstep):
6573
future.result()
6674

6775
def benchmark_rollout(model_file, nthread=os.cpu_count()):
68-
print('\n', model_file)
76+
print()
77+
print(model_file)
6978
bench_steps = int(1e4) # Run approximately bench_steps per thread
7079

7180
# A grid search
@@ -122,9 +131,19 @@ def benchmark_rollout(model_file, nthread=os.cpu_count()):
122131
nt_stats[0] / pt_stats[0], nt_stats[2] / pt_stats[2]))
123132

124133
if __name__ == '__main__':
134+
print('============================================================')
135+
print('small to medium models')
136+
print('============================================================')
137+
125138
benchmark_rollout(model_file='../../../dm_control/dm_control/suite/hopper.xml')
126139
benchmark_rollout(model_file='../../../mujoco_menagerie/unitree_go2/scene.xml')
127140
benchmark_rollout(model_file='../../model/humanoid/humanoid.xml')
141+
142+
print()
143+
print('============================================================')
144+
print('very large models')
145+
print('============================================================')
146+
benchmark_rollout(model_file='../../model/cards/cards.xml')
128147
benchmark_rollout(model_file='../../model/humanoid/humanoid100.xml')
129-
# benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml')
148+
benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml')
130149
# benchmark_rollout(model_file='../../model/humanoid/100_humanoids.xml')

0 commit comments

Comments
 (0)