Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Anantaa Kotal authored and Anantaa Kotal committed Jan 19, 2023
1 parent 37cf673 commit 5fe8833
Show file tree
Hide file tree
Showing 170 changed files with 246 additions and 0 deletions.
Binary file added .DS_Store
Binary file not shown.
Binary file added data/.DS_Store
Binary file not shown.
Binary file added data/NIH_CXR/00000001_000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000001_001.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000001_002.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000002_000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000003_000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000003_001.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000003_002.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000003_003.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000003_004.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000003_005.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000003_006.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000003_007.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000004_000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000005_000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000005_001.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000005_002.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000005_003.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000005_004.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000005_005.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000005_006.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000005_007.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000006_000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000007_000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000008_000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000008_001.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/NIH_CXR/00000008_002.png
Binary file added data/NIH_CXR/00000009_000.png
Binary file added data/NIH_CXR/00000010_000.png
Binary file added data/NIH_CXR/00000011_000.png
Binary file added data/NIH_CXR/00000011_001.png
Binary file added data/NIH_CXR/00000011_002.png
Binary file added data/NIH_CXR/00000011_003.png
Binary file added data/NIH_CXR/00000011_004.png
Binary file added data/NIH_CXR/00000011_005.png
Binary file added data/NIH_CXR/00000011_006.png
Binary file added data/NIH_CXR/00000011_007.png
Binary file added data/NIH_CXR/00000011_008.png
Binary file added data/NIH_CXR/00000012_000.png
Binary file added data/NIH_CXR/00000013_000.png
Binary file added data/NIH_CXR/00000013_001.png
Binary file added data/NIH_CXR/00000013_002.png
Binary file added data/NIH_CXR/00000013_003.png
Binary file added data/NIH_CXR/00000013_004.png
Binary file added data/NIH_CXR/00000013_005.png
Binary file added data/NIH_CXR/00000013_006.png
Binary file added data/NIH_CXR/00000013_007.png
Binary file added data/NIH_CXR/00000013_008.png
Binary file added data/NIH_CXR/00000013_009.png
Binary file added data/NIH_CXR/00000013_010.png
Binary file added data/NIH_CXR/00000013_011.png
Binary file added data/NIH_CXR/00000013_012.png
Binary file added data/NIH_CXR/00000013_013.png
Binary file added data/NIH_CXR/00000013_014.png
Binary file added data/NIH_CXR/00000013_015.png
Binary file added data/NIH_CXR/00000013_016.png
Binary file added data/NIH_CXR/00000013_017.png
Binary file added data/NIH_CXR/00000013_018.png
Binary file added data/NIH_CXR/00000013_019.png
Binary file added data/NIH_CXR/00000013_020.png
Binary file added data/NIH_CXR/00000013_021.png
Binary file added data/NIH_CXR/00000013_022.png
Binary file added data/NIH_CXR/00000013_023.png
Binary file added data/NIH_CXR/00000013_024.png
Binary file added data/NIH_CXR/00000013_025.png
Binary file added data/NIH_CXR/00000013_026.png
Binary file added data/NIH_CXR/00000013_027.png
Binary file added data/NIH_CXR/00000013_028.png
Binary file added data/NIH_CXR/00000013_029.png
Binary file added data/NIH_CXR/00000013_030.png
Binary file added data/NIH_CXR/00000013_031.png
Binary file added data/NIH_CXR/00000013_032.png
Binary file added data/NIH_CXR/00000013_033.png
Binary file added data/NIH_CXR/00000013_034.png
Binary file added data/NIH_CXR/00000013_035.png
Binary file added data/NIH_CXR/00000013_036.png
Binary file added data/NIH_CXR/00000013_037.png
Binary file added data/NIH_CXR/00000013_038.png
Binary file added data/NIH_CXR/00000013_039.png
Binary file added data/NIH_CXR/00000013_040.png
Binary file added data/NIH_CXR/00000013_041.png
Binary file added data/NIH_CXR/00000013_042.png
Binary file added data/NIH_CXR/00000013_043.png
Binary file added data/NIH_CXR/00000013_044.png
Binary file added data/NIH_CXR/00000013_045.png
Binary file added data/NIH_CXR/00000013_046.png
Binary file added data/NIH_CXR/00000014_000.png
Binary file added data/NIH_CXR/00000015_000.png
Binary file added data/NIH_CXR/00000016_000.png
Binary file added data/NIH_CXR/00000017_000.png
Binary file added data/NIH_CXR/00000017_001.png
Binary file added data/NIH_CXR/00000017_002.png
Binary file added data/NIH_CXR/00000018_000.png
Binary file added data/NIH_CXR/00000019_000.png
Binary file added data/NIH_CXR/00000020_000.png
Binary file added data/NIH_CXR/00000020_001.png
Binary file added data/NIH_CXR/00000020_002.png
Binary file added data/NIH_CXR/00000021_000.png
Binary file added data/NIH_CXR/00000021_001.png
Binary file added data/NIH_CXR/00000022_000.png
Binary file added data/NIH_CXR/00000022_001.png
Binary file added data/NIH_CXR/00000023_000.png
Binary file added data/NIH_CXR/00000023_001.png
Binary file added data/NIH_CXR/00000023_002.png
Binary file added data/NIH_CXR/00000023_003.png
Binary file added data/NIH_CXR/00000023_004.png
Binary file added data/NIH_CXR/00000024_000.png
Binary file added data/NIH_CXR/00000025_000.png
Binary file added data/NIH_CXR/00000026_000.png
Binary file added data/NIH_CXR/00000027_000.png
Binary file added data/NIH_CXR/00000028_000.png
Binary file added data/NIH_CXR/00000029_000.png
Binary file added data/NIH_CXR/00000030_000.png
Binary file added data/NIH_CXR/00000030_001.png
Binary file added data/NIH_CXR/00000031_000.png
Binary file added data/NIH_CXR/00000032_000.png
Binary file added data/NIH_CXR/00000032_001.png
Binary file added data/NIH_CXR/00000032_002.png
Binary file added data/NIH_CXR/00000032_003.png
Binary file added data/NIH_CXR/00000032_004.png
Binary file added data/NIH_CXR/00000032_005.png
Binary file added data/NIH_CXR/00000032_006.png
Binary file added data/NIH_CXR/00000032_007.png
Binary file added data/NIH_CXR/00000032_008.png
Binary file added data/NIH_CXR/00000032_009.png
Binary file added data/NIH_CXR/00000032_010.png
Binary file added data/NIH_CXR/00000032_011.png
Binary file added data/NIH_CXR/00000032_012.png
Binary file added data/NIH_CXR/00000032_013.png
Binary file added data/NIH_CXR/00000032_014.png
Binary file added data/NIH_CXR/00000032_015.png
Binary file added data/NIH_CXR/00000032_016.png
Binary file added data/NIH_CXR/00000032_017.png
Binary file added data/NIH_CXR/00000032_018.png
Binary file added data/NIH_CXR/00000032_019.png
Binary file added data/NIH_CXR/00000032_020.png
Binary file added data/NIH_CXR/00000032_021.png
Binary file added data/NIH_CXR/00000032_022.png
Binary file added data/NIH_CXR/00000032_023.png
Binary file added data/NIH_CXR/00000032_024.png
Binary file added data/NIH_CXR/00000032_025.png
Binary file added data/NIH_CXR/00000032_026.png
Binary file added data/NIH_CXR/00000032_027.png
Binary file added data/NIH_CXR/00000032_028.png
Binary file added data/NIH_CXR/00000032_029.png
Binary file added data/NIH_CXR/00000032_030.png
Binary file added data/NIH_CXR/00000032_031.png
Binary file added data/NIH_CXR/00000032_032.png
Binary file added data/NIH_CXR/00000032_033.png
Binary file added data/NIH_CXR/00000032_034.png
Binary file added data/NIH_CXR/00000032_035.png
Binary file added data/NIH_CXR/00000099_007.png
Binary file added data/NIH_CXR/00000099_008.png
Binary file added data/NIH_CXR/00000099_009.png
Binary file added data/NIH_CXR/00000099_010.png
Binary file added data/NIH_CXR/00000099_011.png
Binary file added data/NIH_CXR/00000099_012.png
Binary file added data/NIH_CXR/00000099_013.png
Binary file added data/NIH_CXR/00000100_000.png
Binary file added data/NIH_CXR/00000100_001.png
Binary file added data/NIH_CXR/00000101_000.png
Binary file added debug_imagery/.DS_Store
Binary file not shown.
Binary file added generated_imagery/.DS_Store
Binary file not shown.
Empty file added main.ipynb
Empty file.
246 changes: 246 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# -*- coding: utf-8 -*-
"""main.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1dB_Dwq4_Kp_B_ON1mZaFb72e5-X93DTX
"""

