@@ -243,6 +243,8 @@ def shutdown_persistent_pool():
243
243
This is called automatically interpreter shutdown, but can also be called manually.
244
244
""" # fmt: skip
245
245
global persistent_rollout
246
+ if persistent_rollout is not None :
247
+ persistent_rollout .close ()
246
248
persistent_rollout = None
247
249
atexit .register (shutdown_persistent_pool )
248
250
@@ -308,29 +310,33 @@ def rollout(
308
310
309
311
# Use a persistent thread pool if requested
310
312
if persistent_pool :
311
- global persistent_rollout
312
313
# Create or restart persistent threadpool
313
- if persistent_rollout is None or persistent_rollout .nthread != nthread :
314
+ global persistent_rollout
315
+ if persistent_rollout is None :
316
+ persistent_rollout = Rollout (nthread = nthread )
317
+ if persistent_rollout .nthread != nthread :
318
+ persistent_rollout .close ()
314
319
persistent_rollout = Rollout (nthread = nthread )
315
320
rollout = persistent_rollout
316
321
else :
317
322
rollout = Rollout (nthread = nthread )
318
323
319
- ret = rollout .rollout (
320
- model ,
321
- data ,
322
- initial_state ,
323
- control ,
324
- control_spec = control_spec ,
325
- skip_checks = skip_checks ,
326
- nstep = nstep ,
327
- initial_warmstart = initial_warmstart ,
328
- state = state ,
329
- sensordata = sensordata ,
330
- chunk_size = chunk_size )
331
-
332
- if not persistent_pool :
333
- rollout .close ()
324
+ try :
325
+ ret = rollout .rollout (
326
+ model ,
327
+ data ,
328
+ initial_state ,
329
+ control ,
330
+ control_spec = control_spec ,
331
+ skip_checks = skip_checks ,
332
+ nstep = nstep ,
333
+ initial_warmstart = initial_warmstart ,
334
+ state = state ,
335
+ sensordata = sensordata ,
336
+ chunk_size = chunk_size )
337
+ finally :
338
+ if not persistent_pool :
339
+ rollout .close ()
334
340
335
341
# return outputs
336
342
return ret
0 commit comments