Skip to content

Commit 033abed

Browse files
author
Anantaa Kotal
committed
second commit
1 parent 5fe8833 commit 033abed

File tree

8 files changed

+794
-14
lines changed

8 files changed

+794
-14
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
data_half/
2+
debug_imagery/
3+
models/
4+
15
# Byte-compiled / optimized / DLL files
26
__pycache__/
37
*.py[cod]

generated_imagery/new_image.jpg

46.7 KB
Loading

main.py

+35-14
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,27 @@
2525
from torch.utils.data import Dataset
2626
from torch.utils.data import DataLoader
2727
from torch.utils.tensorboard import SummaryWriter
28+
import gc
29+
30+
import torch
31+
from GPUtil import showUtilization as gpu_usage
32+
from numba import cuda
33+
2834

2935
DRIVE_PATH = os.getcwd()
3036

3137
BINARIES_PATH = os.path.join(DRIVE_PATH, 'models', 'binaries')
3238
CHECKPOINTS_PATH = os.path.join(DRIVE_PATH, 'models', 'checkpoints')
3339
MODEL_PATH = os.path.join(DRIVE_PATH, 'models', 'binaries', 'NIH_CXR.pth')
34-
DATA_DIR_PATH = os.path.join(DRIVE_PATH, 'data')
40+
#DATA_DIR_PATH = os.path.join(DRIVE_PATH, 'data_half/images')
41+
DATA_DIR_PATH = "/nfs/ada/joshi/users/anantak1/data/NIH_CXR_data/images"
3542
DEBUG_IMAGERY_PATH = os.path.join(DRIVE_PATH, 'debug_imagery')
3643
GENERATED_IMAGES_PATH = os.path.join(DRIVE_PATH, 'generated_imagery')
3744

3845
IMG_SIZE = 256
39-
BATCH_SIZE = 32
46+
BATCH_SIZE = 8
47+
48+
#free_gpu_cache()
4049

