Skip to content

Commit

Permalink
[TRAINING] Adding 32x32 image training and testing
Browse files Browse the repository at this point in the history
  • Loading branch information
Hussem Ben Belgacem committed Jan 20, 2021
1 parent eba8738 commit ef37a0e
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 21 deletions.
16 changes: 12 additions & 4 deletions colorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from skimage.transform import resize
from torch import from_numpy

from src.models import Generator
from src.models import Generator32, Generator256


def preprocess(image):
Expand All @@ -34,21 +34,29 @@ def postprocess(in_image, prediction):
return predicted_image


def getModel(image_size):
if image_size == 32:
return Generator32()
return Generator256()


def main(config):
image_path = config['image_path']
generator_path = config['generator_path']
save_path = config['save_path']
height, width = config['image_size'][0], config['image_size'][1]
image_size = config['image_size']

assert image_size == 32 or image_size == 256, "image_size should be equal to 32 or 256 for the training :("

image = io.imread(image_path)
if len(image.shape) == 2:
image = color.gray2rgb(image)

image = image[:, :, :3]
original_image = resize(image, (height, width))
original_image = resize(image, (image_size, image_size))
image = preprocess(original_image)

model = Generator().double()
model = getModel(image_size).double()
model.load_state_dict(
torch.load(
generator_path,
Expand Down
2 changes: 1 addition & 1 deletion config/colorize.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"image_path": "",
"generator_path": "",
"save_path": "",
"image_size": [256, 256]
"image_size": 256
}
2 changes: 1 addition & 1 deletion config/train.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"betas": [0.5, 0.999],
"epochs": 200,
"lambda": 100,
"image_size": [256, 256],
"image_size": 256,
"save_path": "",
"train_percentage": 0.8,
"test_percentage": 0.1,
Expand Down
2 changes: 1 addition & 1 deletion src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .dcgan import Generator, Discriminator
from .dcgan import Discriminator32, Discriminator256, Generator32, Generator256
77 changes: 73 additions & 4 deletions src/models/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,78 @@ def _conv(in_channels, out_channels, stride=1, kernel_size=3, padding=1):
)


class Generator(nn.Module):
class Generator32(nn.Module):
def __init__(self):
super(Generator, self).__init__()
super(Generator32, self).__init__()

self.down_sample1 = _down_sample(1, 64, kernel_size=3, stride=1)
self.down_sample2 = _down_sample(64, 128)
self.down_sample3 = _down_sample(128, 256)
self.down_sample4 = _down_sample(256, 512)
self.down_sample5 = _down_sample(512, 512)

self.up_sample1 = _up_sample(512, 512)
self.up_sample2 = _up_sample(512, 256)
self.up_sample3 = _up_sample(256, 128)
self.up_sample4 = _up_sample(128, 64)

self.conv1 = _conv(1024, 512)
self.conv2 = _conv(512, 256)
self.conv3 = _conv(256, 128)
self.conv4 = _conv(128, 64)

self.output = nn.Sequential(
nn.Conv2d(64, 2, 1, 1),
nn.Tanh()
)

def forward(self, x):
x1 = self.down_sample1(x)
x2 = self.down_sample2(x1)
x3 = self.down_sample3(x2)
x4 = self.down_sample4(x3)
x5 = self.down_sample5(x4)

x = self.conv1(torch.cat([x4, self.up_sample1(x5)], 1))
x = self.conv2(torch.cat([x3, self.up_sample2(x)], 1))
x = self.conv3(torch.cat([x2, self.up_sample3(x)], 1))
x = self.conv4(torch.cat([x1, self.up_sample4(x)], 1))

x = self.output(x)

return x


class Discriminator32(nn.Module):
def __init__(self):
super(Discriminator32, self).__init__()

self.down_sample1 = _down_sample(3, 64)
self.down_sample2 = _down_sample(64, 128)
self.down_sample3 = _down_sample(128, 256)
self.down_sample4 = _down_sample(256, 512, stride=1, kernel_size=3)

