From c49855c522f9cd78673803962cc1ae1d941198a5 Mon Sep 17 00:00:00 2001 From: "G.O.D" <32255912+gameofdimension@users.noreply.github.com> Date: Tue, 11 Jun 2024 21:51:03 +0800 Subject: [PATCH] fix checkpoint --- models/nn.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/models/nn.py b/models/nn.py index b3398c8..50f1336 100755 --- a/models/nn.py +++ b/models/nn.py @@ -127,10 +127,12 @@ def timestep_embedding(timesteps, dim, max_period=10000): return embedding -def torch_checkpoint(func, args, flag, preserve_rng_state=False): +def torch_checkpoint(func, args, flag, preserve_rng_state=True): # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8 if flag: return torch.utils.checkpoint.checkpoint( - func, *args, preserve_rng_state=preserve_rng_state) + func, *args, + use_reentrant=True, + preserve_rng_state=preserve_rng_state) else: return func(*args)