4150
transform = transforms.Compose([
4251
# you can add other transformations in this list
@@ -47,12 +56,7 @@
4756

4857
img_dataset = datasets.ImageFolder(DATA_DIR_PATH, transform=transform)
4958

50-
evens = list(range(0, len(img_dataset), 16))
51-
odds = list(range(1, len(img_dataset), 2))
52-
trainset_1 = torch.utils.data.Subset(img_dataset, evens)
53-
trainset_2 = torch.utils.data.Subset(img_dataset, odds)
54-
55-
img_dataloader = torch.utils.data.DataLoader(img_dataset, batch_size=BATCH_SIZE, shuffle=True)
59+
img_dataloader = torch.utils.data.DataLoader(img_dataset, batch_size=BATCH_SIZE, drop_last=True, shuffle=True)
5660

5761

5862

@@ -76,6 +80,7 @@
7680
# Size of the generator's input vector.
7781
LATENT_SPACE_DIM = 100
7882

83+
#free_gpu_cache()
7984

8085
# This one will produce a batch of those vectors
8186
def get_gaussian_latent_batch(batch_size, device):
@@ -93,7 +98,7 @@ class GeneratorNet(torch.nn.Module):
9398
def __init__(self, img_shape=(IMG_SIZE, IMG_SIZE)):
9499
super().__init__()
95100
self.generated_img_shape = img_shape
96-
num_neurons_per_layer = [LATENT_SPACE_DIM, 256, 512, 1024, img_shape[0] * img_shape[1]]
101+
num_neurons_per_layer = [LATENT_SPACE_DIM, 512, 1024, 4096, img_shape[0] * img_shape[1]]
97102

98103
self.net = nn.Sequential(
99104
*vanilla_block(num_neurons_per_layer[0], num_neurons_per_layer[1]),
@@ -116,19 +121,19 @@ def __init__(self, img_shape=(IMG_SIZE, IMG_SIZE)):
116121
self.net = nn.Sequential(
117122
*vanilla_block(num_neurons_per_layer[0], num_neurons_per_layer[1], normalize=False),
118123
*vanilla_block(num_neurons_per_layer[1], num_neurons_per_layer[2], normalize=False),
119-
*vanilla_block(num_neurons_per_layer[2], num_neurons_per_layer[3], normalize=False, activation=nn.Sigmoid())
124+
*vanilla_block(num_neurons_per_layer[2], num_neurons_per_layer[3], normalize=False, activation=nn.Sigmoid())
120125
)
121126

122127
def forward(self, img_batch):
123128
img_batch_flattened = img_batch.view(img_batch.shape[0], -1) # flatten from (N,1,H,W) into (N, HxW)
124129
return self.net(img_batch_flattened)
125130

126131
def get_optimizers(d_net, g_net):
127-
d_opt = Adam(d_net.parameters(), lr=0.0002, betas=(0.5, 0.999))
128-
g_opt = Adam(g_net.parameters(), lr=0.0002, betas=(0.5, 0.999))
132+
d_opt = Adam(d_net.parameters(), lr=0.0001, betas=(0.5, 0.999))
133+
g_opt = Adam(g_net.parameters(), lr=0.0001, betas=(0.5, 0.999))
129134
return d_opt, g_opt
130135

131-
torch.cuda.empty_cache()
136+
#free_gpu_cache()
132137

133138
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134139

@@ -144,13 +149,20 @@ def get_optimizers(d_net, g_net):
144149
checkpoint_freq = 2
145150
console_log_freq = 50
146151

152+
debug_imagery_log_freq = 50
153+
154+
ref_batch_size = 16
155+
ref_noise_batch = get_gaussian_latent_batch(ref_batch_size, device) # Track G's quality during training on fixed noise vectors
156+
img_cnt = 0
157+
147158
num_epochs = 5
148159

149160
ts = time.time()
150161

151162
def train_GAN():
152163
for epoch in range(num_epochs):
153164
for batch_idx, (real_images, _) in enumerate(img_dataloader):
165+
global img_cnt
154166

155167
real_images = real_images.to(device)
156168

@@ -173,7 +185,16 @@ def train_GAN():
173185

174186
generator_loss.backward()
175187
generator_opt.step()
176-
188+
189+
# Save intermediate generator images (more convenient like this than through tensorboard)
190+
if batch_idx % debug_imagery_log_freq == 0:
191+
with torch.no_grad():
192+
log_generated_images = generator_net(ref_noise_batch)
193+
log_generated_images_resized = nn.Upsample(scale_factor=2.5, mode='nearest')(log_generated_images)
194+
out_path = os.path.join(DEBUG_IMAGERY_PATH, f'{str(img_cnt).zfill(6)}.jpg')
195+
save_image(log_generated_images_resized, out_path, nrow=int(np.sqrt(ref_batch_size)), normalize=True)
196+
img_cnt += 1
197+
177198
if batch_idx % console_log_freq == 0:
178199
prefix = 'GAN training: time elapsed'
179200
print(

main_1.py

+257
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# -*- coding: utf-8 -*-
2+
"""main.ipynb
3+
4+
Automatically generated by Colaboratory.
5+
6+
Original file is located at
7+
https://colab.research.google.com/drive/1dB_Dwq4_Kp_B_ON1mZaFb72e5-X93DTX
8+
"""
9+
10+
import os
11+
import re
12+
import time
13+
import enum
14+
15+
16+
import cv2 as cv
17+
import numpy as np
18+
import matplotlib.pyplot as plt
19+
20+
import torch
21+
from torch import nn
22+
from torch.optim import Adam
23+
from torchvision import transforms, datasets
24+
from torchvision.utils import make_grid, save_image
25+
from torch.utils.data import Dataset
26+
from torch.utils.data import DataLoader
27+
from torch.utils.tensorboard import SummaryWriter
28+
import gc
29+
30+
import torch
31+
from GPUtil import showUtilization as gpu_usage
32+
from numba import cuda
33+
34+
35+
DRIVE_PATH = os.getcwd()
36+
37+
import os
38+
# os.environ['CUDA_VISIBLE_DEVICES']='2, 3'
39+
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:516"
40+
41+
BINARIES_PATH = os.path.join(DRIVE_PATH, 'models', 'binaries')
42+
CHECKPOINTS_PATH = os.path.join(DRIVE_PATH, 'models', 'checkpoints')
43+
MODEL_PATH = os.path.join(DRIVE_PATH, 'models', 'binaries', 'NIH_CXR.pth')
44+
#DATA_DIR_PATH = os.path.join(DRIVE_PATH, 'data_half/images')
45+
DATA_DIR_PATH = "/nfs/ada/joshi/users/anantak1/data/NIH_CXR_data/images"
46+
DEBUG_IMAGERY_PATH = os.path.join(DRIVE_PATH, 'debug_imagery')
47+
GENERATED_IMAGES_PATH = os.path.join(DRIVE_PATH, 'generated_imagery')
48+
49+
IMG_SIZE = 256
50+
BATCH_SIZE = 32
51+
52+
#free_gpu_cache()
53+
54+
transform = transforms.Compose([
55+
# you can add other transformations in this list
56+
transforms.Grayscale(),
57+
transforms.Resize(IMG_SIZE),
58+
transforms.ToTensor()
59+
])
60+
61+
img_dataset = datasets.ImageFolder(DATA_DIR_PATH, transform=transform)
62+
63+
img_dataloader = torch.utils.data.DataLoader(img_dataset, batch_size=BATCH_SIZE, drop_last=True, shuffle=True)
64+
65+
66+
67+
# Visualize the data
68+
69+
print(f'Dataset size: {len(img_dataset)} images.')
70+
71+
"""num_imgs_to_visualize = 1
72+
batch = next(iter(img_dataloader))
73+
img_batch = batch[0]
74+
img_batch_subset = img_batch[:num_imgs_to_visualize]
75+
76+
print(f'Image shape {img_batch_subset.shape[1:]}')
77+
grid = make_grid(img_batch_subset, nrow=int(np.sqrt(num_imgs_to_visualize)), normalize=True, pad_value=1.)
78+
grid = np.moveaxis(grid.numpy(), 0, 2) # from CHW -> HWC format that's what matplotlib expects! Get used to this.
79+
plt.figure(figsize=(6, 6))
80+
plt.title("Samples from the NIH_CXR dataset")
81+
plt.imshow(grid)
82+
plt.show()"""
83+
84+
# Size of the generator's input vector.
85+
LATENT_SPACE_DIM = 100
86+
87+
#free_gpu_cache()
88+
89+
# This one will produce a batch of those vectors
90+
def get_gaussian_latent_batch(batch_size, device):
91+
return torch.randn((batch_size, LATENT_SPACE_DIM), device=device)
92+
93+
94+
def vanilla_block(in_feat, out_feat, normalize=True, activation=None):
95+
layers = [nn.Linear(in_feat, out_feat)]
96+
if normalize:
97+
layers.append(nn.BatchNorm1d(out_feat))
98+
layers.append(nn.LeakyReLU(0.2) if activation is None else activation)
99+
return layers
100+
101+
class GeneratorNet(torch.nn.Module):
102+
def __init__(self, img_shape=(IMG_SIZE, IMG_SIZE)):
103+
super().__init__()
104+
self.generated_img_shape = img_shape
105+
num_neurons_per_layer = [LATENT_SPACE_DIM, 256, 512, 1024, img_shape[0] * img_shape[1]]
106+
107+
self.net = nn.Sequential(
108+
*vanilla_block(num_neurons_per_layer[0], num_neurons_per_layer[1]),
109+
*vanilla_block(num_neurons_per_layer[1], num_neurons_per_layer[2]),
110+
*vanilla_block(num_neurons_per_layer[2], num_neurons_per_layer[3]),
111+
*vanilla_block(num_neurons_per_layer[3], num_neurons_per_layer[4], normalize=False, activation=nn.Tanh())
112+
)
113+
114+
def forward(self, latent_vector_batch):
115+
img_batch_flattened = self.net(latent_vector_batch)
116+
return img_batch_flattened.view(img_batch_flattened.shape[0], 1, *self.generated_img_shape)
117+
118+
class DiscriminatorNet(torch.nn.Module):
119+
def __init__(self, img_shape=(IMG_SIZE, IMG_SIZE)):
120+
super().__init__()
121+
num_neurons_per_layer = [img_shape[0] * img_shape[1], 1024, 512, 256, 1]
122+
123+
# Last layer is Sigmoid function - basically the goal of the discriminator is to output 1.
124+
# for real images and 0. for fake images and sigmoid is clamped between 0 and 1 so it's perfect.
125+
self.net = nn.Sequential(
126+
*vanilla_block(num_neurons_per_layer[0], num_neurons_per_layer[1], normalize=False),
127+
*vanilla_block(num_neurons_per_layer[1], num_neurons_per_layer[2], normalize=False),
128+
*vanilla_block(num_neurons_per_layer[2], num_neurons_per_layer[3], normalize=False),
129+
*vanilla_block(num_neurons_per_layer[3], num_neurons_per_layer[4], normalize=False, activation=nn.Sigmoid())
130+
)
131+
132+
def forward(self, img_batch):
133+
img_batch_flattened = img_batch.view(img_batch.shape[0], -1) # flatten from (N,1,H,W) into (N, HxW)
134+
return self.net(img_batch_flattened)
135+
136+
def get_optimizers(d_net, g_net):
137+
d_opt = Adam(d_net.parameters(), lr=0.001, betas=(0.5, 0.999))
138+
g_opt = Adam(g_net.parameters(), lr=0.001, betas=(0.5, 0.999))
139+
return d_opt, g_opt
140+
141+
#free_gpu_cache()
142+
143+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144+
145+
discriminator_net = DiscriminatorNet().train().to(device)
146+
generator_net = GeneratorNet().train().to(device)
147+
148+
discriminator_opt, generator_opt = get_optimizers(discriminator_net, generator_net)
149+
150+
adversarial_loss = nn.BCELoss()
151+
real_images_gt = torch.ones((BATCH_SIZE, 1), device=device)
152+
fake_images_gt = torch.zeros((BATCH_SIZE, 1), device=device)
153+
154+
checkpoint_freq = 2
155+
console_log_freq = 50
156+
157+
num_epochs = 10
158+
159+
ts = time.time()
160+
161+
def train_GAN():
162+
for epoch in range(num_epochs):
163+
for batch_idx, (real_images, _) in enumerate(img_dataloader):
164+
165+
real_images = real_images.to(device)
166+
167+
discriminator_opt.zero_grad()
168+
169+
real_discriminator_loss = adversarial_loss(discriminator_net(real_images), real_images_gt)
170+
171+
fake_images = generator_net(get_gaussian_latent_batch(BATCH_SIZE, device))
172+
fake_images_predictions = discriminator_net(fake_images.detach())
173+
fake_discriminator_loss = adversarial_loss(fake_images_predictions, fake_images_gt)
174+
175+
discriminator_loss = real_discriminator_loss + fake_discriminator_loss
176+
discriminator_loss.backward()
177+
discriminator_opt.step()
178+
179+
180+
generator_opt.zero_grad()
181+
generated_images_predictions = discriminator_net(generator_net(get_gaussian_latent_batch(BATCH_SIZE, device)))
182+
generator_loss = adversarial_loss(generated_images_predictions, real_images_gt)
183+
184+
generator_loss.backward()
185+
generator_opt.step()
186+
# free_gpu_cache()
187+
188+
if batch_idx % console_log_freq == 0:
189+
prefix = 'GAN training: time elapsed'
190+
print(
191+
f'{prefix} = {(time.time() - ts):.2f} [s] | epoch={epoch + 1} | batch= [{batch_idx + 1}/{len(img_dataloader)}]')
192+
193+
# Save generator checkpoint
194+
if (epoch + 1) % checkpoint_freq == 0 and batch_idx == 0:
195+
ckpt_model_name = f"vanilla_ckpt_epoch_{epoch + 1}_batch_{batch_idx + 1}.pth"
196+
torch.save(generator_net.state_dict(), os.path.join(CHECKPOINTS_PATH, ckpt_model_name))
197+
198+
# Save the latest generator in the binaries directory
199+
torch.save(generator_net.state_dict(), MODEL_PATH)
200+
201+
train_GAN()
202+
203+
def postprocess_generated_img(generated_img_tensor):
204+
assert isinstance(generated_img_tensor,
205+
torch.Tensor), f'Expected PyTorch tensor but got {type(generated_img_tensor)}.'
206+
207+
generated_img = np.moveaxis(generated_img_tensor.to('cpu').numpy()[0], 0, 2)
208+
209+
generated_img = np.repeat(generated_img, 3, axis=2)
210+
211+
generated_img -= np.min(generated_img)
212+
generated_img /= np.max(generated_img)
213+
214+
return generated_img
215+
216+
def generate_from_random_latent_vector(generator):
217+
with torch.no_grad(): # Tells PyTorch not to compute gradients which would have huge memory footprint
218+
219+
# Generate a single random (latent) vector
220+
latent_vector = get_gaussian_latent_batch(1, next(generator.parameters()).device)
221+
222+
# Post process generator output (as it's in the [-1, 1] range, remember?)
223+
generated_img = postprocess_generated_img(generator(latent_vector))
224+
225+
return generated_img
226+
227+
def save_and_maybe_display_image(dump_img, out_res=(256, 256), should_display=False):
228+
assert isinstance(dump_img, np.ndarray), f'Expected numpy array got {type(dump_img)}.'
229+
230+
os.makedirs(GENERATED_IMAGES_PATH, exist_ok=True)
231+
232+
dump_img_name = "new_image.jpg"
233+
234+
if dump_img.dtype != np.uint8:
235+
dump_img = (dump_img * 255).astype(np.uint8)
236+
237+
cv.imwrite(os.path.join(GENERATED_IMAGES_PATH, dump_img_name),
238+
cv.resize(dump_img[:, :, ::-1], out_res, interpolation=cv.INTER_NEAREST))
239+
240+
if should_display:
241+
plt.imshow(dump_img)
242+
plt.show()
243+
244+
def generate_sample_image():
245+
assert os.path.exists(MODEL_PATH), f'Could not find the model {MODEL_PATH}. You first need to train your generator.'
246+
247+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
248+
generator = GeneratorNet().to(device)
249+
250+
generator.load_state_dict(torch.load(MODEL_PATH))
251+
generator.eval()
252+
253+
print('Generating new images!')
254+
generated_img = generate_from_random_latent_vector(generator)
255+
save_and_maybe_display_image(generated_img, should_display=True)
256+
257+
generate_sample_image()

0 commit comments

Comments
 (0)