self.conv = nn.Conv2d(512, 1, 1, 1)

self.output = nn.Linear(4 * 4, 1)

def forward(self, x):
x = self.down_sample1(x)
x = self.down_sample2(x)
x = self.down_sample3(x)
x = self.down_sample4(x)

x = self.conv(x)
x = x.view(-1, 4 * 4)
x = self.output(x)
x = x.squeeze(-1)

return x


class Generator256(nn.Module):
def __init__(self):
super(Generator256, self).__init__()

self.down_sample1 = _down_sample(1, 64, kernel_size=3, stride=1)
self.down_sample2 = _down_sample(64, 64)
Expand Down Expand Up @@ -102,9 +171,9 @@ def forward(self, x):
return x


class Discriminator(nn.Module):
class Discriminator256(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
super(Discriminator256, self).__init__()

self.down_sample1 = _down_sample(3, 64)
self.down_sample2 = _down_sample(64, 128)
Expand Down
2 changes: 0 additions & 2 deletions src/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def __init__(self, g_model, d_model, g_optimizer, d_optimizer, config,
self.epochs = config['epochs']
self.l1_lambda = config['lambda']
self.save_path = config['save_path']
self.height = config['image_size'][0]
self.width = config['image_size'][1]
self.early_stop_patience = config['early_stop_patience']
self.log_loss_d = []
self.log_loss_g = []
Expand Down
23 changes: 15 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from torch.utils.data import DataLoader, dataloader

from src.datasets import ImageColorizationDataset
from src.models import Discriminator, Generator
from src.models import (Discriminator32, Discriminator256, Generator32,
Generator256)
from src.tester import DCGANTester
from src.trainer import DCGANTrainer
from src.utils import RGB2LAB, NormalizeImage, Resize, ToTensor
Expand All @@ -33,9 +34,6 @@ def my_collate(batch):


def splitData(data_path, save_path, train, test, shuffle):
assert os.path.exists(data_path), "data_path given to splitData doesn't exists :("
assert train + test < 1, "train percentage and test percentage should summup to be < 1 to keep some data for validation :("

data = glob.glob(os.path.join(data_path, '*'))

if shuffle is True:
Expand All @@ -56,18 +54,28 @@ def splitData(data_path, save_path, train, test, shuffle):
return train_data, test_data, validation_data


def getModel(image_size, device):
if image_size == 32:
return Generator32().to(device), Discriminator32().to(device)
return Generator256().to(device), Discriminator256().to(device)


def main(config):
batch_size = config['batch_size']
betas = config['betas']
data_path = config['data_path']
height, width = config['image_size'][0], config['image_size'][1]
image_size = config['image_size']
learning_rate = config['learning_rate']
save_path = config['save_path']
shuffle_data = config['shuffle_data']
test_model = config['test_model']
train_percentage = config['train_percentage']
test_percentage = config['test_percentage']

assert os.path.exists(data_path), "data_path given to splitData doesn't exists :("
assert train_percentage + test_percentage < 1, "train percentage and test percentage should summup to be < 1 to keep some data for validation :("
assert image_size == 32 or image_size == 256, "image_size should be equal to 32 or 256 for the training :("

train_data, test_data, validation_data = splitData(
data_path=data_path,
save_path=save_path,
Expand All @@ -77,7 +85,7 @@ def main(config):
)

transforms = torchvision.transforms.Compose([
Resize(size=(height, width)),
Resize(size=(image_size, image_size)),
RGB2LAB(),
NormalizeImage(),
ToTensor()
Expand Down Expand Up @@ -113,8 +121,7 @@ def main(config):

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

g_model = Generator().to(device)
d_model = Discriminator().to(device)
g_model, d_model = getModel(image_size, device)

g_optimizer = Adam(
params=list(g_model.parameters()),
Expand Down

0 comments on commit ef37a0e

Please sign in to comment.