diff --git a/train.py b/train.py index fbf481e8..2e1259d7 100644 --- a/train.py +++ b/train.py @@ -171,8 +171,8 @@ def main(hparams): checkpoint_callback = \ ModelCheckpoint(filepath=os.path.join(f'ckpts/{hparams.exp_name}', '{epoch:d}'), - monitor='val/loss', - mode='min', + monitor='val/psnr', + mode='max', save_top_k=5) logger = TestTubeLogger(save_dir="logs",