-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathtrain.py
30 lines (25 loc) · 1.04 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import tensorflow as tf
import numpy as np
import util
import config
from model import *
from dataset import *
model = DCGAN(config.nz, config.nsf, config.nvx, config.batch_size, config.learning_rate)
dataset = Dataset(config.dataset_path)
total_batch = dataset.num_examples / config.batch_size
for epoch in xrange(1, 51):
for batch in xrange(total_batch):
z = np.random.uniform(-1, 1, [config.batch_size, config.nz]).astype(np.float32)
x = np.array(dataset.next_batch(config.batch_size))
# z = np.split(z, 2) # multi-GPU mode
# x = np.split(x, 2) # multi-GPU mode
model.optimize(z, x)
if batch % 100 == 0:
lossD, lossG = model.get_errors(z, x)
x_g = model.generate(z)
for i, x in enumerate(x_g[:5]):
util.save_binvox("./out/{0}-{1}.binvox".format(epoch, i), x > 0.5)
print "{0:>2}, {1:>5}, {2:.8f}, {3:.8f}".format(epoch, batch, lossD, lossG)
if epoch % 10 == 0:
model.save("./params/epoch-{0}.ckpt".format(epoch))
model.close()