6
6
"""
7
7
8
8
import glob
9
+ import logging
9
10
import operator
10
11
import os
11
- import logging
12
+ import shutil
12
13
13
14
import torch
14
15
@@ -32,7 +33,8 @@ def __init__(
32
33
recovery_dir = '' ,
33
34
decreasing = False ,
34
35
max_history = 10 ,
35
- unwrap_fn = unwrap_model ):
36
+ unwrap_fn = unwrap_model
37
+ ):
36
38
37
39
# objects to save state_dicts of
38
40
self .model = model
@@ -46,7 +48,8 @@ def __init__(
46
48
self .best_epoch = None
47
49
self .best_metric = None
48
50
self .curr_recovery_file = ''
49
- self .last_recovery_file = ''
51
+ self .prev_recovery_file = ''
52
+ self .can_hardlink = True
50
53
51
54
# config
52
55
self .checkpoint_dir = checkpoint_dir
@@ -60,41 +63,26 @@ def __init__(
60
63
self .unwrap_fn = unwrap_fn
61
64
assert self .max_history >= 1
62
65
63
- def save_checkpoint (self , epoch , metric = None ):
64
- assert epoch >= 0
65
- tmp_save_path = os .path .join (self .checkpoint_dir , 'tmp' + self .extension )
66
- last_save_path = os .path .join (self .checkpoint_dir , 'last' + self .extension )
67
- self ._save (tmp_save_path , epoch , metric )
68
- if os .path .exists (last_save_path ):
69
- os .unlink (last_save_path ) # required for Windows support.
70
- os .rename (tmp_save_path , last_save_path )
71
- worst_file = self .checkpoint_files [- 1 ] if self .checkpoint_files else None
72
- if (len (self .checkpoint_files ) < self .max_history
73
- or metric is None or self .cmp (metric , worst_file [1 ])):
74
- if len (self .checkpoint_files ) >= self .max_history :
75
- self ._cleanup_checkpoints (1 )
76
- filename = '-' .join ([self .save_prefix , str (epoch )]) + self .extension
77
- save_path = os .path .join (self .checkpoint_dir , filename )
78
- os .link (last_save_path , save_path )
79
- self .checkpoint_files .append ((save_path , metric ))
80
- self .checkpoint_files = sorted (
81
- self .checkpoint_files , key = lambda x : x [1 ],
82
- reverse = not self .decreasing ) # sort in descending order if a lower metric is not better
83
-
84
- checkpoints_str = "Current checkpoints:\n "
85
- for c in self .checkpoint_files :
86
- checkpoints_str += ' {}\n ' .format (c )
87
- _logger .info (checkpoints_str )
88
-
89
- if metric is not None and (self .best_metric is None or self .cmp (metric , self .best_metric )):
90
- self .best_epoch = epoch
91
- self .best_metric = metric
92
- best_save_path = os .path .join (self .checkpoint_dir , 'model_best' + self .extension )
93
- if os .path .exists (best_save_path ):
94
- os .unlink (best_save_path )
95
- os .link (last_save_path , best_save_path )
96
-
97
- return (None , None ) if self .best_metric is None else (self .best_metric , self .best_epoch )
66
+ def _replace (self , src , dst ):
67
+ if self .can_hardlink :
68
+ try :
69
+ if os .path .exists (dst ):
70
+ os .unlink (dst ) # required for Windows support.
71
+ except (OSError , NotImplementedError ) as e :
72
+ self .can_hardlink = False
73
+ os .replace (src , dst )
74
+
75
+ def _duplicate (self , src , dst ):
76
+ if self .can_hardlink :
77
+ try :
78
+ if os .path .exists (dst ):
79
+ # for Windows
80
+ os .unlink (dst )
81
+ os .link (src , dst )
82
+ return
83
+ except (OSError , NotImplementedError ) as e :
84
+ self .can_hardlink = False
85
+ shutil .copy2 (src , dst )
98
86
99
87
def _save (self , save_path , epoch , metric = None ):
100
88
save_state = {
@@ -129,18 +117,61 @@ def _cleanup_checkpoints(self, trim=0):
129
117
_logger .error ("Exception '{}' while deleting checkpoint" .format (e ))
130
118
self .checkpoint_files = self .checkpoint_files [:delete_index ]
131
119
120
+ def save_checkpoint (self , epoch , metric = None ):
121
+ assert epoch >= 0
122
+ tmp_save_path = os .path .join (self .checkpoint_dir , 'tmp' + self .extension )
123
+ last_save_path = os .path .join (self .checkpoint_dir , 'last' + self .extension )
124
+ self ._save (tmp_save_path , epoch , metric )
125
+ self ._replace (tmp_save_path , last_save_path )
126
+
127
+ worst_file = self .checkpoint_files [- 1 ] if self .checkpoint_files else None
128
+ if (
129
+ len (self .checkpoint_files ) < self .max_history
130
+ or metric is None
131
+ or self .cmp (metric , worst_file [1 ])
132
+ ):
133
+ if len (self .checkpoint_files ) >= self .max_history :
134
+ self ._cleanup_checkpoints (1 )
135
+ filename = '-' .join ([self .save_prefix , str (epoch )]) + self .extension
136
+ save_path = os .path .join (self .checkpoint_dir , filename )
137
+ self ._duplicate (last_save_path , save_path )
138
+
139
+ self .checkpoint_files .append ((save_path , metric ))
140
+ self .checkpoint_files = sorted (
141
+ self .checkpoint_files ,
142
+ key = lambda x : x [1 ],
143
+ reverse = not self .decreasing # sort in descending order if a lower metric is not better
144
+ )
145
+
146
+ checkpoints_str = "Current checkpoints:\n "
147
+ for c in self .checkpoint_files :
148
+ checkpoints_str += ' {}\n ' .format (c )
149
+ _logger .info (checkpoints_str )
150
+
151
+ if metric is not None and (self .best_metric is None or self .cmp (metric , self .best_metric )):
152
+ self .best_epoch = epoch
153
+ self .best_metric = metric
154
+ best_save_path = os .path .join (self .checkpoint_dir , 'model_best' + self .extension )
155
+ self ._duplicate (last_save_path , best_save_path )
156
+
157
+ return (None , None ) if self .best_metric is None else (self .best_metric , self .best_epoch )
158
+
132
159
def save_recovery (self , epoch , batch_idx = 0 ):
133
160
assert epoch >= 0
161
+ tmp_save_path = os .path .join (self .recovery_dir , 'recovery_tmp' + self .extension )
162
+ self ._save (tmp_save_path , epoch )
163
+
134
164
filename = '-' .join ([self .recovery_prefix , str (epoch ), str (batch_idx )]) + self .extension
135
165
save_path = os .path .join (self .recovery_dir , filename )
136
- self ._save (save_path , epoch )
137
- if os .path .exists (self .last_recovery_file ):
166
+ self ._replace (tmp_save_path , save_path )
167
+
168
+ if os .path .exists (self .prev_recovery_file ):
138
169
try :
139
- _logger .debug ("Cleaning recovery: {}" .format (self .last_recovery_file ))
140
- os .remove (self .last_recovery_file )
170
+ _logger .debug ("Cleaning recovery: {}" .format (self .prev_recovery_file ))
171
+ os .remove (self .prev_recovery_file )
141
172
except Exception as e :
142
- _logger .error ("Exception '{}' while removing {}" .format (e , self .last_recovery_file ))
143
- self .last_recovery_file = self .curr_recovery_file
173
+ _logger .error ("Exception '{}' while removing {}" .format (e , self .prev_recovery_file ))
174
+ self .prev_recovery_file = self .curr_recovery_file
144
175
self .curr_recovery_file = save_path
145
176
146
177
def find_recovery (self ):
0 commit comments