import os
import re
import time
import enum


import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.optim import Adam
from torchvision import transforms, datasets
from torchvision.utils import make_grid, save_image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

DRIVE_PATH = os.getcwd()

BINARIES_PATH = os.path.join(DRIVE_PATH, 'models', 'binaries')
CHECKPOINTS_PATH = os.path.join(DRIVE_PATH, 'models', 'checkpoints')
MODEL_PATH = os.path.join(DRIVE_PATH, 'models', 'binaries', 'NIH_CXR.pth')
DATA_DIR_PATH = os.path.join(DRIVE_PATH, 'data')
DEBUG_IMAGERY_PATH = os.path.join(DRIVE_PATH, 'debug_imagery')
GENERATED_IMAGES_PATH = os.path.join(DRIVE_PATH, 'generated_imagery')

IMG_SIZE = 256
BATCH_SIZE = 32

transform = transforms.Compose([
# you can add other transformations in this list
transforms.Grayscale(),
transforms.Resize(IMG_SIZE),
transforms.ToTensor()
])

img_dataset = datasets.ImageFolder(DATA_DIR_PATH, transform=transform)

evens = list(range(0, len(img_dataset), 16))
odds = list(range(1, len(img_dataset), 2))
trainset_1 = torch.utils.data.Subset(img_dataset, evens)
trainset_2 = torch.utils.data.Subset(img_dataset, odds)

