Skip to content

Commit deb9895

Browse files
committed
Update checkpoint save to fix old hard-link + fuse issue I ran into again... fix #340
1 parent c4fb98f commit deb9895

File tree

1 file changed

+75
-44
lines changed

1 file changed

+75
-44
lines changed

timm/utils/checkpoint_saver.py

+75-44
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
"""
77

88
import glob
9+
import logging
910
import operator
1011
import os
11-
import logging
12+
import shutil
1213

1314
import torch
1415

@@ -32,7 +33,8 @@ def __init__(
3233
recovery_dir='',
3334
decreasing=False,
3435
max_history=10,
35-
unwrap_fn=unwrap_model):
36+
unwrap_fn=unwrap_model
37+
):
3638

3739
# objects to save state_dicts of
3840
self.model = model
@@ -46,7 +48,8 @@ def __init__(
4648
self.best_epoch = None
4749
self.best_metric = None
4850
self.curr_recovery_file = ''
49-
self.last_recovery_file = ''
51+
self.prev_recovery_file = ''
52+
self.can_hardlink = True
5053

5154
# config
5255
self.checkpoint_dir = checkpoint_dir
@@ -60,41 +63,26 @@ def __init__(
6063
self.unwrap_fn = unwrap_fn
6164
assert self.max_history >= 1
6265

63-
def save_checkpoint(self, epoch, metric=None):
64-
assert epoch >= 0
65-
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
66-
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
67-
self._save(tmp_save_path, epoch, metric)
68-
if os.path.exists(last_save_path):
69-
os.unlink(last_save_path) # required for Windows support.
70-
os.rename(tmp_save_path, last_save_path)
71-
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
72-
if (len(self.checkpoint_files) < self.max_history
73-
or metric is None or self.cmp(metric, worst_file[1])):
74-
if len(self.checkpoint_files) >= self.max_history:
75-
self._cleanup_checkpoints(1)
76-
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
77-
save_path = os.path.join(self.checkpoint_dir, filename)
78-
os.link(last_save_path, save_path)
79-
self.checkpoint_files.append((save_path, metric))
80-
self.checkpoint_files = sorted(
81-
self.checkpoint_files, key=lambda x: x[1],
82-
reverse=not self.decreasing) # sort in descending order if a lower metric is not better
83-
84-
checkpoints_str = "Current checkpoints:\n"
85-
for c in self.checkpoint_files:
86-
checkpoints_str += ' {}\n'.format(c)
87-
_logger.info(checkpoints_str)
88-
89-
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
90-
self.best_epoch = epoch
91-
self.best_metric = metric
92-
best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension)
93-
if os.path.exists(best_save_path):
94-
os.unlink(best_save_path)
95-
os.link(last_save_path, best_save_path)
96-
97-
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
66+
def _replace(self, src, dst):
67+
if self.can_hardlink:
68+
try:
69+
if os.path.exists(dst):
70+
os.unlink(dst) # required for Windows support.
71+
except (OSError, NotImplementedError) as e:
72+
self.can_hardlink = False
73+
os.replace(src, dst)
74+
75+
def _duplicate(self, src, dst):
76+
if self.can_hardlink:
77+
try:
78+
if os.path.exists(dst):
79+
# for Windows
80+
os.unlink(dst)
81+
os.link(src, dst)
82+
return
83+
except (OSError, NotImplementedError) as e:
84+
self.can_hardlink = False
85+
shutil.copy2(src, dst)
9886

9987
def _save(self, save_path, epoch, metric=None):
10088
save_state = {
@@ -129,18 +117,61 @@ def _cleanup_checkpoints(self, trim=0):
129117
_logger.error("Exception '{}' while deleting checkpoint".format(e))
130118
self.checkpoint_files = self.checkpoint_files[:delete_index]
131119

120+
def save_checkpoint(self, epoch, metric=None):
121+
assert epoch >= 0
122+
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
123+
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
124+
self._save(tmp_save_path, epoch, metric)
125+
self._replace(tmp_save_path, last_save_path)
126+
127+
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
128+
if (
129+
len(self.checkpoint_files) < self.max_history
130+
or metric is None
131+
or self.cmp(metric, worst_file[1])
132+
):
133+
if len(self.checkpoint_files) >= self.max_history:
134+
self._cleanup_checkpoints(1)
135+
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
136+
save_path = os.path.join(self.checkpoint_dir, filename)
137+
self._duplicate(last_save_path, save_path)
138+
139+
self.checkpoint_files.append((save_path, metric))
140+
self.checkpoint_files = sorted(
141+
self.checkpoint_files,
142+
key=lambda x: x[1],
143+
reverse=not self.decreasing # sort in descending order if a lower metric is not better
144+
)
145+
146+
checkpoints_str = "Current checkpoints:\n"
147+
for c in self.checkpoint_files:
148+
checkpoints_str += ' {}\n'.format(c)
149+
_logger.info(checkpoints_str)
150+
151+
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
152+
self.best_epoch = epoch
153+
self.best_metric = metric
154+
best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension)
155+
self._duplicate(last_save_path, best_save_path)
156+
157+
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
158+
132159
def save_recovery(self, epoch, batch_idx=0):
133160
assert epoch >= 0
161+
tmp_save_path = os.path.join(self.recovery_dir, 'recovery_tmp' + self.extension)
162+
self._save(tmp_save_path, epoch)
163+
134164
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
135165
save_path = os.path.join(self.recovery_dir, filename)
136-
self._save(save_path, epoch)
137-
if os.path.exists(self.last_recovery_file):
166+
self._replace(tmp_save_path, save_path)
167+
168+
if os.path.exists(self.prev_recovery_file):
138169
try:
139-
_logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
140-
os.remove(self.last_recovery_file)
170+
_logger.debug("Cleaning recovery: {}".format(self.prev_recovery_file))
171+
os.remove(self.prev_recovery_file)
141172
except Exception as e:
142-
_logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
143-
self.last_recovery_file = self.curr_recovery_file
173+
_logger.error("Exception '{}' while removing {}".format(e, self.prev_recovery_file))
174+
self.prev_recovery_file = self.curr_recovery_file
144175
self.curr_recovery_file = save_path
145176

146177
def find_recovery(self):

0 commit comments

Comments
 (0)