-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtraining_util.py
41 lines (36 loc) · 1.2 KB
/
training_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 17 23:45:26 2018
borrowed from https://github.com/JeremyCCHsu/Gumbel-Softmax-VAE-in-tensorflow
"""
import os
import sys
import tensorflow as tf
def save(saver, sess, logdir, step):
''' Save a model to logdir/model.ckpt-[step] '''
model_name = 'model.ckpt'
checkpoint_path = os.path.join(logdir, model_name)
sys.stdout.flush()
if not os.path.exists(logdir):
os.makedirs(logdir)
saver.save(sess, checkpoint_path, global_step=step)
def load(saver, sess, logdir):
'''
Try to load model form a dir (search for the newest checkpoint)
'''
print('Trying to restore checkpoints from {} ...'.format(logdir),
end="")
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt:
print(' Checkpoint found: {}'.format(ckpt.model_checkpoint_path))
global_step = int(
ckpt.model_checkpoint_path
.split('/')[-1]
.split('-')[-1])
print(' Global step: {}'.format(global_step))
print(' Restoring...', end="")
saver.restore(sess, ckpt.model_checkpoint_path)
return global_step
else:
print('No checkpoint found')
return None