@@ -42,7 +42,15 @@ def call_rollout(self, model_list, initial_state, nstep):
42
42
43
43
def run (self , model_list , initial_state , nstep ):
44
44
nroll = len (model_list )
45
+
46
+ # Divide jobs evenly across threads (as in test)
47
+ # better for very wide rollouts
48
+ # chunk_size = max(1, nroll // self.num_workers)
49
+
50
+ # Divide jobs across threads with a chunksize 1/10th of the even amount
51
+ # new strategy, helps with load balancing
45
52
chunk_size = max (1 , nroll // (10 * self .num_workers ))
53
+
46
54
nfulljobs = nroll // chunk_size ;
47
55
chunk_remainder = nroll % chunk_size ;
48
56
njobs = nfulljobs
@@ -65,7 +73,8 @@ def run(self, model_list, initial_state, nstep):
65
73
future .result ()
66
74
67
75
def benchmark_rollout (model_file , nthread = os .cpu_count ()):
68
- print ('\n ' , model_file )
76
+ print ()
77
+ print (model_file )
69
78
bench_steps = int (1e4 ) # Run approximately bench_steps per thread
70
79
71
80
# A grid search
@@ -122,9 +131,19 @@ def benchmark_rollout(model_file, nthread=os.cpu_count()):
122
131
nt_stats [0 ] / pt_stats [0 ], nt_stats [2 ] / pt_stats [2 ]))
123
132
124
133
if __name__ == '__main__' :
134
+ print ('============================================================' )
135
+ print ('small to medium models' )
136
+ print ('============================================================' )
137
+
125
138
benchmark_rollout (model_file = '../../../dm_control/dm_control/suite/hopper.xml' )
126
139
benchmark_rollout (model_file = '../../../mujoco_menagerie/unitree_go2/scene.xml' )
127
140
benchmark_rollout (model_file = '../../model/humanoid/humanoid.xml' )
141
+
142
+ print ()
143
+ print ('============================================================' )
144
+ print ('very large models' )
145
+ print ('============================================================' )
146
+ benchmark_rollout (model_file = '../../model/cards/cards.xml' )
128
147
benchmark_rollout (model_file = '../../model/humanoid/humanoid100.xml' )
129
- # benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml')
148
+ benchmark_rollout (model_file = '../../test/benchmark/testdata/humanoid200.xml' )
130
149
# benchmark_rollout(model_file='../../model/humanoid/100_humanoids.xml')
0 commit comments