img_dataloader = torch.utils.data.DataLoader(img_dataset, batch_size=BATCH_SIZE, shuffle=True)



# Visualize the data

print(f'Dataset size: {len(img_dataset)} images.')

"""num_imgs_to_visualize = 1
batch = next(iter(img_dataloader))
img_batch = batch[0]
img_batch_subset = img_batch[:num_imgs_to_visualize]
print(f'Image shape {img_batch_subset.shape[1:]}')
grid = make_grid(img_batch_subset, nrow=int(np.sqrt(num_imgs_to_visualize)), normalize=True, pad_value=1.)
grid = np.moveaxis(grid.numpy(), 0, 2) # from CHW -> HWC format that's what matplotlib expects! Get used to this.
plt.figure(figsize=(6, 6))
plt.title("Samples from the NIH_CXR dataset")
plt.imshow(grid)
plt.show()"""

# Size of the generator's input vector.
LATENT_SPACE_DIM = 100


# This one will produce a batch of those vectors
def get_gaussian_latent_batch(batch_size, device):
return torch.randn((batch_size, LATENT_SPACE_DIM), device=device)


def vanilla_block(in_feat, out_feat, normalize=True, activation=None):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat))
layers.append(nn.LeakyReLU(0.2) if activation is None else activation)
return layers

class GeneratorNet(torch.nn.Module):
def __init__(self, img_shape=(IMG_SIZE, IMG_SIZE)):
super().__init__()
self.generated_img_shape = img_shape
num_neurons_per_layer = [LATENT_SPACE_DIM, 256, 512, 1024, img_shape[0] * img_shape[1]]

self.net = nn.Sequential(
*vanilla_block(num_neurons_per_layer[0], num_neurons_per_layer[1]),
*vanilla_block(num_neurons_per_layer[1], num_neurons_per_layer[2]),
*vanilla_block(num_neurons_per_layer[2], num_neurons_per_layer[3]),
*vanilla_block(num_neurons_per_layer[3], num_neurons_per_layer[4], normalize=False, activation=nn.Tanh())
)

def forward(self, latent_vector_batch):
img_batch_flattened = self.net(latent_vector_batch)
return img_batch_flattened.view(img_batch_flattened.shape[0], 1, *self.generated_img_shape)

class DiscriminatorNet(torch.nn.Module):
def __init__(self, img_shape=(IMG_SIZE, IMG_SIZE)):
super().__init__()
num_neurons_per_layer = [img_shape[0] * img_shape[1], 512, 256, 1]

# Last layer is Sigmoid function - basically the goal of the discriminator is to output 1.
# for real images and 0. for fake images and sigmoid is clamped between 0 and 1 so it's perfect.
self.net = nn.Sequential(
*vanilla_block(num_neurons_per_layer[0], num_neurons_per_layer[1], normalize=False),
*vanilla_block(num_neurons_per_layer[1], num_neurons_per_layer[2], normalize=False),
*vanilla_block(num_neurons_per_layer[2], num_neurons_per_layer[3], normalize=False, activation=nn.Sigmoid())
)

def forward(self, img_batch):
img_batch_flattened = img_batch.view(img_batch.shape[0], -1) # flatten from (N,1,H,W) into (N, HxW)
return self.net(img_batch_flattened)

def get_optimizers(d_net, g_net):
d_opt = Adam(d_net.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_opt = Adam(g_net.parameters(), lr=0.0002, betas=(0.5, 0.999))
return d_opt, g_opt

torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

discriminator_net = DiscriminatorNet().train().to(device)
generator_net = GeneratorNet().train().to(device)

discriminator_opt, generator_opt = get_optimizers(discriminator_net, generator_net)

adversarial_loss = nn.BCELoss()
real_images_gt = torch.ones((BATCH_SIZE, 1), device=device)
fake_images_gt = torch.zeros((BATCH_SIZE, 1), device=device)

checkpoint_freq = 2
console_log_freq = 50

num_epochs = 5

ts = time.time()

def train_GAN():
for epoch in range(num_epochs):
for batch_idx, (real_images, _) in enumerate(img_dataloader):

real_images = real_images.to(device)

discriminator_opt.zero_grad()

real_discriminator_loss = adversarial_loss(discriminator_net(real_images), real_images_gt)

fake_images = generator_net(get_gaussian_latent_batch(BATCH_SIZE, device))
fake_images_predictions = discriminator_net(fake_images.detach())
fake_discriminator_loss = adversarial_loss(fake_images_predictions, fake_images_gt)

discriminator_loss = real_discriminator_loss + fake_discriminator_loss
discriminator_loss.backward()
discriminator_opt.step()


generator_opt.zero_grad()
generated_images_predictions = discriminator_net(generator_net(get_gaussian_latent_batch(BATCH_SIZE, device)))
generator_loss = adversarial_loss(generated_images_predictions, real_images_gt)

generator_loss.backward()
generator_opt.step()

if batch_idx % console_log_freq == 0:
prefix = 'GAN training: time elapsed'
print(
f'{prefix} = {(time.time() - ts):.2f} [s] | epoch={epoch + 1} | batch= [{batch_idx + 1}/{len(img_dataloader)}]')

# Save generator checkpoint
if (epoch + 1) % checkpoint_freq == 0 and batch_idx == 0:
ckpt_model_name = f"vanilla_ckpt_epoch_{epoch + 1}_batch_{batch_idx + 1}.pth"
torch.save(generator_net.state_dict(), os.path.join(CHECKPOINTS_PATH, ckpt_model_name))

# Save the latest generator in the binaries directory
torch.save(generator_net.state_dict(), MODEL_PATH)

train_GAN()

def postprocess_generated_img(generated_img_tensor):
assert isinstance(generated_img_tensor,
torch.Tensor), f'Expected PyTorch tensor but got {type(generated_img_tensor)}.'

generated_img = np.moveaxis(generated_img_tensor.to('cpu').numpy()[0], 0, 2)

generated_img = np.repeat(generated_img, 3, axis=2)

generated_img -= np.min(generated_img)
generated_img /= np.max(generated_img)

return generated_img

def generate_from_random_latent_vector(generator):
with torch.no_grad(): # Tells PyTorch not to compute gradients which would have huge memory footprint

# Generate a single random (latent) vector
latent_vector = get_gaussian_latent_batch(1, next(generator.parameters()).device)

# Post process generator output (as it's in the [-1, 1] range, remember?)
generated_img = postprocess_generated_img(generator(latent_vector))

return generated_img

def save_and_maybe_display_image(dump_img, out_res=(256, 256), should_display=False):
assert isinstance(dump_img, np.ndarray), f'Expected numpy array got {type(dump_img)}.'

os.makedirs(GENERATED_IMAGES_PATH, exist_ok=True)

dump_img_name = "new_image.jpg"

if dump_img.dtype != np.uint8:
dump_img = (dump_img * 255).astype(np.uint8)

cv.imwrite(os.path.join(GENERATED_IMAGES_PATH, dump_img_name),
cv.resize(dump_img[:, :, ::-1], out_res, interpolation=cv.INTER_NEAREST))

if should_display:
plt.imshow(dump_img)
plt.show()

def generate_sample_image():
assert os.path.exists(MODEL_PATH), f'Could not find the model {MODEL_PATH}. You first need to train your generator.'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = GeneratorNet().to(device)

generator.load_state_dict(torch.load(MODEL_PATH))
generator.eval()

print('Generating new images!')
generated_img = generate_from_random_latent_vector(generator)
save_and_maybe_display_image(generated_img, should_display=True)

generate_sample_image()
Binary file added models/.DS_Store
Binary file not shown.
Binary file added models/binaries/.DS_Store
Binary file not shown.
Binary file added models/checkpoints/.DS_Store
Binary file not shown.
Binary file added runs/.DS_Store
Binary file not shown.

0 comments on commit 5fe8833

Please sign in to comment.