diff --git a/.gitignore b/.gitignore index d13b64e..bc88db3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ # Edit at https://www.toptal.com/developers/gitignore?templates=python,opencv ### Custom ### +weight/ *.mp4 ### OpenCV ### diff --git a/README.md b/README.md index ef68615..4ccb56e 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,22 @@ For this project, FGVC(Flow-edge Guided Video Completion) deep learning model wa | scipy | 1.6.2 | +# Usage + +```sh +# Remove __pycache__. +$ find . | grep -E "(__pycache__|\.pyc|\.pyo$)" | xargs rm -rf + +# Run video inpainting. +$ python run_inpainting.py \ +>> --path 'D:/_data/f250_color' \ +>> --path_mask 'D:/_data/f250_mask' \ +>> --outroot 'D:/_data/f250_result_final' \ +>> --merge \ +>> --run +``` + + # Update History ### v1.2 @@ -58,4 +74,4 @@ For this project, FGVC(Flow-edge Guided Video Completion) deep learning model wa
--- -**Updated :** 2021-11-20 22:42 +**Updated :** 2021-11-27 03:25 diff --git a/modules/DeepFill/DeepFill.py b/modules/DeepFill/DeepFill.py new file mode 100644 index 0000000..81a3039 --- /dev/null +++ b/modules/DeepFill/DeepFill.py @@ -0,0 +1,84 @@ +from .ops import * + + +class Generator(nn.Module): + def __init__(self, first_dim=32, isCheck=False, device=None): + super(Generator, self).__init__() + self.isCheck = isCheck + self.device = device + self.stage_1 = CoarseNet(5, first_dim, device=device) + self.stage_2 = RefinementNet(5, first_dim, device=device) + + def forward(self, masked_img, mask, small_mask): # mask : 1 x 1 x H x W + + # border, maybe + mask = mask.expand(masked_img.size(0),1,masked_img.size(2),masked_img.size(3)) + small_mask = small_mask.expand(masked_img.size(0), 1, masked_img.size(2) // 8, masked_img.size(3) // 8) + if self.device: + ones = to_var(torch.ones(mask.size()), device=self.device) + else: + ones = to_var(torch.ones(mask.size())) + # stage1 + stage1_input = torch.cat([masked_img, ones, ones*mask], dim=1) + stage1_output, resized_mask = self.stage_1(stage1_input, mask) + # stage2 + new_masked_img = stage1_output*mask.clone() + masked_img.clone()*(1.-mask.clone()) + stage2_input = torch.cat([new_masked_img, ones.clone(), ones.clone()*mask.clone()], dim=1) + stage2_output, offset_flow = self.stage_2(stage2_input, small_mask) + + return stage1_output, stage2_output, offset_flow + + +class CoarseNet(nn.Module): + ''' + # input: B x 5 x W x H + # after down: B x 128(32*4) x W/4 x H/4 + # after atrous: same with the output size of the down module + # after up : same with the input size + ''' + def __init__(self, in_ch, out_ch, device=None): + super(CoarseNet,self).__init__() + self.down = Down_Module(in_ch, out_ch) + self.atrous = Dilation_Module(out_ch*4, out_ch*4) + self.up = Up_Module(out_ch*4, 3) + self.device=device + + def forward(self, x, mask): + x = self.down(x) + resized_mask = down_sample(mask, scale_factor=0.25, mode='nearest', device=self.device) + x = self.atrous(x) + x = self.up(x) + + return x, resized_mask + + +class RefinementNet(nn.Module): + ''' + # input: B x 5 x W x H + # after down: B x 128(32*4) x W/4 x H/4 + # after atrous: same with the output size of the down module + # after up : same with the input size + ''' + def __init__(self, in_ch, out_ch, device=None): + super(RefinementNet,self).__init__() + self.down_conv_branch = Down_Module(in_ch, out_ch, isRefine=True) + self.down_attn_branch = Down_Module(in_ch, out_ch, activation=nn.ReLU(), isRefine=True, isAttn=True) + self.atrous = Dilation_Module(out_ch*4, out_ch*4) + self.CAttn = Contextual_Attention_Module(out_ch*4, out_ch*4, device=device) + self.up = Up_Module(out_ch*8, 3, isRefine=True) + + def forward(self, x, resized_mask): + # conv branch + conv_x = self.down_conv_branch(x) + conv_x = self.atrous(conv_x) + + # attention branch + attn_x = self.down_attn_branch(x) + + attn_x, offset_flow = self.CAttn(attn_x, attn_x, mask=resized_mask) + + # concat two branches + deconv_x = torch.cat([conv_x, attn_x], dim=1) # deconv_x => B x 256 x W/4 x H/4 + x = self.up(deconv_x) + + return x, offset_flow diff --git a/modules/DeepFill/__init__.py b/modules/DeepFill/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/DeepFill/ops.py b/modules/DeepFill/ops.py new file mode 100644 index 0000000..bf13b76 --- /dev/null +++ b/modules/DeepFill/ops.py @@ -0,0 +1,431 @@ +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + + +def weights_init(init_type='gaussian'): + def init_fun(m): + classname = m.__class__.__name__ + if (classname.find('Conv') == 0 or classname.find( + 'Linear') == 0) and hasattr(m, 'weight'): + if init_type == 'gaussian': + nn.init.normal_(m.weight, 0.0, 0.02) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) + elif init_type == 'default': + pass + else: + assert 0, "Unsupported initialization: {}".format(init_type) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias, 0.0) + + return init_fun + + +class Conv(nn.Module): + def __init__(self, in_ch, out_ch, K=3, S=1, P=1, D=1, activation=nn.ELU(), isGated=False): + super(Conv, self).__init__() + if activation is not None: + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=K, stride=S, padding=P, dilation=D), + activation + ) + else: + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=K, stride=S, padding=P, dilation=D) + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + m.apply(weights_init('kaiming')) + + def forward(self, x): + x = self.conv(x) + return x + + +class Conv_Downsample(nn.Module): + def __init__(self, in_ch, out_ch, K=3, S=1, P=1, D=1, activation=nn.ELU()): + super(Conv_Downsample, self).__init__() + + PaddingLayer = torch.nn.ZeroPad2d((0, (K-1)//2, 0, (K-1)//2)) + + if activation is not None: + self.conv = nn.Sequential( + PaddingLayer, + nn.Conv2d(in_ch, out_ch, kernel_size=K, stride=S, padding=0, dilation=D), + activation + ) + else: + self.conv = nn.Sequential( + PaddingLayer, + nn.Conv2d(in_ch, out_ch, kernel_size=K, stride=S, padding=0, dilation=D) + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + m.apply(weights_init('kaiming')) + + def forward(self, x): + x = self.conv(x) + return x + + +class Down_Module(nn.Module): + def __init__(self, in_ch, out_ch, activation=nn.ELU(), isRefine=False, + isAttn=False, ): + super(Down_Module, self).__init__() + layers = [] + layers.append(Conv(in_ch, out_ch, K=5, P=2)) + # curr_dim = out_ch + # layers.append(Conv_Downsample(curr_dim, curr_dim * 2, K=3, S=2, isGated=isGated)) + + curr_dim = out_ch + if isRefine: + if isAttn: + layers.append(Conv_Downsample(curr_dim, curr_dim, K=3, S=2)) + layers.append(Conv(curr_dim, 2*curr_dim, K=3, S=1)) + layers.append(Conv_Downsample(2*curr_dim, 4*curr_dim, K=3, S=2)) + layers.append(Conv(4 * curr_dim, 4 * curr_dim, K=3, S=1)) + curr_dim *= 4 + else: + for i in range(2): + layers.append(Conv_Downsample(curr_dim, curr_dim, K=3, S=2)) + layers.append(Conv(curr_dim, curr_dim*2)) + curr_dim *= 2 + else: + for i in range(2): + layers.append(Conv_Downsample(curr_dim, curr_dim*2, K=3, S=2)) + layers.append(Conv(curr_dim * 2, curr_dim * 2)) + curr_dim *= 2 + + layers.append(Conv(curr_dim, curr_dim, activation=activation)) + + self.out = nn.Sequential(*layers) + + def forward(self, x): + return self.out(x) + + +class Dilation_Module(nn.Module): + def __init__(self, in_ch, out_ch): + super(Dilation_Module, self).__init__() + layers = [] + dilation = 1 + for i in range(4): + dilation *= 2 + layers.append(Conv(in_ch, out_ch, D=dilation, P=dilation)) + self.out = nn.Sequential(*layers) + + def forward(self, x): + return self.out(x) + + +class Up_Module(nn.Module): + def __init__(self, in_ch, out_ch, isRefine=False): + super(Up_Module, self).__init__() + layers = [] + curr_dim = in_ch + if isRefine: + layers.append(Conv(curr_dim, curr_dim//2)) + curr_dim //= 2 + else: + layers.append(Conv(curr_dim, curr_dim)) + + # conv 12~15 + for i in range(2): + layers.append(Conv(curr_dim, curr_dim)) + layers.append(nn.Upsample(scale_factor=2, mode='nearest')) + layers.append(Conv(curr_dim, curr_dim//2)) + curr_dim //= 2 + + layers.append(Conv(curr_dim, curr_dim//2)) + layers.append(Conv(curr_dim//2, out_ch, activation=None)) + + self.out = nn.Sequential(*layers) + + def forward(self, x): + output = self.out(x) + return torch.clamp(output, min=-1., max=1.) + + +class Up_Module_CNet(nn.Module): + def __init__(self, in_ch, out_ch, isRefine=False, isGated=False): + super(Up_Module_CNet, self).__init__() + layers = [] + curr_dim = in_ch + if isRefine: + layers.append(Conv(curr_dim, curr_dim//2, isGated=isGated)) + curr_dim //= 2 + else: + layers.append(Conv(curr_dim, curr_dim, isGated=isGated)) + + # conv 12~15 + for i in range(2): + layers.append(Conv(curr_dim, curr_dim, isGated=isGated)) + layers.append(nn.Upsample(scale_factor=2, mode='nearest')) + layers.append(Conv(curr_dim, curr_dim//2, isGated=isGated)) + curr_dim //= 2 + + layers.append(Conv(curr_dim, curr_dim//2, isGated=isGated)) + layers.append(Conv(curr_dim//2, out_ch, activation=None, isGated=isGated)) + + self.out = nn.Sequential(*layers) + + def forward(self, x): + output = self.out(x) + return output + + +class Flatten_Module(nn.Module): + def __init__(self, in_ch, out_ch, isLocal=True): + super(Flatten_Module, self).__init__() + layers = [] + layers.append(Conv(in_ch, out_ch, K=5, S=2, P=2, activation=nn.LeakyReLU())) + curr_dim = out_ch + + for i in range(2): + layers.append(Conv(curr_dim, curr_dim*2, K=5, S=2, P=2, activation=nn.LeakyReLU())) + curr_dim *= 2 + + if isLocal: + layers.append(Conv(curr_dim, curr_dim*2, K=5, S=2, P=2, activation=nn.LeakyReLU())) + else: + layers.append(Conv(curr_dim, curr_dim, K=5, S=2, P=2, activation=nn.LeakyReLU())) + + self.out = nn.Sequential(*layers) + + def forward(self, x): + x = self.out(x) + return x.view(x.size(0),-1) # 2B x 256*(256 or 512); front 256:16*16 + + +class Contextual_Attention_Module(nn.Module): + def __init__(self, in_ch, out_ch, rate=2, stride=1, isCheck=False, device=None): + super(Contextual_Attention_Module, self).__init__() + self.rate = rate + self.padding = nn.ZeroPad2d(1) + self.up_sample = nn.Upsample(scale_factor=self.rate, mode='nearest') + layers = [] + for i in range(2): + layers.append(Conv(in_ch, out_ch)) + self.out = nn.Sequential(*layers) + self.isCheck = isCheck + self.device = device + + def forward(self, f, b, mask=None, ksize=3, stride=1, + fuse_k=3, softmax_scale=10., training=True, fuse=True): + + """ Contextual attention layer implementation. + + Contextual attention is first introduced in publication: + Generative Image Inpainting with Contextual Attention, Yu et al. + + Args: + f: Input feature to match (foreground). + b: Input feature for match (background). + mask: Input mask for b, indicating patches not available. + ksize: Kernel size for contextual attention. + stride: Stride for extracting patches from b. + rate: Dilation for matching. + softmax_scale: Scaled softmax for attention. + training: Indicating if current graph is training or inference. + + Returns: + tf.Tensor: output + + """ + + # get shapes + raw_fs = f.size() # B x 128 x 64 x 64 + raw_int_fs = list(f.size()) + raw_int_bs = list(b.size()) + + # extract patches from background with stride and rate + kernel = 2*self.rate + raw_w = self.extract_patches(b, kernel=kernel, stride=self.rate) + raw_w = raw_w.permute(0, 2, 3, 4, 5, 1) + raw_w = raw_w.contiguous().view(raw_int_bs[0], raw_int_bs[2] // self.rate, raw_int_bs[3] // self.rate, -1) + raw_w = raw_w.contiguous().view(raw_int_bs[0], -1, kernel, kernel, raw_int_bs[1]) + raw_w = raw_w.permute(0, 1, 4, 2, 3) + + f = down_sample(f, scale_factor=1/self.rate, mode='nearest', device=self.device) + b = down_sample(b, scale_factor=1/self.rate, mode='nearest', device=self.device) + + fs = f.size() # B x 128 x 32 x 32 + int_fs = list(f.size()) + f_groups = torch.split(f, 1, dim=0) # Split tensors by batch dimension; tuple is returned + + # from b(B*H*W*C) to w(b*k*k*c*h*w) + bs = b.size() # B x 128 x 32 x 32 + int_bs = list(b.size()) + w = self.extract_patches(b) + w = w.permute(0, 2, 3, 4, 5, 1) + w = w.contiguous().view(raw_int_bs[0], raw_int_bs[2] // self.rate, raw_int_bs[3] // self.rate, -1) + w = w.contiguous().view(raw_int_bs[0], -1, ksize, ksize, raw_int_bs[1]) + w = w.permute(0, 1, 4, 2, 3) + # process mask + mask = mask.clone() + if mask is not None: + if mask.size(2) != b.size(2): + mask = down_sample(mask, scale_factor=1./self.rate, mode='nearest', device=self.device) + else: + mask = torch.zeros([1, 1, bs[2], bs[3]]) + + m = self.extract_patches(mask) + + m = m.permute(0, 2, 3, 4, 5, 1) + m = m.contiguous().view(raw_int_bs[0], raw_int_bs[2] // self.rate, raw_int_bs[3] // self.rate, -1) + m = m.contiguous().view(raw_int_bs[0], -1, ksize, ksize, 1) + m = m.permute(0, 4, 1, 2, 3) + + m = m[0] # (1, 32*32, 3, 3) + m = reduce_mean(m) # smoothing, maybe + mm = m.eq(0.).float() # (1, 32*32, 1, 1) + + w_groups = torch.split(w, 1, dim=0) # Split tensors by batch dimension; tuple is returned + raw_w_groups = torch.split(raw_w, 1, dim=0) # Split tensors by batch dimension; tuple is returned + y = [] + offsets = [] + k = fuse_k + scale = softmax_scale + fuse_weight = Variable(torch.eye(k).view(1, 1, k, k)).cuda(self.device) # 1 x 1 x K x K + y_test = [] + for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups): + ''' + O => output channel as a conv filter + I => input channel as a conv filter + xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32) + wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3) + raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4) + ''' + # conv for compare + wi = wi[0] + escape_NaN = Variable(torch.FloatTensor([1e-4])).cuda(self.device) + wi_normed = wi / torch.max(l2_norm(wi), escape_NaN) + yi = F.conv2d(xi, wi_normed, stride=1, padding=1) # yi => (B=1, C=32*32, H=32, W=32) + y_test.append(yi) + # conv implementation for fuse scores to encourage large patches + if fuse: + yi = yi.permute(0, 2, 3, 1) + yi = yi.contiguous().view(1, fs[2] * fs[3], bs[2] * bs[3], 1) + yi = yi.permute(0, 3, 1, 2) # make all of depth to spatial resolution, (B=1, I=1, H=32*32, W=32*32) + yi = F.conv2d(yi, fuse_weight, stride=1, padding=1) # (B=1, C=1, H=32*32, W=32*32) + + yi = yi.permute(0, 2, 3, 1) + yi = yi.contiguous().view(1, fs[2], fs[3], bs[2], bs[3]) + # yi = yi.contiguous().view(1, fs[2], fs[3], bs[2], bs[3]) # (B=1, 32, 32, 32, 32) + yi = yi.permute(0, 2, 1, 4, 3) + yi = yi.contiguous().view(1, fs[2] * fs[3], bs[2] * bs[3], 1) + yi = yi.permute(0, 3, 1, 2) + + yi = F.conv2d(yi, fuse_weight, stride=1, padding=1) + yi = yi.permute(0, 2, 3, 1) + yi = yi.contiguous().view(1, fs[3], fs[2], bs[3], bs[2]) + yi = yi.permute(0, 2, 1, 4, 3) + yi = yi.contiguous().view(1, fs[2], fs[3], bs[2] * bs[3]) + yi = yi.permute(0, 3, 1, 2) + else: + yi = yi.permute(0, 2, 3, 1) + yi = yi.contiguous().view(1, fs[2], fs[3], bs[2] * bs[3]) + yi = yi.permute(0, 3, 1, 2) # (B=1, C=32*32, H=32, W=32) + # yi = yi.contiguous().view(1, bs[2] * bs[3], fs[2], fs[3]) + + # softmax to match + yi = yi * mm # mm => (1, 32*32, 1, 1) + yi = F.softmax(yi*scale, dim=1) + yi = yi * mm # mask + + _, offset = torch.max(yi, dim=1) # argmax; index + division = torch.true_divide(offset, fs[3]).long() + offset = torch.stack([division, torch.true_divide(offset, fs[3])-division], dim=-1) + + wi_center = raw_wi[0] + + yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4. # (B=1, C=128, H=64, W=64) + y.append(yi) + offsets.append(offset) + + y = torch.cat(y, dim=0) # back to the mini-batch + y.contiguous().view(raw_int_fs) + # wi_patched = y + offsets = torch.cat(offsets, dim=0) + offsets = offsets.view([int_bs[0]] + [2] + int_bs[2:]) + + # case1: visualize optical flow: minus current position + h_add = Variable(torch.arange(0,float(bs[2]))).cuda(self.device).view([1, 1, bs[2], 1]) + h_add = h_add.expand(bs[0], 1, bs[2], bs[3]) + w_add = Variable(torch.arange(0,float(bs[3]))).cuda(self.device).view([1, 1, 1, bs[3]]) + w_add = w_add.expand(bs[0], 1, bs[2], bs[3]) + + offsets = offsets - torch.cat([h_add, w_add], dim=1).long() + + # # case2: visualize which pixels are attended + # flow = torch.from_numpy(highlight_flow((offsets * mask.int()).numpy())) + y = self.out(y) + + return y, offsets + + def extract_patches(self, x, kernel=3, stride=1): + x = self.padding(x) + all_patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride) + + return all_patches + + +def reduce_mean(x): + for i in range(4): + if i==1: continue + x = torch.mean(x, dim=i, keepdim=True) + return x + + +def l2_norm(x): + def reduce_sum(x): + for i in range(4): + if i==0: continue + x = torch.sum(x, dim=i, keepdim=True) + return x + + x = x**2 + x = reduce_sum(x) + return torch.sqrt(x) + + +def down_sample(x, size=None, scale_factor=None, mode='nearest', device=None): + # define size if user has specified scale_factor + if size is None: size = (int(scale_factor*x.size(2)), int(scale_factor*x.size(3))) + # create coordinates + # size_origin = [x.size[2], x.size[3]] + h = torch.true_divide(torch.arange(0, size[0]), (size[0])) * 2 - 1 + w = torch.true_divide(torch.arange(0, size[1]), (size[1])) * 2 - 1 + # create grid + grid = torch.zeros(size[0],size[1],2) + grid[:,:,0] = w.unsqueeze(0).repeat(size[0],1) + grid[:,:,1] = h.unsqueeze(0).repeat(size[1],1).transpose(0,1) + # expand to match batch size + grid = grid.unsqueeze(0).repeat(x.size(0),1,1,1) + if x.is_cuda: + if device: + grid = Variable(grid).cuda(device) + else: + grid = Variable(grid).cuda() + # do sampling + + return F.grid_sample(x, grid, mode=mode) + + +def to_var(x, volatile=False, device=None): + if torch.cuda.is_available(): + if device: + x = x.cuda(device) + else: + x = x.cuda() + return Variable(x, volatile=volatile) diff --git a/modules/DeepFill/util.py b/modules/DeepFill/util.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/EdgeConnect/__init__.py b/modules/EdgeConnect/__init__.py new file mode 100644 index 0000000..bc63beb --- /dev/null +++ b/modules/EdgeConnect/__init__.py @@ -0,0 +1 @@ +# empty \ No newline at end of file diff --git a/modules/EdgeConnect/config.py b/modules/EdgeConnect/config.py new file mode 100644 index 0000000..523699d --- /dev/null +++ b/modules/EdgeConnect/config.py @@ -0,0 +1,64 @@ +import os +import yaml + +class Config(dict): + def __init__(self, config_path): + with open(config_path, 'r') as f: + self._yaml = f.read() + self._dict = yaml.load(self._yaml) + self._dict['PATH'] = os.path.dirname(config_path) + + def __getattr__(self, name): + if self._dict.get(name) is not None: + return self._dict[name] + + if DEFAULT_CONFIG.get(name) is not None: + return DEFAULT_CONFIG[name] + + return None + + def print(self): + print('Model configurations:') + print('---------------------------------') + print(self._yaml) + print('') + print('---------------------------------') + print('') + + +DEFAULT_CONFIG = { + 'MODE': 1, # 1: train, 2: test, 3: eval + 'MODEL': 1, # 1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model + 'MASK': 3, # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half) + 'EDGE': 1, # 1: canny, 2: external + 'NMS': 1, # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny + 'SEED': 10, # random seed + 'GPU': [0], # list of gpu ids + 'DEBUG': 0, # turns on debugging mode + 'VERBOSE': 0, # turns on verbose mode in the output console + + 'LR': 0.0001, # learning rate + 'D2G_LR': 0.1, # discriminator/generator learning rate ratio + 'BETA1': 0.0, # adam optimizer beta1 + 'BETA2': 0.9, # adam optimizer beta2 + 'BATCH_SIZE': 8, # input batch size for training + 'INPUT_SIZE': 256, # input image size for training 0 for original size + 'SIGMA': 2, # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge) + 'MAX_ITERS': 2e6, # maximum number of iterations to train the model + + 'EDGE_THRESHOLD': 0.5, # edge detection threshold + 'L1_LOSS_WEIGHT': 1, # l1 loss weight + 'FM_LOSS_WEIGHT': 10, # feature-matching loss weight + 'STYLE_LOSS_WEIGHT': 1, # style loss weight + 'CONTENT_LOSS_WEIGHT': 1, # perceptual loss weight + 'INPAINT_ADV_LOSS_WEIGHT': 0.01,# adversarial loss weight + + 'GAN_LOSS': 'nsgan', # nsgan | lsgan | hinge + 'GAN_POOL_SIZE': 0, # fake images pool size + + 'SAVE_INTERVAL': 1000, # how many iterations to wait before saving model (0: never) + 'SAMPLE_INTERVAL': 1000, # how many iterations to wait before sampling (0: never) + 'SAMPLE_SIZE': 12, # number of images to sample + 'EVAL_INTERVAL': 0, # how many iterations to wait before model evaluation (0: never) + 'LOG_INTERVAL': 10, # how many iterations to wait before logging training status (0: never) +} diff --git a/modules/EdgeConnect/dataset.py b/modules/EdgeConnect/dataset.py new file mode 100644 index 0000000..5dfc69f --- /dev/null +++ b/modules/EdgeConnect/dataset.py @@ -0,0 +1,404 @@ +import os +import cv2 +import glob +import scipy +import torch +import random +import numpy as np +import torchvision.transforms.functional as F +from torch.utils.data import DataLoader +from PIL import Image +from scipy.misc import imread +from skimage.feature import canny +from skimage.color import rgb2gray, gray2rgb +from .utils import create_mask +import src.region_fill as rf + +class Dataset(torch.utils.data.Dataset): + def __init__(self, config, flist, edge_flist, mask_flist, augment=True, training=True): + super(Dataset, self).__init__() + + self.augment = augment + self.training = training + self.flo = config.FLO + self.norm = config.NORM + self.data = self.load_flist(flist, self.flo) + self.edge_data = self.load_flist(edge_flist, 0) + self.mask_data = self.load_flist(mask_flist, 0) + + self.input_size = config.INPUT_SIZE + self.sigma = config.SIGMA + self.edge = config.EDGE + self.mask = config.MASK + self.nms = config.NMS + + + + # in test mode, there's a one-to-one relationship between mask and image + # masks are loaded non random + if config.MODE == 2: + self.mask = 6 + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + try: + item = self.load_item(index) + except: + print('loading error: ' + self.data[index]) + item = self.load_item(0) + + return item + + def load_name(self, index): + name = self.data[index] + return os.path.basename(name) + + def load_item(self, index): + size = self.input_size + factor = 1. + if self.flo == 0: + + # load image + img = imread(self.data[index]) + + # gray to rgb + if len(img.shape) < 3: + img = gray2rgb(img) + + # resize/crop if needed + if size != 0: + img = self.resize(img, size[0], size[1]) + + # create grayscale image + img_gray = rgb2gray(img) + + # load mask + mask = self.load_mask(img, index) + + edge = self.load_edge(img_gray, index, mask) + + img_filled = img + + else: + + img = self.readFlow(self.data[index]) + + # resize/crop if needed + if size != 0: + img = self.flow_tf(img, [size[0], size[1]]) + + img_gray = (img[:, :, 0] ** 2 + img[:, :, 1] ** 2) ** 0.5 + + if self.norm == 1: + # normalization + # factor = (np.abs(img[:, :, 0]).max() ** 2 + np.abs(img[:, :, 1]).max() ** 2) ** 0.5 + factor = img_gray.max() + img /= factor + + # load mask + mask = self.load_mask(img, index) + + edge = self.load_edge(img_gray, index, mask) + img_gray = img_gray / img_gray.max() + + img_filled = np.zeros(img.shape) + img_filled[:, :, 0] = rf.regionfill(img[:, :, 0], mask) + img_filled[:, :, 1] = rf.regionfill(img[:, :, 1], mask) + + + # augment data + if self.augment and np.random.binomial(1, 0.5) > 0: + img = img[:, ::-1, ...].copy() + img_filled = img_filled[:, ::-1, ...].copy() + img_gray = img_gray[:, ::-1, ...] + edge = edge[:, ::-1, ...] + mask = mask[:, ::-1, ...] + + return self.to_tensor(img), self.to_tensor(img_filled), self.to_tensor(img_gray), self.to_tensor(edge), self.to_tensor(mask), factor + + def load_edge(self, img, index, mask): + sigma = self.sigma + + # in test mode images are masked (with masked regions), + # using 'mask' parameter prevents canny to detect edges for the masked regions + mask = None if self.training else (1 - mask / 255).astype(np.bool) + + # canny + if self.edge == 1: + # no edge + if sigma == -1: + return np.zeros(img.shape).astype(np.float) + + # random sigma + if sigma == 0: + sigma = random.randint(1, 4) + return canny(img, sigma=sigma, mask=mask).astype(np.float) + + # external + else: + imgh, imgw = img.shape[0:2] + edge = imread(self.edge_data[index]) + edge = self.resize(edge, imgh, imgw) + + # non-max suppression + if self.nms == 1: + edge = edge * canny(img, sigma=sigma, mask=mask) + + return edge + + def load_mask(self, img, index): + imgh, imgw = img.shape[0:2] + mask_type = self.mask + + # external + random block + if mask_type == 4: + mask_type = 1 if np.random.binomial(1, 0.5) == 1 else 3 + + # external + random block + half + elif mask_type == 5: + mask_type = np.random.randint(1, 4) + + # random block + if mask_type == 1: + return create_mask(imgw, imgh, imgw // 2, imgh // 2) + + # half + if mask_type == 2: + # randomly choose right or left + return create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0) + + # external + if mask_type == 3: + mask_index = random.randint(0, len(self.mask_data) - 1) + mask = imread(self.mask_data[mask_index]) + mask = self.resize(mask, imgh, imgw, centerCrop=False) + mask = (mask > 0).astype(np.uint8) * 255 # threshold due to interpolation + return mask + + # test mode: load mask non random + + if mask_type == 6: + mask = imread(self.mask_data[index]) + mask = self.resize(mask, imgh, imgw, centerCrop=False) + mask = rgb2gray(mask) + mask = (mask > 0).astype(np.uint8) * 255 + return mask + + def to_tensor(self, img): + if (len(img.shape) == 3 and img.shape[2] == 2): + return F.to_tensor(img).float() + img = Image.fromarray(img) + img_t = F.to_tensor(img).float() + return img_t + + def resize(self, img, height, width, centerCrop=True): + imgh, imgw = img.shape[0:2] + + if centerCrop and imgh != imgw: + # center crop + side = np.minimum(imgh, imgw) + j = (imgh - side) // 2 + i = (imgw - side) // 2 + img = img[j:j + side, i:i + side, ...] + + img = scipy.misc.imresize(img, [height, width]) + + return img + + def load_flist(self, flist, flo=0): + if isinstance(flist, list): + return flist + + # flist: image file path, image directory path, text file flist path + if flo == 0: + if isinstance(flist, str): + if os.path.isdir(flist): + flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png')) + flist.sort() + return flist + + if os.path.isfile(flist): + try: + return np.genfromtxt(flist, dtype=np.str, encoding='utf-8') + except: + return [flist] + else: + if isinstance(flist, str): + if os.path.isdir(flist): + flist = list(glob.glob(flist + '/*.flo')) + flist.sort() + return flist + + if os.path.isfile(flist): + try: + return np.genfromtxt(flist, dtype=np.str, encoding='utf-8') + except: + return [flist] + + return [] + + def create_iterator(self, batch_size): + while True: + sample_loader = DataLoader( + dataset=self, + batch_size=batch_size, + drop_last=True + ) + + for item in sample_loader: + yield item + + def readFlow(self, fn): + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + + def flow_to_image(self, flow): + + UNKNOWN_FLOW_THRESH = 1e7 + + u = flow[:, :, 0] + v = flow[:, :, 1] + + maxu = -999. + maxv = -999. + minu = 999. + minv = 999. + + idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) + u[idxUnknow] = 0 + v[idxUnknow] = 0 + + maxu = max(maxu, np.max(u)) + minu = min(minu, np.min(u)) + + maxv = max(maxv, np.max(v)) + minv = min(minv, np.min(v)) + + rad = np.sqrt(u ** 2 + v ** 2) + maxrad = max(-1, np.max(rad)) + + # print("max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv)) + + u = u/(maxrad + np.finfo(float).eps) + v = v/(maxrad + np.finfo(float).eps) + + img = self.compute_color(u, v) + + idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) + img[idx] = 0 + + return np.uint8(img) + + + def compute_color(self, u, v): + """ + compute optical flow color map + :param u: optical flow horizontal map + :param v: optical flow vertical map + :return: optical flow in color code + """ + [h, w] = u.shape + img = np.zeros([h, w, 3]) + nanIdx = np.isnan(u) | np.isnan(v) + u[nanIdx] = 0 + v[nanIdx] = 0 + + colorwheel = self.make_color_wheel() + ncols = np.size(colorwheel, 0) + + rad = np.sqrt(u**2+v**2) + + a = np.arctan2(-v, -u) / np.pi + + fk = (a+1) / 2 * (ncols - 1) + 1 + + k0 = np.floor(fk).astype(int) + + k1 = k0 + 1 + k1[k1 == ncols+1] = 1 + f = fk - k0 + + for i in range(0, np.size(colorwheel,1)): + tmp = colorwheel[:, i] + col0 = tmp[k0-1] / 255 + col1 = tmp[k1-1] / 255 + col = (1-f) * col0 + f * col1 + + idx = rad <= 1 + col[idx] = 1-rad[idx]*(1-col[idx]) + notidx = np.logical_not(idx) + + col[notidx] *= 0.75 + img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) + + return img + + + def make_color_wheel(self): + """ + Generate color wheel according Middlebury color code + :return: Color wheel + """ + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + + colorwheel = np.zeros([ncols, 3]) + + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) + col += RY + + # YG + colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) + colorwheel[col:col+YG, 1] = 255 + col += YG + + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) + col += GC + + # CB + colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) + colorwheel[col:col+CB, 2] = 255 + col += CB + + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) + col += + BM + + # MR + colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) + colorwheel[col:col+MR, 0] = 255 + + return colorwheel + + def flow_tf(self, flow, size): + flow_shape = flow.shape + flow_resized = cv2.resize(flow, (size[1], size[0])) + flow_resized[:, :, 0] *= (float(size[1]) / float(flow_shape[1])) + flow_resized[:, :, 1] *= (float(size[0]) / float(flow_shape[0])) + + return flow_resized diff --git a/modules/EdgeConnect/edge_connect.py b/modules/EdgeConnect/edge_connect.py new file mode 100644 index 0000000..12aefe9 --- /dev/null +++ b/modules/EdgeConnect/edge_connect.py @@ -0,0 +1,715 @@ +import os +import numpy as np +import torch +import torchvision.utils as vutils +from torch.utils.data import DataLoader +from .dataset import Dataset +from .models import EdgeModel, InpaintingModel +from .utils import Progbar, create_dir, stitch_images, imsave +from .metrics import PSNR, EdgeAccuracy +from tensorboardX import SummaryWriter + + +class EdgeConnect(): + def __init__(self, config): + self.config = config + + if config.MODEL == 1: + model_name = 'edge' + elif config.MODEL == 2: + model_name = 'inpaint' + elif config.MODEL == 3: + model_name = 'edge_inpaint' + elif config.MODEL == 4: + model_name = 'joint' + + self.debug = False + self.model_name = model_name + self.edge_model = EdgeModel(config).to(config.DEVICE) + self.inpaint_model = InpaintingModel(config).to(config.DEVICE) + + self.psnr = PSNR(255.0).to(config.DEVICE) + self.edgeacc = EdgeAccuracy(config.EDGE_THRESHOLD).to(config.DEVICE) + + # test mode + if self.config.MODE == 2: + self.test_dataset = Dataset(config, config.TEST_FLIST, config.TEST_EDGE_FLIST, config.TEST_MASK_FLIST, augment=False, training=False) + else: + self.train_dataset = Dataset(config, config.TRAIN_FLIST, config.TRAIN_EDGE_FLIST, config.TRAIN_MASK_FLIST, augment=True, training=True) + self.val_dataset = Dataset(config, config.VAL_FLIST, config.VAL_EDGE_FLIST, config.VAL_MASK_FLIST, augment=False, training=True) + self.sample_iterator = self.val_dataset.create_iterator(config.SAMPLE_SIZE) + + self.samples_path = os.path.join(config.PATH, 'samples') + self.results_path = os.path.join(config.PATH, 'results') + + if config.RESULTS is not None: + self.results_path = os.path.join(config.RESULTS) + + if config.DEBUG is not None and config.DEBUG != 0: + self.debug = True + + self.log_file = os.path.join(config.PATH, 'log_' + model_name + '.dat') + + self.writer = SummaryWriter('/home/gaochen/Project/edge-connect/checkpoints/flow_NORM_FILL_MASK_RES/logs') + + def load(self): + if self.config.MODEL == 1: + self.edge_model.load() + + elif self.config.MODEL == 2: + self.inpaint_model.load() + + else: + self.edge_model.load() + self.inpaint_model.load() + + def save(self): + if self.config.MODEL == 1: + self.edge_model.save() + + elif self.config.MODEL == 2 or self.config.MODEL == 3: + self.inpaint_model.save() + + else: + self.edge_model.save() + self.inpaint_model.save() + + def train(self): + train_loader = DataLoader( + dataset=self.train_dataset, + batch_size=self.config.BATCH_SIZE, + num_workers=4, + drop_last=True, + shuffle=True + ) + # train_loader = DataLoader( + # dataset=self.train_dataset, + # batch_size=1, + # ) + epoch = 0 + keep_training = True + model = self.config.MODEL + max_iteration = int(float((self.config.MAX_ITERS))) + total = len(self.train_dataset) + + if total == 0: + print('No training data was provided! Check \'TRAIN_FLIST\' value in the configuration file.') + return + + while(keep_training): + epoch += 1 + print('\n\nTraining epoch: %d' % epoch) + + progbar = Progbar(total, width=20, stateful_metrics=['epoch', 'iter']) + + for items in train_loader: + self.edge_model.train() + self.inpaint_model.train() + + images, images_filled, images_gray, edges, masks, factor = self.cuda(*items) + + # edge model + if model == 1: + # train + outputs, gen_loss, dis_loss, logs = self.edge_model.process(images_gray, edges, masks) + + # metrics + precision, recall = self.edgeacc(edges * masks, outputs * masks) + logs.append(('precision', precision.item())) + logs.append(('recall', recall.item())) + + # backward + self.edge_model.backward(gen_loss, dis_loss) + iteration = self.edge_model.iteration + + # inpaint model + elif model == 2: + # train + outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, images_filled, edges, masks) + # import ipdb; ipdb.set_trace() + # cv2.imwrite('/home/gaochen/test.png', self.flow2img(images * factor, 10)[0]) + # cv2.imwrite('/home/gaochen/images_filled.png', self.flow2img(images_filled * factor, 10)[0]) + # cv2.imwrite('/home/gaochen/outputs.png', self.flow2img(outputs.detach() * factor, 10)[0]) + # cv2.imwrite('/home/gaochen/edges.png', edges.cpu().numpy()[0,0,:,:] * 255) + + if self.config.NORM == 1: + outputs = outputs * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + images = images * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + images_filled = images_filled * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + + outputs_merged = (outputs * masks) + (images * (1 - masks)) + + # metrics + if self.config.FLO == 0: + psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged)) + elif self.config.FLO == 1: + psnr = self.psnr(torch.from_numpy(self.flow2img(images, 10)), torch.from_numpy(self.flow2img(outputs_merged.detach(), 10))) + else: + assert(0) + mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(torch.abs(images))).float() + logs.append(('psnr', psnr.item())) + logs.append(('mae', mae.item())) + + # backward + self.inpaint_model.backward(gen_loss, dis_loss) + iteration = self.inpaint_model.iteration + + # inpaint with edge model + elif model == 3: + # train + if True or np.random.binomial(1, 0.5) > 0: + outputs = self.edge_model(images_gray, edges, masks) + outputs = outputs * masks + edges * (1 - masks) + else: + outputs = edges + + outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, images_filled, outputs.detach(), masks) + + if self.config.NORM == 1: + outputs = outputs * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + images = images * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + images_filled = images_filled * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + + outputs_merged = (outputs * masks) + (images * (1 - masks)) + + # metrics + if self.config.FLO == 0: + psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged)) + elif self.config.FLO == 1: + psnr = self.psnr(torch.from_numpy(self.flow2img(images)), torch.from_numpy(self.flow2img(outputs_merged))) + else: + assert(0) + + mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(torch.abs(images))).float() + logs.append(('psnr', psnr.item())) + logs.append(('mae', mae.item())) + + # backward + self.inpaint_model.backward(gen_loss, dis_loss) + iteration = self.inpaint_model.iteration + + + # joint model + else: + # train + e_outputs, e_gen_loss, e_dis_loss, e_logs = self.edge_model.process(images_gray, edges, masks) + e_outputs = e_outputs * masks + edges * (1 - masks) + i_outputs, i_gen_loss, i_dis_loss, i_logs = self.inpaint_model.process(images, images_filled, e_outputs, masks) + + if self.config.NORM == 1: + i_outputs = i_outputs * factor.reshape(-1, 1, 1, 1).type(i_outputs.dtype) + images = images * factor.reshape(-1, 1, 1, 1).type(i_outputs.dtype) + images_filled = images_filled * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + + outputs_merged = (i_outputs * masks) + (images * (1 - masks)) + + # metrics + if self.config.FLO == 0: + psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged)) + elif self.config.FLO == 1: + psnr = self.psnr(torch.from_numpy(self.flow2img(images)), torch.from_numpy(self.flow2img(outputs_merged))) + else: + assert(0) + + mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(torch.abs(images))).float() + precision, recall = self.edgeacc(edges * masks, e_outputs * masks) + e_logs.append(('pre', precision.item())) + e_logs.append(('rec', recall.item())) + i_logs.append(('psnr', psnr.item())) + i_logs.append(('mae', mae.item())) + logs = e_logs + i_logs + + # backward + self.inpaint_model.backward(i_gen_loss, i_dis_loss) + self.edge_model.backward(e_gen_loss, e_dis_loss) + iteration = self.inpaint_model.iteration + + + if iteration >= max_iteration: + keep_training = False + break + + for idx in range(len(logs)): + self.writer.add_scalar(logs[idx][0], logs[idx][1], iteration) + + + logs = [ + ("epoch", epoch), + ("iter", iteration), + ] + logs + + progbar.add(len(images), values=logs if self.config.VERBOSE else [x for x in logs if not x[0].startswith('l_')]) + + # log model at checkpoints + if self.config.LOG_INTERVAL and iteration % self.config.LOG_INTERVAL == 0: + self.log(logs) + + # sample model at checkpoints + if self.config.SAMPLE_INTERVAL and iteration % self.config.SAMPLE_INTERVAL == 0: + self.sample() + + # evaluate model at checkpoints + if self.config.EVAL_INTERVAL and iteration % self.config.EVAL_INTERVAL == 0: + print('\nstart eval...\n') + self.eval() + + # save model at checkpoints + if self.config.SAVE_INTERVAL and iteration % self.config.SAVE_INTERVAL == 0: + self.save() + + print('\nEnd training....') + + def eval(self): + val_loader = DataLoader( + dataset=self.val_dataset, + batch_size=self.config.BATCH_SIZE, + drop_last=True, + shuffle=True + ) + + model = self.config.MODEL + total = len(self.val_dataset) + + self.edge_model.eval() + self.inpaint_model.eval() + + progbar = Progbar(total, width=20, stateful_metrics=['it']) + iteration = 0 + + for items in val_loader: + iteration += 1 + # images, images_gray, edges, masks = self.cuda(*items) + images, images_filled, images_gray, edges, masks, factor = self.cuda(*items) + + # edge model + if model == 1: + # eval + outputs, gen_loss, dis_loss, logs = self.edge_model.process(images_gray, edges, masks) + + # metrics + precision, recall = self.edgeacc(edges * masks, outputs * masks) + logs.append(('precision', precision.item())) + logs.append(('recall', recall.item())) + + + # inpaint model + elif model == 2: + # eval + outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, images_filled, edges, masks) + + if self.config.NORM == 1: + outputs = outputs * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + images = images * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + images_filled = images_filled * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + + outputs_merged = (outputs * masks) + (images * (1 - masks)) + + # metrics + if self.config.FLO == 0: + psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged)) + elif self.config.FLO == 1: + psnr = self.psnr(torch.from_numpy(self.flow2img(images)), torch.from_numpy(self.flow2img(outputs_merged))) + else: + assert(0) + + mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float() + logs.append(('psnr', psnr.item())) + logs.append(('mae', mae.item())) + + + # inpaint with edge model + elif model == 3: + # eval + outputs = self.edge_model(images_gray, edges, masks) + outputs = outputs * masks + edges * (1 - masks) + + outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, images_filled, outputs.detach(), masks) + + if self.config.NORM == 1: + outputs = outputs * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + images = images * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + images_filled = images_filled * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + + outputs_merged = (outputs * masks) + (images * (1 - masks)) + + # metrics + if self.config.FLO == 0: + psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged)) + elif self.config.FLO == 1: + psnr = self.psnr(torch.from_numpy(self.flow2img(images)), torch.from_numpy(self.flow2img(outputs_merged))) + else: + assert(0) + + mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float() + logs.append(('psnr', psnr.item())) + logs.append(('mae', mae.item())) + + + # joint model + else: + # eval + e_outputs, e_gen_loss, e_dis_loss, e_logs = self.edge_model.process(images_gray, edges, masks) + e_outputs = e_outputs * masks + edges * (1 - masks) + i_outputs, i_gen_loss, i_dis_loss, i_logs = self.inpaint_model.process(images, images_filled, e_outputs, masks) + + if self.config.NORM == 1: + i_outputs = i_outputs * factor.reshape(-1, 1, 1, 1).type(i_outputs.dtype) + images = images * factor.reshape(-1, 1, 1, 1).type(i_outputs.dtype) + images_filled = images_filled * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + + outputs_merged = (i_outputs * masks) + (images * (1 - masks)) + + # metrics + if self.config.FLO == 0: + psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged)) + elif self.config.FLO == 1: + psnr = self.psnr(torch.from_numpy(self.flow2img(images)), torch.from_numpy(self.flow2img(outputs_merged))) + else: + assert(0) + + mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float() + precision, recall = self.edgeacc(edges * masks, e_outputs * masks) + e_logs.append(('pre', precision.item())) + e_logs.append(('rec', recall.item())) + i_logs.append(('psnr', psnr.item())) + i_logs.append(('mae', mae.item())) + logs = e_logs + i_logs + + + logs = [("it", iteration), ] + logs + progbar.add(len(images), values=logs) + + def test(self): + self.edge_model.eval() + self.inpaint_model.eval() + + model = self.config.MODEL + create_dir(self.results_path) + + test_loader = DataLoader( + dataset=self.test_dataset, + batch_size=1, + ) + + index = 0 + for items in test_loader: + name = self.test_dataset.load_name(index) + images, images_filled, images_gray, edges, masks, factor = self.cuda(*items) + # images, images_gray, edges, masks = self.cuda(*items) + index += 1 + + # edge model + if model == 1: + outputs = self.edge_model(images_gray, edges, masks) + outputs_merged = (outputs * masks) + (edges * (1 - masks)) + + # inpaint model + elif model == 2: + outputs = self.inpaint_model(images, images_filled, edges, masks) + + if self.config.NORM == 1: + outputs = outputs * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + images = images * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + images_filled = images_filled * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + outputs_merged = (outputs * masks) + (images * (1 - masks)) + + # inpaint with edge model / joint model + else: + edges = self.edge_model(images_gray, edges, masks).detach() + outputs = self.inpaint_model(images, images_filled, edges, masks) + + if self.config.NORM == 1: + outputs = outputs * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + images = images * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + images_filled = images_filled * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + + outputs_merged = (outputs * masks) + (images * (1 - masks)) + + if os.path.splitext(name)[1] == '.flo': + name = os.path.splitext(name)[0] + '.png' + + if self.config.FLO == 0: + output = self.postprocess(outputs_merged)[0] + elif self.config.FLO == 1 and model == 1: + output = self.postprocess(outputs_merged)[0] + elif self.config.FLO == 1: + output = torch.from_numpy(self.flow2img(outputs_merged.detach())) + else: + assert(0) + path = os.path.join(self.results_path, name) + print(index, name) + + imsave(output, path) + + if self.debug: + edges = self.postprocess(1 - edges)[0] + masked = self.postprocess(images * (1 - masks) + masks)[0] + fname, fext = name.split('.') + + imsave(edges, os.path.join(self.results_path, fname + '_edge.' + fext)) + imsave(masked, os.path.join(self.results_path, fname + '_masked.' + fext)) + + print('\nEnd test....') + + def sample(self, it=None): + # do not sample when validation set is empty + if len(self.val_dataset) == 0: + return + + self.edge_model.eval() + self.inpaint_model.eval() + + model = self.config.MODEL + items = next(self.sample_iterator) + # images, images_gray, edges, masks = self.cuda(*items) + images, images_filled, images_gray, edges, masks, factor = self.cuda(*items) + + # cv2.imwrite('/home/gaochen/test.png', images_gray.detach().cpu().numpy()[0].transpose(1,2,0)*255) + # cv2.imwrite('/home/gaochen/test.png', edges.detach().cpu().numpy()[0].transpose(1,2,0)*255) + # edge model + + if model == 1: + iteration = self.edge_model.iteration + inputs = (images_gray * (1 - masks)) + masks + outputs = self.edge_model(images_gray, edges, masks) + outputs_merged = (outputs * masks) + (edges * (1 - masks)) + + # inpaint model + elif model == 2: + iteration = self.inpaint_model.iteration + outputs = self.inpaint_model(images, images_filled, edges, masks) + + if self.config.NORM == 1: + outputs = outputs * factor.reshape(-1, 1, 1, 1).float() + images = images * factor.reshape(-1, 1, 1, 1).float() + images_filled = images_filled * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + + outputs_merged = (outputs * masks) + (images * (1 - masks)) + inputs = (images * (1 - masks)) # + masks + # inpaint with edge model / joint model + else: + iteration = self.inpaint_model.iteration + outputs = self.edge_model(images_gray, edges, masks).detach() + edges = (outputs * masks + edges * (1 - masks)).detach() + outputs = self.inpaint_model(images, images_filled, edges, masks) + + if self.config.NORM == 1: + outputs = outputs * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + images = images * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + images_filled = images_filled * factor.reshape(-1, 1, 1, 1).type(outputs.dtype) + + outputs_merged = (outputs * masks) + (images * (1 - masks)) + inputs = (images * (1 - masks)) # + masks + if it is not None: + iteration = it + + image_per_row = 2 + if self.config.SAMPLE_SIZE <= 6: + image_per_row = 1 + + if self.config.FLO == 0: + images_ = stitch_images( + self.postprocess(images), + self.postprocess(inputs), + self.postprocess(edges), + self.postprocess(outputs), + self.postprocess(outputs_merged), + img_per_row = image_per_row + ) + elif self.config.FLO == 1: + if self.config.FILL == 1: + images_ = stitch_images( + torch.from_numpy(self.flow2img(images, 10)), + torch.from_numpy(self.flow2img(inputs, 10)), + torch.from_numpy(self.flow2img(images_filled, 10)), + self.postprocess(edges), + torch.from_numpy(self.flow2img(outputs.detach(), 10)), + torch.from_numpy(self.flow2img(outputs_merged.detach(), 10)), + img_per_row = image_per_row + ) + else: + images_ = stitch_images( + torch.from_numpy(self.flow2img(images, 10)), + torch.from_numpy(self.flow2img(inputs, 10)), + self.postprocess(edges), + torch.from_numpy(self.flow2img(outputs.detach(), 10)), + torch.from_numpy(self.flow2img(outputs_merged.detach(), 10)), + img_per_row = image_per_row + ) + else: + assert(0) + # + # self.writer.add_image('images', vutils.make_grid(torch.from_numpy(self.flow2img(images, 10)).permute(0, 3, 1, 2), scale_each=True), iteration) + # self.writer.add_image('inputs', vutils.make_grid(torch.from_numpy(self.flow2img(inputs, 10)).permute(0, 3, 1, 2), scale_each=True), iteration) + # self.writer.add_image('images_filled', vutils.make_grid(torch.from_numpy(self.flow2img(images_filled, 10)).permute(0, 3, 1, 2), scale_each=True), iteration) + # self.writer.add_image('outputs', vutils.make_grid(torch.from_numpy(self.flow2img(outputs.detach(), 10)).permute(0, 3, 1, 2), scale_each=True), iteration) + # self.writer.add_image('outputs_merged', vutils.make_grid(torch.from_numpy(self.flow2img(outputs_merged.detach(), 10)).permute(0, 3, 1, 2), scale_each=True), iteration) + + path = os.path.join(self.samples_path, self.model_name) + name = os.path.join(path, str(iteration).zfill(5) + ".png") + create_dir(path) + print('\nsaving sample ' + name) + images_.save(name) + + def log(self, logs): + with open(self.log_file, 'a') as f: + f.write('%s\n' % ' '.join([str(item[1]) for item in logs])) + + def cuda(self, *args): + return (item.to(self.config.DEVICE) for item in args) + + def postprocess(self, img): + # [0, 1] => [0, 255] + img = img * 255.0 + img = img.permute(0, 2, 3, 1) + return img.int() + + + def flow2img(self, flows, global_max=None): + flows = flows.permute(0, 2, 3, 1) + imgs = np.empty((0, flows.shape[1], flows.shape[2], 3), np.uint8) + for idx in range(len(flows)): + imgs = np.concatenate((imgs, np.expand_dims((self.flow_to_image(flows[idx, :, :, :].cpu().numpy(), global_max)), 0)), axis=0) + return imgs + + + def flow_to_image(self, flow, global_max): + + UNKNOWN_FLOW_THRESH = 1e7 + + u = flow[:, :, 0] + v = flow[:, :, 1] + + maxu = -999. + maxv = -999. + minu = 999. + minv = 999. + + idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) + u[idxUnknow] = 0 + v[idxUnknow] = 0 + + maxu = max(maxu, np.max(u)) + minu = min(minu, np.min(u)) + + maxv = max(maxv, np.max(v)) + minv = min(minv, np.min(v)) + + rad = np.sqrt(u ** 2 + v ** 2) + maxrad = max(-1, np.max(rad)) + + if global_max != None: + maxrad = global_max + # print("max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv)) + + u = u/(maxrad + np.finfo(float).eps) + v = v/(maxrad + np.finfo(float).eps) + + img = self.compute_color(u, v) + + idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) + img[idx] = 0 + + return np.uint8(img) + + + def compute_color(self, u, v): + """ + compute optical flow color map + :param u: optical flow horizontal map + :param v: optical flow vertical map + :return: optical flow in color code + """ + [h, w] = u.shape + img = np.zeros([h, w, 3]) + nanIdx = np.isnan(u) | np.isnan(v) + u[nanIdx] = 0 + v[nanIdx] = 0 + + colorwheel = self.make_color_wheel() + ncols = np.size(colorwheel, 0) + + rad = np.sqrt(u**2+v**2) + + a = np.arctan2(-v, -u) / np.pi + + fk = (a+1) / 2 * (ncols - 1) + 1 + + k0 = np.floor(fk).astype(int) + + k1 = k0 + 1 + k1[k1 == ncols+1] = 1 + f = fk - k0 + + for i in range(0, np.size(colorwheel,1)): + tmp = colorwheel[:, i] + col0 = tmp[k0-1] / 255 + col1 = tmp[k1-1] / 255 + col = (1-f) * col0 + f * col1 + + idx = rad <= 1 + col[idx] = 1-rad[idx]*(1-col[idx]) + notidx = np.logical_not(idx) + + col[notidx] *= 0.75 + img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) + + return img + + + def make_color_wheel(self): + """ + Generate color wheel according Middlebury color code + :return: Color wheel + """ + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + + colorwheel = np.zeros([ncols, 3]) + + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) + col += RY + + # YG + colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) + colorwheel[col:col+YG, 1] = 255 + col += YG + + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) + col += GC + + # CB + colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) + colorwheel[col:col+CB, 2] = 255 + col += CB + + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) + col += + BM + + # MR + colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) + colorwheel[col:col+MR, 0] = 255 + + return colorwheel + + def flow_tf(self, flow, size): + flow_shape = flow.shape + flow_resized = cv2.resize(flow, (size[1], size[0])) + flow_resized[:, :, 0] *= (float(size[1]) / float(flow_shape[1])) + flow_resized[:, :, 1] *= (float(size[0]) / float(flow_shape[0])) + + return flow_resized diff --git a/modules/EdgeConnect/loss.py b/modules/EdgeConnect/loss.py new file mode 100644 index 0000000..f99fd60 --- /dev/null +++ b/modules/EdgeConnect/loss.py @@ -0,0 +1,251 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + + +class TotalVariationalLoss(nn.Module): + def __init__(self): + super(TotalVariationalLoss, self).__init__() + + def _tensor_size(self, x): + return x.size()[1] * x.size()[2] * x.size()[3] + + def __call__(self, x): + + batch_size = x.size()[0] + h_x = x.size()[2] + w_x = x.size()[3] + count_h = self._tensor_size(x[:, :, 1:, :]) + count_w = self._tensor_size(x[:, :, :, 1:]) + h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() + w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() + return 2 * (h_tv / count_h + w_tv / count_w) / batch_size + + +class AdversarialLoss(nn.Module): + r""" + Adversarial loss + https://arxiv.org/abs/1711.10337 + """ + + def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0): + r""" + type = nsgan | lsgan | hinge + """ + super(AdversarialLoss, self).__init__() + + self.type = type + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + + if type == 'nsgan': + self.criterion = nn.BCELoss() + + elif type == 'lsgan': + self.criterion = nn.MSELoss() + + elif type == 'hinge': + self.criterion = nn.ReLU() + + def __call__(self, outputs, is_real, is_disc=None): + if self.type == 'hinge': + if is_disc: + if is_real: + outputs = -outputs + return self.criterion(1 + outputs).mean() + else: + return (-outputs).mean() + + else: + labels = (self.real_label if is_real else self.fake_label).expand_as(outputs) + loss = self.criterion(outputs, labels) + return loss + + +class StyleLoss(nn.Module): + r""" + Perceptual loss, VGG-based + https://arxiv.org/abs/1603.08155 + https://github.com/dxyang/StyleTransfer/blob/master/utils.py + """ + + def __init__(self): + super(StyleLoss, self).__init__() + self.add_module('vgg', VGG19()) + self.criterion = torch.nn.L1Loss() + + def compute_gram(self, x): + b, ch, h, w = x.size() + f = x.view(b, ch, w * h) + f_T = f.transpose(1, 2) + G = f.bmm(f_T) / (h * w * ch) + + return G + + def __call__(self, x, y): + # Compute features + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + + # Compute loss + style_loss = 0.0 + style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2'])) + style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4'])) + style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4'])) + style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2'])) + + return style_loss + + + +class PerceptualLoss(nn.Module): + r""" + Perceptual loss, VGG-based + https://arxiv.org/abs/1603.08155 + https://github.com/dxyang/StyleTransfer/blob/master/utils.py + """ + + def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): + super(PerceptualLoss, self).__init__() + self.add_module('vgg', VGG19()) + self.criterion = torch.nn.L1Loss() + self.weights = weights + + def __call__(self, x, y): + # Compute features + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + + content_loss = 0.0 + content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) + content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) + content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) + content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) + content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) + + + return content_loss + + + +class VGG19(torch.nn.Module): + def __init__(self): + super(VGG19, self).__init__() + features = models.vgg19(pretrained=True).features + self.relu1_1 = torch.nn.Sequential() + self.relu1_2 = torch.nn.Sequential() + + self.relu2_1 = torch.nn.Sequential() + self.relu2_2 = torch.nn.Sequential() + + self.relu3_1 = torch.nn.Sequential() + self.relu3_2 = torch.nn.Sequential() + self.relu3_3 = torch.nn.Sequential() + self.relu3_4 = torch.nn.Sequential() + + self.relu4_1 = torch.nn.Sequential() + self.relu4_2 = torch.nn.Sequential() + self.relu4_3 = torch.nn.Sequential() + self.relu4_4 = torch.nn.Sequential() + + self.relu5_1 = torch.nn.Sequential() + self.relu5_2 = torch.nn.Sequential() + self.relu5_3 = torch.nn.Sequential() + self.relu5_4 = torch.nn.Sequential() + + for x in range(2): + self.relu1_1.add_module(str(x), features[x]) + + for x in range(2, 4): + self.relu1_2.add_module(str(x), features[x]) + + for x in range(4, 7): + self.relu2_1.add_module(str(x), features[x]) + + for x in range(7, 9): + self.relu2_2.add_module(str(x), features[x]) + + for x in range(9, 12): + self.relu3_1.add_module(str(x), features[x]) + + for x in range(12, 14): + self.relu3_2.add_module(str(x), features[x]) + + for x in range(14, 16): + self.relu3_3.add_module(str(x), features[x]) + + for x in range(16, 18): + self.relu3_4.add_module(str(x), features[x]) + + for x in range(18, 21): + self.relu4_1.add_module(str(x), features[x]) + + for x in range(21, 23): + self.relu4_2.add_module(str(x), features[x]) + + for x in range(23, 25): + self.relu4_3.add_module(str(x), features[x]) + + for x in range(25, 27): + self.relu4_4.add_module(str(x), features[x]) + + for x in range(27, 30): + self.relu5_1.add_module(str(x), features[x]) + + for x in range(30, 32): + self.relu5_2.add_module(str(x), features[x]) + + for x in range(32, 34): + self.relu5_3.add_module(str(x), features[x]) + + for x in range(34, 36): + self.relu5_4.add_module(str(x), features[x]) + + # don't need the gradients, just want the features + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + relu1_1 = self.relu1_1(x) + relu1_2 = self.relu1_2(relu1_1) + + relu2_1 = self.relu2_1(relu1_2) + relu2_2 = self.relu2_2(relu2_1) + + relu3_1 = self.relu3_1(relu2_2) + relu3_2 = self.relu3_2(relu3_1) + relu3_3 = self.relu3_3(relu3_2) + relu3_4 = self.relu3_4(relu3_3) + + relu4_1 = self.relu4_1(relu3_4) + relu4_2 = self.relu4_2(relu4_1) + relu4_3 = self.relu4_3(relu4_2) + relu4_4 = self.relu4_4(relu4_3) + + relu5_1 = self.relu5_1(relu4_4) + relu5_2 = self.relu5_2(relu5_1) + relu5_3 = self.relu5_3(relu5_2) + relu5_4 = self.relu5_4(relu5_3) + + out = { + 'relu1_1': relu1_1, + 'relu1_2': relu1_2, + + 'relu2_1': relu2_1, + 'relu2_2': relu2_2, + + 'relu3_1': relu3_1, + 'relu3_2': relu3_2, + 'relu3_3': relu3_3, + 'relu3_4': relu3_4, + + 'relu4_1': relu4_1, + 'relu4_2': relu4_2, + 'relu4_3': relu4_3, + 'relu4_4': relu4_4, + + 'relu5_1': relu5_1, + 'relu5_2': relu5_2, + 'relu5_3': relu5_3, + 'relu5_4': relu5_4, + } + return out diff --git a/modules/EdgeConnect/metrics.py b/modules/EdgeConnect/metrics.py new file mode 100644 index 0000000..94d161c --- /dev/null +++ b/modules/EdgeConnect/metrics.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn + + +class EdgeAccuracy(nn.Module): + """ + Measures the accuracy of the edge map + """ + def __init__(self, threshold=0.5): + super(EdgeAccuracy, self).__init__() + self.threshold = threshold + + def __call__(self, inputs, outputs): + labels = (inputs > self.threshold) + outputs = (outputs > self.threshold) + + relevant = torch.sum(labels.float()) + selected = torch.sum(outputs.float()) + + if relevant == 0 and selected == 0: + return torch.tensor(1), torch.tensor(1) + + true_positive = ((outputs == labels) * labels).float() + recall = torch.sum(true_positive) / (relevant + 1e-8) + precision = torch.sum(true_positive) / (selected + 1e-8) + + return precision, recall + + +class PSNR(nn.Module): + def __init__(self, max_val): + super(PSNR, self).__init__() + + base10 = torch.log(torch.tensor(10.0)) + max_val = torch.tensor(max_val).float() + + self.register_buffer('base10', base10) + self.register_buffer('max_val', 20 * torch.log(max_val) / base10) + + def __call__(self, a, b): + mse = torch.mean((a.float() - b.float()) ** 2) + + if mse == 0: + return torch.tensor(0) + + return self.max_val - 10 * torch.log(mse) / self.base10 diff --git a/modules/EdgeConnect/models.py b/modules/EdgeConnect/models.py new file mode 100644 index 0000000..1c47274 --- /dev/null +++ b/modules/EdgeConnect/models.py @@ -0,0 +1,316 @@ +import os +import torch +import torch.nn as nn +import torch.optim as optim +from .networks import InpaintGenerator, EdgeGenerator, Discriminator +from .loss import AdversarialLoss, PerceptualLoss, StyleLoss, TotalVariationalLoss + + +class BaseModel(nn.Module): + def __init__(self, name, config): + super(BaseModel, self).__init__() + + self.name = name + self.config = config + self.iteration = 0 + + self.gen_weights_path = os.path.join(config.PATH, name + '_gen.pth') + self.dis_weights_path = os.path.join(config.PATH, name + '_dis.pth') + + def load(self): + if os.path.exists(self.gen_weights_path): + print('Loading %s generator...' % self.name) + + if torch.cuda.is_available(): + data = torch.load(self.gen_weights_path) + else: + data = torch.load(self.gen_weights_path, map_location=lambda storage, loc: storage) + + self.generator.load_state_dict(data['generator']) + self.iteration = data['iteration'] + + # load discriminator only when training + if self.config.MODE == 1 and os.path.exists(self.dis_weights_path): + print('Loading %s discriminator...' % self.name) + + if torch.cuda.is_available(): + data = torch.load(self.dis_weights_path) + else: + data = torch.load(self.dis_weights_path, map_location=lambda storage, loc: storage) + + self.discriminator.load_state_dict(data['discriminator']) + + def save(self): + print('\nsaving %s...\n' % self.name) + torch.save({ + 'iteration': self.iteration, + 'generator': self.generator.state_dict() + }, self.gen_weights_path) + + torch.save({ + 'discriminator': self.discriminator.state_dict() + }, self.dis_weights_path) + + +class EdgeModel(BaseModel): + def __init__(self, config): + super(EdgeModel, self).__init__('EdgeModel', config) + + # generator input: [grayscale(1) + edge(1) + mask(1)] + # discriminator input: (grayscale(1) + edge(1)) + generator = EdgeGenerator(use_spectral_norm=True) + discriminator = Discriminator(in_channels=2, use_sigmoid=config.GAN_LOSS != 'hinge') + if len(config.GPU) > 1: + generator = nn.DataParallel(generator, config.GPU) + discriminator = nn.DataParallel(discriminator, config.GPU) + l1_loss = nn.L1Loss() + adversarial_loss = AdversarialLoss(type=config.GAN_LOSS) + + self.add_module('generator', generator) + self.add_module('discriminator', discriminator) + + self.add_module('l1_loss', l1_loss) + self.add_module('adversarial_loss', adversarial_loss) + + self.gen_optimizer = optim.Adam( + params=generator.parameters(), + lr=float(config.LR), + betas=(config.BETA1, config.BETA2) + ) + + self.dis_optimizer = optim.Adam( + params=discriminator.parameters(), + lr=float(config.LR) * float(config.D2G_LR), + betas=(config.BETA1, config.BETA2) + ) + + def process(self, images, edges, masks): + self.iteration += 1 + + + # zero optimizers + self.gen_optimizer.zero_grad() + self.dis_optimizer.zero_grad() + + + # process outputs + outputs = self(images, edges, masks) + gen_loss = 0 + dis_loss = 0 + + + # discriminator loss + dis_input_real = torch.cat((images, edges), dim=1) + dis_input_fake = torch.cat((images, outputs.detach()), dim=1) + dis_real, dis_real_feat = self.discriminator(dis_input_real) # in: (grayscale(1) + edge(1)) + dis_fake, dis_fake_feat = self.discriminator(dis_input_fake) # in: (grayscale(1) + edge(1)) + dis_real_loss = self.adversarial_loss(dis_real, True, True) + dis_fake_loss = self.adversarial_loss(dis_fake, False, True) + dis_loss += (dis_real_loss + dis_fake_loss) / 2 + + + # generator adversarial loss + gen_input_fake = torch.cat((images, outputs), dim=1) + gen_fake, gen_fake_feat = self.discriminator(gen_input_fake) # in: (grayscale(1) + edge(1)) + gen_gan_loss = self.adversarial_loss(gen_fake, True, False) + gen_loss += gen_gan_loss + + + # generator feature matching loss + gen_fm_loss = 0 + for i in range(len(dis_real_feat)): + gen_fm_loss += self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach()) + gen_fm_loss = gen_fm_loss * self.config.FM_LOSS_WEIGHT + gen_loss += gen_fm_loss + + + # create logs + logs = [ + ("l_d1", dis_loss.item()), + ("l_g1", gen_gan_loss.item()), + ("l_fm", gen_fm_loss.item()), + ] + + return outputs, gen_loss, dis_loss, logs + + def forward(self, images, edges, masks): + edges_masked = (edges * (1 - masks)) + images_masked = (images * (1 - masks)) + masks + inputs = torch.cat((images_masked, edges_masked, masks), dim=1) + outputs = self.generator(inputs) # in: [grayscale(1) + edge(1) + mask(1)] + return outputs + + def backward(self, gen_loss=None, dis_loss=None): + if dis_loss is not None: + dis_loss.backward() + self.dis_optimizer.step() + + if gen_loss is not None: + gen_loss.backward() + self.gen_optimizer.step() + + +class InpaintingModel(BaseModel): + def __init__(self, config): + super(InpaintingModel, self).__init__('InpaintingModel', config) + + # generator input: [rgb(3) + edge(1)] + # discriminator input: [rgb(3)] + generator = InpaintGenerator(config) + self.config = config + if config.FLO == 1: + in_channels = 2 + elif config.FLO == 0: + in_channels = 3 + else: + assert(0) + discriminator = Discriminator(in_channels=in_channels, use_sigmoid=config.GAN_LOSS != 'hinge') + if len(config.GPU) > 1: + generator = nn.DataParallel(generator, config.GPU) + discriminator = nn.DataParallel(discriminator , config.GPU) + + l1_loss = nn.L1Loss() + tv_loss = TotalVariationalLoss() + perceptual_loss = PerceptualLoss() + style_loss = StyleLoss() + adversarial_loss = AdversarialLoss(type=config.GAN_LOSS) + + self.add_module('generator', generator) + self.add_module('discriminator', discriminator) + + self.add_module('l1_loss', l1_loss) + self.add_module('tv_loss', tv_loss) + self.add_module('perceptual_loss', perceptual_loss) + self.add_module('style_loss', style_loss) + self.add_module('adversarial_loss', adversarial_loss) + + self.gen_optimizer = optim.Adam( + params=generator.parameters(), + lr=float(config.LR), + betas=(config.BETA1, config.BETA2) + ) + + self.dis_optimizer = optim.Adam( + params=discriminator.parameters(), + lr=float(config.LR) * float(config.D2G_LR), + betas=(config.BETA1, config.BETA2) + ) + + def process(self, images, images_filled, edges, masks): + self.iteration += 1 + + # zero optimizers + self.gen_optimizer.zero_grad() + self.dis_optimizer.zero_grad() + + # process outputs + outputs = self(images, images_filled, edges, masks) + + gen_loss = 0 + dis_loss = 0 + gen_gan_loss = 0 + + if self.config.GAN == 1: + # discriminator loss + dis_input_real = images + dis_input_fake = outputs.detach() + dis_real, _ = self.discriminator(dis_input_real) # in: [rgb(3)] + dis_fake, _ = self.discriminator(dis_input_fake) # in: [rgb(3)] + dis_real_loss = self.adversarial_loss(dis_real, True, True) + dis_fake_loss = self.adversarial_loss(dis_fake, False, True) + dis_loss += (dis_real_loss + dis_fake_loss) / 2 + + + # generator adversarial loss + gen_input_fake = outputs + gen_fake, _ = self.discriminator(gen_input_fake) # in: [rgb(3)] + gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.config.INPAINT_ADV_LOSS_WEIGHT + gen_loss += gen_gan_loss + + + # generator l1 loss + gen_l1_loss = self.l1_loss(outputs, images) * self.config.L1_LOSS_WEIGHT / torch.mean(masks) + gen_loss += gen_l1_loss + + if self.config.ENFORCE == 1: + gen_l1_masked_loss = self.l1_loss(outputs * masks, images * masks) * 10 * self.config.L1_LOSS_WEIGHT + gen_loss += gen_l1_masked_loss + elif self.config.ENFORCE != 0: + assert(0) + + if self.config.TV == 1: + # generator tv loss + gen_tv_loss = self.tv_loss(outputs) * self.config.TV_LOSS_WEIGHT + gen_loss += gen_tv_loss + + if self.config.FLO != 1: + # generator perceptual loss + gen_content_loss = self.perceptual_loss(outputs, images) + gen_content_loss = gen_content_loss * self.config.CONTENT_LOSS_WEIGHT + gen_loss += gen_content_loss + + # generator style loss + gen_style_loss = self.style_loss(outputs * masks, images * masks) + gen_style_loss = gen_style_loss * self.config.STYLE_LOSS_WEIGHT + gen_loss += gen_style_loss + + # create logs + logs = [ + ("l_d2", dis_loss.item()), + ("l_g2", gen_gan_loss.item()), + ("l_l1", gen_l1_loss.item()), + ("l_per", gen_content_loss.item()), + ("l_sty", gen_style_loss.item()), + ] + else: + logs = [] + logs.append(("l_l1", gen_l1_loss.item())) + logs.append(("l_gen", gen_loss.item())) + + if self.config.GAN == 1: + logs.append(("l_d2", dis_loss.item())) + logs.append(("l_g2", gen_gan_loss.item())) + + if self.config.TV == 1: + logs.append(("l_tv", gen_tv_loss.item())) + + if self.config.ENFORCE == 1: + logs.append(("l_masked_l1", gen_l1_masked_loss.item())) + + return outputs, gen_loss, dis_loss, logs + + def forward(self, images, images_filled, edges, masks): + + if self.config.FILL == 1: + images_masked = images_filled + elif self.config.FILL == 0: + images_masked = (images * (1 - masks).float()) # + masks + else: + assert(0) + + if self.config.PASSMASK == 1: + inputs = torch.cat((images_masked, edges, masks), dim=1) + elif self.config.PASSMASK == 0: + inputs = torch.cat((images_masked, edges), dim=1) + else: + assert(0) + + outputs = self.generator(inputs) + # if self.config.RESIDUAL == 1: + # assert(self.config.PASSMASK == 1) + # outputs = self.generator(inputs) + images_filled + # elif self.config.RESIDUAL == 0: + # outputs = self.generator(inputs) + # else: + # assert(0) + + return outputs + + def backward(self, gen_loss=None, dis_loss=None): + + if self.config.GAN == 1: + dis_loss.backward() + self.dis_optimizer.step() + + gen_loss.backward() + self.gen_optimizer.step() diff --git a/modules/EdgeConnect/networks.py b/modules/EdgeConnect/networks.py new file mode 100644 index 0000000..80b9831 --- /dev/null +++ b/modules/EdgeConnect/networks.py @@ -0,0 +1,234 @@ +import torch +import torch.nn as nn + + +class BaseNetwork(nn.Module): + def __init__(self): + super(BaseNetwork, self).__init__() + + def init_weights(self, init_type='normal', gain=0.02): + ''' + initialize network's weights + init_type: normal | xavier | kaiming | orthogonal + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 + ''' + + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + elif classname.find('BatchNorm2d') != -1: + nn.init.normal_(m.weight.data, 1.0, gain) + nn.init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + +class InpaintGenerator(BaseNetwork): + def __init__(self, config, residual_blocks=8, init_weights=True): + super(InpaintGenerator, self).__init__() + self.config = config + if config.FLO == 1: + if config.PASSMASK == 0: + in_channels = 3 + elif config.PASSMASK == 1: + in_channels = 4 + else: + assert(0) + out_channels = 2 + elif config.FLO == 0: + in_channels = 4 + out_channels = 3 + else: + assert(0) + self.encoder = nn.Sequential( + nn.ReflectionPad2d(3), + nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0), + nn.InstanceNorm2d(64, track_running_stats=False), + nn.ReLU(True), + + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), + nn.InstanceNorm2d(128, track_running_stats=False), + nn.ReLU(True), + + nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), + nn.InstanceNorm2d(256, track_running_stats=False), + nn.ReLU(True) + ) + + blocks = [] + for _ in range(residual_blocks): + block = ResnetBlock(256, 2) + blocks.append(block) + + self.middle = nn.Sequential(*blocks) + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), + nn.InstanceNorm2d(128, track_running_stats=False), + nn.ReLU(True), + + nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), + nn.InstanceNorm2d(64, track_running_stats=False), + nn.ReLU(True), + + nn.ReflectionPad2d(3), + nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=7, padding=0), + ) + + if init_weights: + self.init_weights() + + def forward(self, input): + x = self.encoder(input) + x = self.middle(x) + x = self.decoder(x) + + if self.config.FLO == 0: + x = (torch.tanh(x) + 1) / 2 + elif self.config.FLO == 1 and self.config.NORM == 1: + if self.config.RESIDUAL == 1: + assert(self.config.FILL == 1) + x = torch.tanh(x + input[:, :2, :, :]) + elif self.config.RESIDUAL == 0: + x = torch.tanh(x) + else: + assert(0) + return x + + +class EdgeGenerator_(BaseNetwork): + def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True): + super(EdgeGenerator_, self).__init__() + + self.encoder = nn.Sequential( + nn.ReflectionPad2d(3), + spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm), + nn.InstanceNorm2d(64, track_running_stats=False), + nn.ReLU(True), + + spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm), + nn.InstanceNorm2d(128, track_running_stats=False), + nn.ReLU(True), + + spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm), + nn.InstanceNorm2d(256, track_running_stats=False), + nn.ReLU(True) + ) + + blocks = [] + for _ in range(residual_blocks): + block = ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm) + blocks.append(block) + + self.middle = nn.Sequential(*blocks) + + self.decoder = nn.Sequential( + spectral_norm(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm), + nn.InstanceNorm2d(128, track_running_stats=False), + nn.ReLU(True), + + spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm), + nn.InstanceNorm2d(64, track_running_stats=False), + nn.ReLU(True), + + nn.ReflectionPad2d(3), + nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0), + ) + + if init_weights: + self.init_weights() + + def forward(self, x): + x = self.encoder(x) + x = self.middle(x) + x = self.decoder(x) + x = torch.sigmoid(x) + return x + + +class Discriminator(BaseNetwork): + def __init__(self, in_channels, use_sigmoid=True, use_spectral_norm=True, init_weights=True): + super(Discriminator, self).__init__() + self.use_sigmoid = use_sigmoid + + self.conv1 = self.features = nn.Sequential( + spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), + nn.LeakyReLU(0.2, inplace=True), + ) + + self.conv2 = nn.Sequential( + spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), + nn.LeakyReLU(0.2, inplace=True), + ) + + self.conv3 = nn.Sequential( + spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), + nn.LeakyReLU(0.2, inplace=True), + ) + + self.conv4 = nn.Sequential( + spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), + nn.LeakyReLU(0.2, inplace=True), + ) + + self.conv5 = nn.Sequential( + spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), + ) + + if init_weights: + self.init_weights() + + def forward(self, x): + conv1 = self.conv1(x) + conv2 = self.conv2(conv1) + conv3 = self.conv3(conv2) + conv4 = self.conv4(conv3) + conv5 = self.conv5(conv4) + + outputs = conv5 + if self.use_sigmoid: + outputs = torch.sigmoid(conv5) + + return outputs, [conv1, conv2, conv3, conv4, conv5] + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dilation=1, use_spectral_norm=False): + super(ResnetBlock, self).__init__() + self.conv_block = nn.Sequential( + nn.ReflectionPad2d(dilation), + spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm), + nn.InstanceNorm2d(dim, track_running_stats=False), + nn.ReLU(True), + + nn.ReflectionPad2d(1), + spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm), + nn.InstanceNorm2d(dim, track_running_stats=False), + ) + + def forward(self, x): + out = x + self.conv_block(x) + + # Remove ReLU at the end of the residual block + # http://torch.ch/blog/2016/02/04/resnets.html + + return out + + +def spectral_norm(module, mode=True): + if mode: + return nn.utils.spectral_norm(module) + + return module diff --git a/modules/EdgeConnect/region_fill.py b/modules/EdgeConnect/region_fill.py new file mode 100644 index 0000000..511e066 --- /dev/null +++ b/modules/EdgeConnect/region_fill.py @@ -0,0 +1,141 @@ +import numpy as np +import cv2 +from scipy import sparse +from scipy.sparse.linalg import spsolve + + +def regionfill(I, mask, factor=1.0): + if np.count_nonzero(mask) == 0: + return I.copy() + resize_mask = cv2.resize( + mask.astype(float), (0, 0), fx=factor, fy=factor) > 0 + resize_I = cv2.resize(I.astype(float), (0, 0), fx=factor, fy=factor) + maskPerimeter = findBoundaryPixels(resize_mask) + regionfillLaplace(resize_I, resize_mask, maskPerimeter) + resize_I = cv2.resize(resize_I, (I.shape[1], I.shape[0])) + resize_I[mask == 0] = I[mask == 0] + return resize_I + + +def findBoundaryPixels(mask): + kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) + maskDilated = cv2.dilate(mask.astype(float), kernel) + return (maskDilated > 0) & (mask == 0) + + +def regionfillLaplace(I, mask, maskPerimeter): + height, width = I.shape + rightSide = formRightSide(I, maskPerimeter) + + # Location of mask pixels + maskIdx = np.where(mask) + + # Only keep values for pixels that are in the mask + rightSide = rightSide[maskIdx] + + # Number the mask pixels in a grid matrix + grid = -np.ones((height, width)) + grid[maskIdx] = range(0, maskIdx[0].size) + # Pad with zeros to avoid "index out of bounds" errors in the for loop + grid = padMatrix(grid) + gridIdx = np.where(grid >= 0) + + # Form the connectivity matrix D=sparse(i,j,s) + # Connect each mask pixel to itself + i = np.arange(0, maskIdx[0].size) + j = np.arange(0, maskIdx[0].size) + # The coefficient is the number of neighbors over which we average + numNeighbors = computeNumberOfNeighbors(height, width) + s = numNeighbors[maskIdx] + # Now connect the N,E,S,W neighbors if they exist + for direction in ((-1, 0), (0, 1), (1, 0), (0, -1)): + # Possible neighbors in the current direction + neighbors = grid[gridIdx[0] + direction[0], gridIdx[1] + direction[1]] + # ConDnect mask points to neighbors with -1's + index = (neighbors >= 0) + i = np.concatenate((i, grid[gridIdx[0][index], gridIdx[1][index]])) + j = np.concatenate((j, neighbors[index])) + s = np.concatenate((s, -np.ones(np.count_nonzero(index)))) + + D = sparse.coo_matrix((s, (i.astype(int), j.astype(int)))).tocsr() + sol = spsolve(D, rightSide) + I[maskIdx] = sol + return I + + +def formRightSide(I, maskPerimeter): + height, width = I.shape + perimeterValues = np.zeros((height, width)) + perimeterValues[maskPerimeter] = I[maskPerimeter] + rightSide = np.zeros((height, width)) + + rightSide[1:height - 1, 1:width - 1] = ( + perimeterValues[0:height - 2, 1:width - 1] + + perimeterValues[2:height, 1:width - 1] + + perimeterValues[1:height - 1, 0:width - 2] + + perimeterValues[1:height - 1, 2:width]) + + rightSide[1:height - 1, 0] = ( + perimeterValues[0:height - 2, 0] + perimeterValues[2:height, 0] + + perimeterValues[1:height - 1, 1]) + + rightSide[1:height - 1, width - 1] = ( + perimeterValues[0:height - 2, width - 1] + + perimeterValues[2:height, width - 1] + + perimeterValues[1:height - 1, width - 2]) + + rightSide[0, 1:width - 1] = ( + perimeterValues[1, 1:width - 1] + perimeterValues[0, 0:width - 2] + + perimeterValues[0, 2:width]) + + rightSide[height - 1, 1:width - 1] = ( + perimeterValues[height - 2, 1:width - 1] + + perimeterValues[height - 1, 0:width - 2] + + perimeterValues[height - 1, 2:width]) + + rightSide[0, 0] = perimeterValues[0, 1] + perimeterValues[1, 0] + rightSide[0, width - 1] = ( + perimeterValues[0, width - 2] + perimeterValues[1, width - 1]) + rightSide[height - 1, 0] = ( + perimeterValues[height - 2, 0] + perimeterValues[height - 1, 1]) + rightSide[height - 1, width - 1] = (perimeterValues[height - 2, width - 1] + + perimeterValues[height - 1, width - 2]) + return rightSide + + +def computeNumberOfNeighbors(height, width): + # Initialize + numNeighbors = np.zeros((height, width)) + # Interior pixels have 4 neighbors + numNeighbors[1:height - 1, 1:width - 1] = 4 + # Border pixels have 3 neighbors + numNeighbors[1:height - 1, (0, width - 1)] = 3 + numNeighbors[(0, height - 1), 1:width - 1] = 3 + # Corner pixels have 2 neighbors + numNeighbors[(0, 0, height - 1, height - 1), (0, width - 1, 0, + width - 1)] = 2 + return numNeighbors + + +def padMatrix(grid): + height, width = grid.shape + gridPadded = -np.ones((height + 2, width + 2)) + gridPadded[1:height + 1, 1:width + 1] = grid + gridPadded = gridPadded.astype(grid.dtype) + return gridPadded + + +if __name__ == '__main__': + import time + x = np.linspace(0, 255, 500) + xv, _ = np.meshgrid(x, x) + image = ((xv + np.transpose(xv)) / 2.0).astype(int) + mask = np.zeros((500, 500)) + mask[100:259, 100:259] = 1 + mask = (mask > 0) + image[mask] = 0 + st = time.time() + inpaint = regionfill(image, mask, 0.5).astype(np.uint8) + print(time.time() - st) + cv2.imshow('img', np.concatenate((image.astype(np.uint8), inpaint))) + cv2.waitKey() diff --git a/modules/EdgeConnect/utils.py b/modules/EdgeConnect/utils.py new file mode 100644 index 0000000..8adf774 --- /dev/null +++ b/modules/EdgeConnect/utils.py @@ -0,0 +1,216 @@ +import os +import sys +import time +import random +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image + + +def create_dir(dir): + if not os.path.exists(dir): + os.makedirs(dir) + + +def create_mask(width, height, mask_width, mask_height, x=None, y=None): + mask = np.zeros((height, width)) + mask_x = x if x is not None else random.randint(0, width - mask_width) + mask_y = y if y is not None else random.randint(0, height - mask_height) + mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1 + return mask + + +def stitch_images(inputs, *outputs, img_per_row=2): + gap = 5 + columns = len(outputs) + 1 + + height, width = inputs[0][:, :, 0].shape + img = Image.new('RGB', (width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row))) + images = [inputs, *outputs] + + for ix in range(len(inputs)): + xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap + yoffset = int(ix / img_per_row) * height + + for cat in range(len(images)): + im = np.array((images[cat][ix]).cpu()).astype(np.uint8).squeeze() + im = Image.fromarray(im) + img.paste(im, (xoffset + cat * width, yoffset)) + + return img + + +def imshow(img, title=''): + fig = plt.gcf() + fig.canvas.set_window_title(title) + plt.axis('off') + plt.imshow(img, interpolation='none') + plt.show() + + +def imsave(img, path): + im = Image.fromarray(img.cpu().numpy().astype(np.uint8).squeeze()) + im.save(path) + + +class Progbar(object): + """Displays a progress bar. + + Arguments: + target: Total number of steps expected, None if unknown. + width: Progress bar width on screen. + verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) + stateful_metrics: Iterable of string names of metrics that + should *not* be averaged over time. Metrics in this list + will be displayed as-is. All others will be averaged + by the progbar before display. + interval: Minimum visual progress update interval (in seconds). + """ + + def __init__(self, target, width=25, verbose=1, interval=0.05, + stateful_metrics=None): + self.target = target + self.width = width + self.verbose = verbose + self.interval = interval + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) + else: + self.stateful_metrics = set() + + self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and + sys.stdout.isatty()) or + 'ipykernel' in sys.modules or + 'posix' in sys.modules) + self._total_width = 0 + self._seen_so_far = 0 + # We use a dict + list to avoid garbage collection + # issues found in OrderedDict + self._values = {} + self._values_order = [] + self._start = time.time() + self._last_update = 0 + + def update(self, current, values=None): + """Updates the progress bar. + + Arguments: + current: Index of current step. + values: List of tuples: + `(name, value_for_last_step)`. + If `name` is in `stateful_metrics`, + `value_for_last_step` will be displayed as-is. + Else, an average of the metric over time will be displayed. + """ + values = values or [] + for k, v in values: + if k not in self._values_order: + self._values_order.append(k) + if k not in self.stateful_metrics: + if k not in self._values: + self._values[k] = [v * (current - self._seen_so_far), + current - self._seen_so_far] + else: + self._values[k][0] += v * (current - self._seen_so_far) + self._values[k][1] += (current - self._seen_so_far) + else: + self._values[k] = v + self._seen_so_far = current + + now = time.time() + info = ' - %.0fs' % (now - self._start) + if self.verbose == 1: + if (now - self._last_update < self.interval and + self.target is not None and current < self.target): + return + + prev_total_width = self._total_width + if self._dynamic_display: + sys.stdout.write('\b' * prev_total_width) + sys.stdout.write('\r') + else: + sys.stdout.write('\n') + + if self.target is not None: + numdigits = int(np.floor(np.log10(self.target))) + 1 + barstr = '%%%dd/%d [' % (numdigits, self.target) + bar = barstr % current + prog = float(current) / self.target + prog_width = int(self.width * prog) + if prog_width > 0: + bar += ('=' * (prog_width - 1)) + if current < self.target: + bar += '>' + else: + bar += '=' + bar += ('.' * (self.width - prog_width)) + bar += ']' + else: + bar = '%7d/Unknown' % current + + self._total_width = len(bar) + sys.stdout.write(bar) + + if current: + time_per_unit = (now - self._start) / current + else: + time_per_unit = 0 + if self.target is not None and current < self.target: + eta = time_per_unit * (self.target - current) + if eta > 3600: + eta_format = '%d:%02d:%02d' % (eta // 3600, + (eta % 3600) // 60, + eta % 60) + elif eta > 60: + eta_format = '%d:%02d' % (eta // 60, eta % 60) + else: + eta_format = '%ds' % eta + + info = ' - ETA: %s' % eta_format + else: + if time_per_unit >= 1: + info += ' %.0fs/step' % time_per_unit + elif time_per_unit >= 1e-3: + info += ' %.0fms/step' % (time_per_unit * 1e3) + else: + info += ' %.0fus/step' % (time_per_unit * 1e6) + + for k in self._values_order: + info += ' - %s:' % k + if isinstance(self._values[k], list): + avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) + if abs(avg) > 1e-3: + info += ' %.4f' % avg + else: + info += ' %.4e' % avg + else: + info += ' %s' % self._values[k] + + self._total_width += len(info) + if prev_total_width > self._total_width: + info += (' ' * (prev_total_width - self._total_width)) + + if self.target is not None and current >= self.target: + info += '\n' + + sys.stdout.write(info) + sys.stdout.flush() + + elif self.verbose == 2: + if self.target is None or current >= self.target: + for k in self._values_order: + info += ' - %s:' % k + avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) + if avg > 1e-3: + info += ' %.4f' % avg + else: + info += ' %.4e' % avg + info += '\n' + + sys.stdout.write(info) + sys.stdout.flush() + + self._last_update = now + + def add(self, n, values=None): + self.update(self._seen_so_far + n, values) diff --git a/modules/RAFT/__init__.py b/modules/RAFT/__init__.py new file mode 100644 index 0000000..e7179ea --- /dev/null +++ b/modules/RAFT/__init__.py @@ -0,0 +1,2 @@ +# from .demo import RAFT_infer +from .raft import RAFT diff --git a/modules/RAFT/corr.py b/modules/RAFT/corr.py new file mode 100644 index 0000000..449dbd9 --- /dev/null +++ b/modules/RAFT/corr.py @@ -0,0 +1,111 @@ +import torch +import torch.nn.functional as F +from .utils.utils import bilinear_sampler, coords_grid + +try: + import alt_cuda_corr +except: + # alt_cuda_corr is not compiled + pass + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2*r+1) + dy = torch.linspace(-r, r, 2*r+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class CorrLayer(torch.autograd.Function): + @staticmethod + def forward(ctx, fmap1, fmap2, coords, r): + fmap1 = fmap1.contiguous() + fmap2 = fmap2.contiguous() + coords = coords.contiguous() + ctx.save_for_backward(fmap1, fmap2, coords) + ctx.r = r + corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r) + return corr + + @staticmethod + def backward(ctx, grad_corr): + fmap1, fmap2, coords = ctx.saved_tensors + grad_corr = grad_corr.contiguous() + fmap1_grad, fmap2_grad, coords_grad = \ + correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r) + return fmap1_grad, fmap2_grad, coords_grad, None + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / 16.0 diff --git a/modules/RAFT/datasets.py b/modules/RAFT/datasets.py new file mode 100644 index 0000000..3411fda --- /dev/null +++ b/modules/RAFT/datasets.py @@ -0,0 +1,235 @@ +# Data loading based on https://github.com/NVIDIA/flownet2-pytorch + +import numpy as np +import torch +import torch.utils.data as data +import torch.nn.functional as F + +import os +import math +import random +from glob import glob +import os.path as osp + +from utils import frame_utils +from utils.augmentor import FlowAugmentor, SparseFlowAugmentor + + +class FlowDataset(data.Dataset): + def __init__(self, aug_params=None, sparse=False): + self.augmentor = None + self.sparse = sparse + if aug_params is not None: + if sparse: + self.augmentor = SparseFlowAugmentor(**aug_params) + else: + self.augmentor = FlowAugmentor(**aug_params) + + self.is_test = False + self.init_seed = False + self.flow_list = [] + self.image_list = [] + self.extra_info = [] + + def __getitem__(self, index): + + if self.is_test: + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + img1 = np.array(img1).astype(np.uint8)[..., :3] + img2 = np.array(img2).astype(np.uint8)[..., :3] + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + return img1, img2, self.extra_info[index] + + if not self.init_seed: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + torch.manual_seed(worker_info.id) + np.random.seed(worker_info.id) + random.seed(worker_info.id) + self.init_seed = True + + index = index % len(self.image_list) + valid = None + if self.sparse: + flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) + else: + flow = frame_utils.read_gen(self.flow_list[index]) + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = np.array(flow).astype(np.float32) + img1 = np.array(img1).astype(np.uint8) + img2 = np.array(img2).astype(np.uint8) + + # grayscale images + if len(img1.shape) == 2: + img1 = np.tile(img1[...,None], (1, 1, 3)) + img2 = np.tile(img2[...,None], (1, 1, 3)) + else: + img1 = img1[..., :3] + img2 = img2[..., :3] + + if self.augmentor is not None: + if self.sparse: + img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) + else: + img1, img2, flow = self.augmentor(img1, img2, flow) + + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + flow = torch.from_numpy(flow).permute(2, 0, 1).float() + + if valid is not None: + valid = torch.from_numpy(valid) + else: + valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) + + return img1, img2, flow, valid.float() + + + def __rmul__(self, v): + self.flow_list = v * self.flow_list + self.image_list = v * self.image_list + return self + + def __len__(self): + return len(self.image_list) + + +class MpiSintel(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): + super(MpiSintel, self).__init__(aug_params) + flow_root = osp.join(root, split, 'flow') + image_root = osp.join(root, split, dstype) + + if split == 'test': + self.is_test = True + + for scene in os.listdir(image_root): + image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) + for i in range(len(image_list)-1): + self.image_list += [ [image_list[i], image_list[i+1]] ] + self.extra_info += [ (scene, i) ] # scene and frame_id + + if split != 'test': + self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) + + +class FlyingChairs(FlowDataset): + def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): + super(FlyingChairs, self).__init__(aug_params) + + images = sorted(glob(osp.join(root, '*.ppm'))) + flows = sorted(glob(osp.join(root, '*.flo'))) + assert (len(images)//2 == len(flows)) + + split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) + for i in range(len(flows)): + xid = split_list[i] + if (split=='training' and xid==1) or (split=='validation' and xid==2): + self.flow_list += [ flows[i] ] + self.image_list += [ [images[2*i], images[2*i+1]] ] + + +class FlyingThings3D(FlowDataset): + def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): + super(FlyingThings3D, self).__init__(aug_params) + + for cam in ['left']: + for direction in ['into_future', 'into_past']: + image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) + image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) + + flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) + flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) + + for idir, fdir in zip(image_dirs, flow_dirs): + images = sorted(glob(osp.join(idir, '*.png')) ) + flows = sorted(glob(osp.join(fdir, '*.pfm')) ) + for i in range(len(flows)-1): + if direction == 'into_future': + self.image_list += [ [images[i], images[i+1]] ] + self.flow_list += [ flows[i] ] + elif direction == 'into_past': + self.image_list += [ [images[i+1], images[i]] ] + self.flow_list += [ flows[i+1] ] + + +class KITTI(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): + super(KITTI, self).__init__(aug_params, sparse=True) + if split == 'testing': + self.is_test = True + + root = osp.join(root, split) + images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) + images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) + + for img1, img2 in zip(images1, images2): + frame_id = img1.split('/')[-1] + self.extra_info += [ [frame_id] ] + self.image_list += [ [img1, img2] ] + + if split == 'training': + self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) + + +class HD1K(FlowDataset): + def __init__(self, aug_params=None, root='datasets/HD1k'): + super(HD1K, self).__init__(aug_params, sparse=True) + + seq_ix = 0 + while 1: + flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) + images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) + + if len(flows) == 0: + break + + for i in range(len(flows)-1): + self.flow_list += [flows[i]] + self.image_list += [ [images[i], images[i+1]] ] + + seq_ix += 1 + + +def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): + """ Create the data loader for the corresponding trainign set """ + + if args.stage == 'chairs': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} + train_dataset = FlyingChairs(aug_params, split='training') + + elif args.stage == 'things': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} + clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') + final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') + train_dataset = clean_dataset + final_dataset + + elif args.stage == 'sintel': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} + things = FlyingThings3D(aug_params, dstype='frames_cleanpass') + sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') + sintel_final = MpiSintel(aug_params, split='training', dstype='final') + + if TRAIN_DS == 'C+T+K+S+H': + kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) + hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) + train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things + + elif TRAIN_DS == 'C+T+K/S': + train_dataset = 100*sintel_clean + 100*sintel_final + things + + elif args.stage == 'kitti': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} + train_dataset = KITTI(aug_params, split='training') + + train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, + pin_memory=False, shuffle=True, num_workers=4, drop_last=True) + + print('Training with %d image pairs' % len(train_dataset)) + return train_loader + diff --git a/modules/RAFT/demo.py b/modules/RAFT/demo.py new file mode 100644 index 0000000..096963b --- /dev/null +++ b/modules/RAFT/demo.py @@ -0,0 +1,79 @@ +import sys +import argparse +import os +import cv2 +import glob +import numpy as np +import torch +from PIL import Image + +from .raft import RAFT +from .utils import flow_viz +from .utils.utils import InputPadder + + + +DEVICE = 'cuda' + +def load_image(imfile): + img = np.array(Image.open(imfile)).astype(np.uint8) + img = torch.from_numpy(img).permute(2, 0, 1).float() + return img + + +def load_image_list(image_files): + images = [] + for imfile in sorted(image_files): + images.append(load_image(imfile)) + + images = torch.stack(images, dim=0) + images = images.to(DEVICE) + + padder = InputPadder(images.shape) + return padder.pad(images)[0] + + +def viz(img, flo): + img = img[0].permute(1,2,0).cpu().numpy() + flo = flo[0].permute(1,2,0).cpu().numpy() + + # map flow to rgb image + flo = flow_viz.flow_to_image(flo) + # img_flo = np.concatenate([img, flo], axis=0) + img_flo = flo + + cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]]) + # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) + # cv2.waitKey() + + +def demo(args): + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.model)) + + model = model.module + model.to(DEVICE) + model.eval() + + with torch.no_grad(): + images = glob.glob(os.path.join(args.path, '*.png')) + \ + glob.glob(os.path.join(args.path, '*.jpg')) + + images = load_image_list(images) + for i in range(images.shape[0]-1): + image1 = images[i,None] + image2 = images[i+1,None] + + flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) + viz(image1, flow_up) + + +def RAFT_infer(args): + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.model)) + + model = model.module + model.to(DEVICE) + model.eval() + + return model diff --git a/modules/RAFT/extractor.py b/modules/RAFT/extractor.py new file mode 100644 index 0000000..9a9c759 --- /dev/null +++ b/modules/RAFT/extractor.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/modules/RAFT/raft.py b/modules/RAFT/raft.py new file mode 100644 index 0000000..43a59c3 --- /dev/null +++ b/modules/RAFT/raft.py @@ -0,0 +1,145 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .update import BasicUpdateBlock, SmallUpdateBlock +from .extractor import BasicEncoder, SmallEncoder +from .corr import CorrBlock, AlternateCorrBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + def __enter__(self): + pass + def __exit__(self, *args): + pass + + +class RAFT(nn.Module): + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + + if args.small: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + args.corr_levels = 4 + args.corr_radius = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + + if 'dropout' not in args._get_kwargs(): + args.dropout = 0 + + if 'alternate_corr' not in args._get_kwargs(): + args.alternate_corr = False + + # feature network, context network, and update block + if args.small: + self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) + self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) + self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H//8, W//8).to(img.device) + coords1 = coords_grid(N, H//8, W//8).to(img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3,3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8*H, 8*W) + + + def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): + """ Estimate optical flow between pair of frames """ + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast(enabled=self.args.mixed_precision): + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + if self.args.alternate_corr: + corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius) + else: + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + with autocast(enabled=self.args.mixed_precision): + cnet = self.cnet(image1) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + for itr in range(iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + with autocast(enabled=self.args.mixed_precision): + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if test_mode: + return coords1 - coords0, flow_up + + return flow_predictions diff --git a/modules/RAFT/update.py b/modules/RAFT/update.py new file mode 100644 index 0000000..f940497 --- /dev/null +++ b/modules/RAFT/update.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + + + diff --git a/modules/RAFT/utils/__init__.py b/modules/RAFT/utils/__init__.py new file mode 100644 index 0000000..0437149 --- /dev/null +++ b/modules/RAFT/utils/__init__.py @@ -0,0 +1,2 @@ +from .flow_viz import flow_to_image +from .frame_utils import writeFlow diff --git a/modules/RAFT/utils/augmentor.py b/modules/RAFT/utils/augmentor.py new file mode 100644 index 0000000..e81c4f2 --- /dev/null +++ b/modules/RAFT/utils/augmentor.py @@ -0,0 +1,246 @@ +import numpy as np +import random +import math +from PIL import Image + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torch.nn.functional as F + + +class FlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): + + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + """ Photometric augmentation """ + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """ Occlusion augmentation """ + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow): + # randomly sample scale + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = flow * [scale_x, scale_y] + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + return img1, img2, flow + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid>=1] + flow0 = flow[valid>=1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:,0]).astype(np.int32) + yy = np.round(coords1[:,1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), + (self.crop_size[1] + 1) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + return img1, img2, flow, valid + + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid diff --git a/modules/RAFT/utils/flow_viz.py b/modules/RAFT/utils/flow_viz.py new file mode 100644 index 0000000..dcee65e --- /dev/null +++ b/modules/RAFT/utils/flow_viz.py @@ -0,0 +1,132 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/modules/RAFT/utils/frame_utils.py b/modules/RAFT/utils/frame_utils.py new file mode 100644 index 0000000..6c49113 --- /dev/null +++ b/modules/RAFT/utils/frame_utils.py @@ -0,0 +1,137 @@ +import numpy as np +from PIL import Image +from os.path import * +import re + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +TAG_CHAR = np.array([202021.25], np.float32) + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def writeFlow(filename,uv,v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert(uv.ndim == 3) + assert(uv.shape[2] == 2) + u = uv[:,:,0] + v = uv[:,:,1] + else: + u = uv + + assert(u.shape == v.shape) + height,width = u.shape + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width*nBands)) + tmp[:,np.arange(width)*2] = u + tmp[:,np.arange(width)*2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) + flow = flow[:,:,::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + return Image.open(file_name) + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return readFlow(file_name).astype(np.float32) + elif ext == '.pfm': + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] \ No newline at end of file diff --git a/modules/RAFT/utils/utils.py b/modules/RAFT/utils/utils.py new file mode 100644 index 0000000..5f32d28 --- /dev/null +++ b/modules/RAFT/utils/utils.py @@ -0,0 +1,82 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/modules/frame_inpaint.py b/modules/frame_inpaint.py new file mode 100644 index 0000000..bc52f3d --- /dev/null +++ b/modules/frame_inpaint.py @@ -0,0 +1,123 @@ +import sys, os, argparse +#sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..'))) +import torch +import numpy as np +import cv2 + +from modules.DeepFill import DeepFill + + +class DeepFillv1(object): + def __init__(self, + pretrained_model=None, + image_shape=[512, 960], + res_shape=None, + device=torch.device('cuda:0')): + self.image_shape = image_shape + self.res_shape = res_shape + self.device = device + + self.deepfill = DeepFill.Generator().to(device) + model_weight = torch.load(pretrained_model) + self.deepfill.load_state_dict(model_weight, strict=True) + self.deepfill.eval() + print('Load Deepfill Model from', pretrained_model) + + def forward(self, img, mask): + + img, mask, small_mask = self.data_preprocess(img, mask, size=self.image_shape) + + image = torch.stack([img]) + mask = torch.stack([mask]) + small_mask = torch.stack([small_mask]) + + with torch.no_grad(): + _, inpaint_res, _ = self.deepfill(image.to(self.device), mask.to(self.device), small_mask.to(self.device)) + + res_complete = self.data_proprocess(image, mask, inpaint_res) + + return res_complete + + def data_preprocess(self, img, mask, enlarge_kernel=0, size=[512, 960]): + img = img / 127.5 - 1 + mask = (mask > 0).astype(np.int) + img = cv2.resize(img, (size[1], size[0])) + if enlarge_kernel > 0: + kernel = np.ones((enlarge_kernel, enlarge_kernel), np.uint8) + mask = cv2.dilate(mask, kernel, iterations=1) + mask = (mask > 0).astype(np.uint8) + + small_mask = cv2.resize(mask, (size[1] // 8, size[0] // 8), interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask, (size[1], size[0]), interpolation=cv2.INTER_NEAREST) + + if len(mask.shape) == 3: + mask = mask[:, :, 0:1] + else: + mask = np.expand_dims(mask, axis=2) + + if len(small_mask.shape) == 3: + small_mask = small_mask[:, :, 0:1] + else: + small_mask = np.expand_dims(small_mask, axis=2) + + img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float() + mask = torch.from_numpy(mask).permute(2, 0, 1).contiguous().float() + small_mask = torch.from_numpy(small_mask).permute(2, 0, 1).contiguous().float() + + return img*(1-mask), mask, small_mask + + def data_proprocess(self, img, mask, res): + img = img.cpu().data.numpy()[0] + mask = mask.data.numpy()[0] + res = res.cpu().data.numpy()[0] + + res_complete = res * mask + img * (1. - mask) + res_complete = (res_complete + 1) * 127.5 + res_complete = res_complete.transpose(1, 2, 0) + if self.res_shape is not None: + res_complete = cv2.resize(res_complete, + (self.res_shape[1], self.res_shape[0])) + + return res_complete + + +def parse_arges(): + parser = argparse.ArgumentParser() + parser.add_argument('--image_shape', type=int, nargs='+', + default=[512, 960]) + parser.add_argument('--res_shape', type=int, nargs='+', + default=None) + parser.add_argument('--pretrained_model', type=str, + default='/home/chengao/Weight/imagenet_deepfill.pth') + parser.add_argument('--test_img', type=str, + default='/work/cascades/chengao/DAVIS-540/bear_540p/00000.png') + parser.add_argument('--test_mask', type=str, + default='/work/cascades/chengao/DAVIS-540-baseline/mask_540p.png') + parser.add_argument('--output_path', type=str, + default='/home/chengao/res_00000.png') + + args = parser.parse_args() + + return args + + +def main(): + + args = parse_arges() + + deepfill = DeepFillv1(pretrained_model=args.pretrained_model, + image_shape=args.image_shape, + res_shape=args.res_shape) + + test_image = cv2.imread(args.test_img) + mask = cv2.imread(args.test_mask, cv2.IMREAD_UNCHANGED) + + with torch.no_grad(): + img_res = deepfill.forward(test_image, mask) + + cv2.imwrite(args.output_path, img_res) + print('Result Saved') + + +if __name__ == '__main__': + main() diff --git a/modules/get_flowNN_gradient.py b/modules/get_flowNN_gradient.py new file mode 100644 index 0000000..35500bd --- /dev/null +++ b/modules/get_flowNN_gradient.py @@ -0,0 +1,534 @@ +from __future__ import absolute_import, division, print_function, unicode_literals +import os +import cv2 +import copy +import numpy as np +import scipy.io as sio +from modules.utils.common_utils import interp, BFconsistCheck, \ + FBconsistCheck, consistCheck, get_KeySourceFrame_flowNN_gradient + + +def get_flowNN_gradient(args, + gradient_x, + gradient_y, + mask_RGB, + mask, + videoFlowF, + videoFlowB, + videoNonLocalFlowF, + videoNonLocalFlowB): + + # gradient_x: imgH x (imgW - 1 + 1) x 3 x nFrame + # gradient_y: (imgH - 1 + 1) x imgW x 3 x nFrame + # mask_RGB: imgH x imgW x nFrame + # mask: imgH x imgW x nFrame + # videoFlowF: imgH x imgW x 2 x (nFrame - 1) | [u, v] + # videoFlowB: imgH x imgW x 2 x (nFrame - 1) | [u, v] + # videoNonLocalFlowF: imgH x imgW x 2 x 3 x nFrame + # videoNonLocalFlowB: imgH x imgW x 2 x 3 x nFrame + + if args.Nonlocal: + num_candidate = 5 + else: + num_candidate = 2 + imgH, imgW, nFrame = mask.shape + numPix = np.sum(mask) + + # |--------------------| |--------------------| + # | y | | v | + # | x * | | u * | + # | | | | + # |--------------------| |--------------------| + + # sub: numPix * 3 | [y, x, t] + # flowNN: numPix * 3 * 2 | [y, x, t], [BN, FN] + # HaveFlowNN: imgH * imgW * nFrame * 2 + # numPixInd: imgH * imgW * nFrame + # consistencyMap: imgH * imgW * 5 * nFrame | [BN, FN, NL2, NL3, NL4] + # consistency_uv: imgH * imgW * [BN, FN] * [u, v] * nFrame + + # sub: numPix * [y, x, t] | position of mising pixels + sub = np.concatenate((np.where(mask == 1)[0].reshape(-1, 1), + np.where(mask == 1)[1].reshape(-1, 1), + np.where(mask == 1)[2].reshape(-1, 1)), axis=1) + + # flowNN: numPix * [y, x, t] * [BN, FN] | flow neighbors + flowNN = np.ones((numPix, 3, 2)) * 99999 # * -1 + HaveFlowNN = np.ones((imgH, imgW, nFrame, 2)) * 99999 + HaveFlowNN[mask, :] = 0 + numPixInd = np.ones((imgH, imgW, nFrame)) * -1 + consistencyMap = np.zeros((imgH, imgW, num_candidate, nFrame)) + consistency_uv = np.zeros((imgH, imgW, 2, 2, nFrame)) + + # numPixInd[y, x, t] gives the index of the missing pixel@[y, x, t] in sub, + # i.e. which row. numPixInd[y, x, t] = idx; sub[idx, :] = [y, x, t] + for idx in range(len(sub)): + numPixInd[sub[idx, 0], sub[idx, 1], sub[idx, 2]] = idx + + # Initialization + frameIndSetF = range(1, nFrame) + frameIndSetB = range(nFrame - 2, -1, -1) + + # 1. Forward Pass (backward flow propagation) + print('Forward Pass......') + + NN_idx = 0 # BN:0 + for indFrame in frameIndSetF: + + # Bool indicator of missing pixels at frame t + holepixPosInd = (sub[:, 2] == indFrame) + + # Hole pixel location at frame t, i.e. [y, x, t] + holepixPos = sub[holepixPosInd, :] + + # Calculate the backward flow neighbor. Should be located at frame t-1 + flowB_neighbor = copy.deepcopy(holepixPos) + flowB_neighbor = flowB_neighbor.astype(np.float32) + + flowB_vertical = videoFlowB[:, :, 1, indFrame - 1] # t --> t-1 + flowB_horizont = videoFlowB[:, :, 0, indFrame - 1] + flowF_vertical = videoFlowF[:, :, 1, indFrame - 1] # t-1 --> t + flowF_horizont = videoFlowF[:, :, 0, indFrame - 1] + + flowB_neighbor[:, 0] += flowB_vertical[holepixPos[:, 0], holepixPos[:, 1]] + flowB_neighbor[:, 1] += flowB_horizont[holepixPos[:, 0], holepixPos[:, 1]] + flowB_neighbor[:, 2] -= 1 + + # Round the backward flow neighbor location + flow_neighbor_int = np.round(copy.deepcopy(flowB_neighbor)).astype(np.int32) + + # Chen: I should combine the following two operations together + # Check the backward/forward consistency + IsConsist, _ = BFconsistCheck(flowB_neighbor, + flowF_vertical, + flowF_horizont, + holepixPos, + args.consistencyThres) + + BFdiff, BF_uv = consistCheck(videoFlowF[:, :, :, indFrame - 1], + videoFlowB[:, :, :, indFrame - 1]) + + # Check out-of-boundary + # Last column and last row does not have valid gradient + ValidPos = np.logical_and( + np.logical_and(flow_neighbor_int[:, 0] >= 0, + flow_neighbor_int[:, 0] < imgH - 1), + np.logical_and(flow_neighbor_int[:, 1] >= 0, + flow_neighbor_int[:, 1] < imgW - 1)) + + # Only work with pixels that are not out-of-boundary + holepixPos = holepixPos[ValidPos, :] + flowB_neighbor = flowB_neighbor[ValidPos, :] + flow_neighbor_int = flow_neighbor_int[ValidPos, :] + IsConsist = IsConsist[ValidPos] + + # For each missing pixel in holepixPos|[y, x, t], + # we check its backward flow neighbor flowB_neighbor|[y', x', t-1]. + + # Case 1: If mask[round(y'), round(x'), t-1] == 0, + # the backward flow neighbor of [y, x, t] is known. + # [y', x', t-1] is the backward flow neighbor. + + # KnownInd: Among all backward flow neighbors, which pixel is known. + KnownInd = mask[flow_neighbor_int[:, 0], + flow_neighbor_int[:, 1], + indFrame - 1] == 0 + + KnownIsConsist = np.logical_and(KnownInd, IsConsist) + + # We save backward flow neighbor flowB_neighbor in flowNN + flowNN[numPixInd[holepixPos[KnownIsConsist, 0], + holepixPos[KnownIsConsist, 1], + indFrame].astype(np.int32), :, NN_idx] = \ + flowB_neighbor[KnownIsConsist, :] + # flowNN[np.where(holepixPosInd == 1)[0][ValidPos][KnownIsConsist], :, 0] = \ + # flowB_neighbor[KnownIsConsist, :] + + # We mark [y, x, t] in HaveFlowNN as 1 + HaveFlowNN[holepixPos[KnownIsConsist, 0], + holepixPos[KnownIsConsist, 1], + indFrame, + NN_idx] = 1 + + # HaveFlowNN[:, :, :, 0] + # 0: Backward flow neighbor can not be reached + # 1: Backward flow neighbor can be reached + # -1: Pixels that do not need to be completed + + consistency_uv[holepixPos[KnownIsConsist, 0], holepixPos[KnownIsConsist, 1], NN_idx, 0, indFrame] = np.abs(BF_uv[holepixPos[KnownIsConsist, 0], holepixPos[KnownIsConsist, 1], 0]) + consistency_uv[holepixPos[KnownIsConsist, 0], holepixPos[KnownIsConsist, 1], NN_idx, 1, indFrame] = np.abs(BF_uv[holepixPos[KnownIsConsist, 0], holepixPos[KnownIsConsist, 1], 1]) + + # Case 2: If mask[round(y'), round(x'), t-1] == 1, + # the pixel@[round(y'), round(x'), t-1] is also occluded. + # We further check if we already assign a backward flow neighbor for the backward flow neighbor + # If HaveFlowNN[round(y'), round(x'), t-1] == 0, + # this is isolated pixel. Do nothing. + # If HaveFlowNN[round(y'), round(x'), t-1] == 1, + # we can borrow the value and refine it. + + UnknownInd = np.invert(KnownInd) + + # If we already assign a backward flow neighbor@[round(y'), round(x'), t-1] + HaveNNInd = HaveFlowNN[flow_neighbor_int[:, 0], + flow_neighbor_int[:, 1], + indFrame - 1, + NN_idx] == 1 + + # Unknown & IsConsist & HaveNNInd + Valid_ = np.logical_and.reduce((UnknownInd, HaveNNInd, IsConsist)) + + refineVec = np.concatenate(( + (flowB_neighbor[:, 0] - flow_neighbor_int[:, 0]).reshape(-1, 1), + (flowB_neighbor[:, 1] - flow_neighbor_int[:, 1]).reshape(-1, 1), + np.zeros((flowB_neighbor[:, 0].shape[0])).reshape(-1, 1)), 1) + + # Check if the transitive backward flow neighbor of [y, x, t] is known. + # Sometimes after refinement, it is no longer known. + flowNN_tmp = copy.deepcopy(flowNN[numPixInd[flow_neighbor_int[:, 0], + flow_neighbor_int[:, 1], + indFrame - 1].astype(np.int32), :, NN_idx] + refineVec[:, :]) + flowNN_tmp = np.round(flowNN_tmp).astype(np.int32) + + # Check out-of-boundary. flowNN_tmp may be out-of-boundary + ValidPos_ = np.logical_and( + np.logical_and(flowNN_tmp[:, 0] >= 0, + flowNN_tmp[:, 0] < imgH - 1), + np.logical_and(flowNN_tmp[:, 1] >= 0, + flowNN_tmp[:, 1] < imgW - 1)) + + # Change the out-of-boundary value to 0, in order to run mask[y,x,t] + # in the next line. It won't affect anything as ValidPos_ is saved already + flowNN_tmp[np.invert(ValidPos_), :] = 0 + ValidNN = mask[flowNN_tmp[:, 0], + flowNN_tmp[:, 1], + flowNN_tmp[:, 2]] == 0 + + # Valid = np.logical_and.reduce((Valid_, ValidNN, ValidPos_)) + Valid = np.logical_and.reduce((Valid_, ValidPos_)) + + # We save the transitive backward flow neighbor flowB_neighbor in flowNN + flowNN[numPixInd[holepixPos[Valid, 0], + holepixPos[Valid, 1], + indFrame].astype(np.int32), :, NN_idx] = \ + flowNN[numPixInd[flow_neighbor_int[Valid, 0], + flow_neighbor_int[Valid, 1], + indFrame - 1].astype(np.int32), :, NN_idx] + refineVec[Valid, :] + + # We mark [y, x, t] in HaveFlowNN as 1 + HaveFlowNN[holepixPos[Valid, 0], + holepixPos[Valid, 1], + indFrame, + NN_idx] = 1 + + consistency_uv[holepixPos[Valid, 0], holepixPos[Valid, 1], NN_idx, 0, indFrame] = np.maximum(np.abs(BF_uv[holepixPos[Valid, 0], holepixPos[Valid, 1], 0]), np.abs(consistency_uv[flow_neighbor_int[Valid, 0], flow_neighbor_int[Valid, 1], NN_idx, 0, indFrame - 1])) + consistency_uv[holepixPos[Valid, 0], holepixPos[Valid, 1], NN_idx, 1, indFrame] = np.maximum(np.abs(BF_uv[holepixPos[Valid, 0], holepixPos[Valid, 1], 1]), np.abs(consistency_uv[flow_neighbor_int[Valid, 0], flow_neighbor_int[Valid, 1], NN_idx, 1, indFrame - 1])) + + consistencyMap[:, :, NN_idx, indFrame] = (consistency_uv[:, :, NN_idx, 0, indFrame] ** 2 + consistency_uv[:, :, NN_idx, 1, indFrame] ** 2) ** 0.5 + + print("Frame {0:3d}: {1:8d} + {2:8d} = {3:8d}" + .format(indFrame, + np.sum(HaveFlowNN[:, :, indFrame, NN_idx] == 1), + np.sum(HaveFlowNN[:, :, indFrame, NN_idx] == 0), + np.sum(HaveFlowNN[:, :, indFrame, NN_idx] != 99999))) + + # 2. Backward Pass (forward flow propagation) + print('Backward Pass......') + + NN_idx = 1 # FN:1 + for indFrame in frameIndSetB: + + # Bool indicator of missing pixels at frame t + holepixPosInd = (sub[:, 2] == indFrame) + + # Hole pixel location at frame t, i.e. [y, x, t] + holepixPos = sub[holepixPosInd, :] + + # Calculate the forward flow neighbor. Should be located at frame t+1 + flowF_neighbor = copy.deepcopy(holepixPos) + flowF_neighbor = flowF_neighbor.astype(np.float32) + + flowF_vertical = videoFlowF[:, :, 1, indFrame] # t --> t+1 + flowF_horizont = videoFlowF[:, :, 0, indFrame] + flowB_vertical = videoFlowB[:, :, 1, indFrame] # t+1 --> t + flowB_horizont = videoFlowB[:, :, 0, indFrame] + + flowF_neighbor[:, 0] += flowF_vertical[holepixPos[:, 0], holepixPos[:, 1]] + flowF_neighbor[:, 1] += flowF_horizont[holepixPos[:, 0], holepixPos[:, 1]] + flowF_neighbor[:, 2] += 1 + + # Round the forward flow neighbor location + flow_neighbor_int = np.round(copy.deepcopy(flowF_neighbor)).astype(np.int32) + + # Check the forawrd/backward consistency + IsConsist, _ = FBconsistCheck(flowF_neighbor, + flowB_vertical, + flowB_horizont, + holepixPos, + args.consistencyThres) + + FBdiff, FB_uv = consistCheck(videoFlowB[:, :, :, indFrame], + videoFlowF[:, :, :, indFrame]) + + # Check out-of-boundary + # Last column and last row does not have valid gradient + ValidPos = np.logical_and( + np.logical_and(flow_neighbor_int[:, 0] >= 0, + flow_neighbor_int[:, 0] < imgH - 1), + np.logical_and(flow_neighbor_int[:, 1] >= 0, + flow_neighbor_int[:, 1] < imgW - 1)) + + # Only work with pixels that are not out-of-boundary + holepixPos = holepixPos[ValidPos, :] + flowF_neighbor = flowF_neighbor[ValidPos, :] + flow_neighbor_int = flow_neighbor_int[ValidPos, :] + IsConsist = IsConsist[ValidPos] + + # Case 1: + KnownInd = mask[flow_neighbor_int[:, 0], + flow_neighbor_int[:, 1], + indFrame + 1] == 0 + + KnownIsConsist = np.logical_and(KnownInd, IsConsist) + flowNN[numPixInd[holepixPos[KnownIsConsist, 0], + holepixPos[KnownIsConsist, 1], + indFrame].astype(np.int32), :, NN_idx] = \ + flowF_neighbor[KnownIsConsist, :] + + HaveFlowNN[holepixPos[KnownIsConsist, 0], + holepixPos[KnownIsConsist, 1], + indFrame, + NN_idx] = 1 + + consistency_uv[holepixPos[KnownIsConsist, 0], holepixPos[KnownIsConsist, 1], NN_idx, 0, indFrame] = np.abs(FB_uv[holepixPos[KnownIsConsist, 0], holepixPos[KnownIsConsist, 1], 0]) + consistency_uv[holepixPos[KnownIsConsist, 0], holepixPos[KnownIsConsist, 1], NN_idx, 1, indFrame] = np.abs(FB_uv[holepixPos[KnownIsConsist, 0], holepixPos[KnownIsConsist, 1], 1]) + + # Case 2: + UnknownInd = np.invert(KnownInd) + HaveNNInd = HaveFlowNN[flow_neighbor_int[:, 0], + flow_neighbor_int[:, 1], + indFrame + 1, + NN_idx] == 1 + + # Unknown & IsConsist & HaveNNInd + Valid_ = np.logical_and.reduce((UnknownInd, HaveNNInd, IsConsist)) + + refineVec = np.concatenate(( + (flowF_neighbor[:, 0] - flow_neighbor_int[:, 0]).reshape(-1, 1), + (flowF_neighbor[:, 1] - flow_neighbor_int[:, 1]).reshape(-1, 1), + np.zeros((flowF_neighbor[:, 0].shape[0])).reshape(-1, 1)), 1) + + # Check if the transitive backward flow neighbor of [y, x, t] is known. + # Sometimes after refinement, it is no longer known. + flowNN_tmp = copy.deepcopy(flowNN[numPixInd[flow_neighbor_int[:, 0], + flow_neighbor_int[:, 1], + indFrame + 1].astype(np.int32), :, NN_idx] + refineVec[:, :]) + flowNN_tmp = np.round(flowNN_tmp).astype(np.int32) + + # Check out-of-boundary. flowNN_tmp may be out-of-boundary + ValidPos_ = np.logical_and( + np.logical_and(flowNN_tmp[:, 0] >= 0, + flowNN_tmp[:, 0] < imgH - 1), + np.logical_and(flowNN_tmp[:, 1] >= 0, + flowNN_tmp[:, 1] < imgW - 1)) + + # Change the out-of-boundary value to 0, in order to run mask[y,x,t] + # in the next line. It won't affect anything as ValidPos_ is saved already + flowNN_tmp[np.invert(ValidPos_), :] = 0 + ValidNN = mask[flowNN_tmp[:, 0], + flowNN_tmp[:, 1], + flowNN_tmp[:, 2]] == 0 + + # Valid = np.logical_and.reduce((Valid_, ValidNN, ValidPos_)) + Valid = np.logical_and.reduce((Valid_, ValidPos_)) + + # We save the transitive backward flow neighbor flowB_neighbor in flowNN + flowNN[numPixInd[holepixPos[Valid, 0], + holepixPos[Valid, 1], + indFrame].astype(np.int32), :, NN_idx] = \ + flowNN[numPixInd[flow_neighbor_int[Valid, 0], + flow_neighbor_int[Valid, 1], + indFrame + 1].astype(np.int32), :, NN_idx] + refineVec[Valid, :] + + # We mark [y, x, t] in HaveFlowNN as 1 + HaveFlowNN[holepixPos[Valid, 0], + holepixPos[Valid, 1], + indFrame, + NN_idx] = 1 + + consistency_uv[holepixPos[Valid, 0], holepixPos[Valid, 1], NN_idx, 0, indFrame] = np.maximum(np.abs(FB_uv[holepixPos[Valid, 0], holepixPos[Valid, 1], 0]), np.abs(consistency_uv[flow_neighbor_int[Valid, 0], flow_neighbor_int[Valid, 1], NN_idx, 0, indFrame + 1])) + consistency_uv[holepixPos[Valid, 0], holepixPos[Valid, 1], NN_idx, 1, indFrame] = np.maximum(np.abs(FB_uv[holepixPos[Valid, 0], holepixPos[Valid, 1], 1]), np.abs(consistency_uv[flow_neighbor_int[Valid, 0], flow_neighbor_int[Valid, 1], NN_idx, 1, indFrame + 1])) + + consistencyMap[:, :, NN_idx, indFrame] = (consistency_uv[:, :, NN_idx, 0, indFrame] ** 2 + consistency_uv[:, :, NN_idx, 1, indFrame] ** 2) ** 0.5 + + print("Frame {0:3d}: {1:8d} + {2:8d} = {3:8d}" + .format(indFrame, + np.sum(HaveFlowNN[:, :, indFrame, NN_idx] == 1), + np.sum(HaveFlowNN[:, :, indFrame, NN_idx] == 0), + np.sum(HaveFlowNN[:, :, indFrame, NN_idx] != 99999))) + + # Interpolation + gradient_x_BN = copy.deepcopy(gradient_x) + gradient_y_BN = copy.deepcopy(gradient_y) + gradient_x_FN = copy.deepcopy(gradient_x) + gradient_y_FN = copy.deepcopy(gradient_y) + + for indFrame in range(nFrame): + # Index of missing pixel whose backward flow neighbor is from frame indFrame + SourceFmInd = np.where(flowNN[:, 2, 0] == indFrame) + + print("{0:8d} pixels are from source Frame {1:3d}" + .format(len(SourceFmInd[0]), indFrame)) + # The location of the missing pixel whose backward flow neighbor is + # from frame indFrame flowNN[SourceFmInd, 0, 0], flowNN[SourceFmInd, 1, 0] + + if len(SourceFmInd[0]) != 0: + + # |--------------------| + # | y | + # | x * | + # | | + # |--------------------| + # sub: numPix x 3 [y, x, t] + # img: [y, x] + # interp(img, x, y) + + gradient_x_BN[sub[SourceFmInd[0], :][:, 0], + sub[SourceFmInd[0], :][:, 1], + :, sub[SourceFmInd[0], :][:, 2]] = \ + interp(gradient_x_BN[:, :, :, indFrame], + flowNN[SourceFmInd, 1, 0].reshape(-1), + flowNN[SourceFmInd, 0, 0].reshape(-1)) + + gradient_y_BN[sub[SourceFmInd[0], :][:, 0], + sub[SourceFmInd[0], :][:, 1], + :, sub[SourceFmInd[0], :][:, 2]] = \ + interp(gradient_y_BN[:, :, :, indFrame], + flowNN[SourceFmInd, 1, 0].reshape(-1), + flowNN[SourceFmInd, 0, 0].reshape(-1)) + + assert(((sub[SourceFmInd[0], :][:, 2] - indFrame) <= 0).sum() == 0) + + for indFrame in range(nFrame - 1, -1, -1): + # Index of missing pixel whose forward flow neighbor is from frame indFrame + SourceFmInd = np.where(flowNN[:, 2, 1] == indFrame) + print("{0:8d} pixels are from source Frame {1:3d}" + .format(len(SourceFmInd[0]), indFrame)) + if len(SourceFmInd[0]) != 0: + + gradient_x_FN[sub[SourceFmInd[0], :][:, 0], + sub[SourceFmInd[0], :][:, 1], + :, sub[SourceFmInd[0], :][:, 2]] = \ + interp(gradient_x_FN[:, :, :, indFrame], + flowNN[SourceFmInd, 1, 1].reshape(-1), + flowNN[SourceFmInd, 0, 1].reshape(-1)) + + gradient_y_FN[sub[SourceFmInd[0], :][:, 0], + sub[SourceFmInd[0], :][:, 1], + :, sub[SourceFmInd[0], :][:, 2]] = \ + interp(gradient_y_FN[:, :, :, indFrame], + flowNN[SourceFmInd, 1, 1].reshape(-1), + flowNN[SourceFmInd, 0, 1].reshape(-1)) + + assert(((indFrame - sub[SourceFmInd[0], :][:, 2]) <= 0).sum() == 0) + + # New mask + mask_tofill = np.zeros((imgH, imgW, nFrame)).astype(np.bool) + + # videoNonLocalFlowB = np.empty(((imgH, imgW, 2, 3, nFrame)), dtype=np.float32) + # videoNonLocalFlowF = np.empty(((imgH, imgW, 2, 3, nFrame)), dtype=np.float32) + + for indFrame in range(nFrame): + if args.Nonlocal: + consistencyMap[:, :, 2, indFrame], _ = consistCheck( + videoNonLocalFlowB[:, :, :, 0, indFrame], + videoNonLocalFlowF[:, :, :, 0, indFrame]) + consistencyMap[:, :, 3, indFrame], _ = consistCheck( + videoNonLocalFlowB[:, :, :, 1, indFrame], + videoNonLocalFlowF[:, :, :, 1, indFrame]) + consistencyMap[:, :, 4, indFrame], _ = consistCheck( + videoNonLocalFlowB[:, :, :, 2, indFrame], + videoNonLocalFlowF[:, :, :, 2, indFrame]) + + HaveNN = np.zeros((imgH, imgW, num_candidate)) + + if args.Nonlocal: + HaveKeySourceFrameFlowNN, gradient_x_KeySourceFrameFlowNN, gradient_y_KeySourceFrameFlowNN = \ + get_KeySourceFrame_flowNN_gradient(sub, + indFrame, + mask, + videoNonLocalFlowB, + videoNonLocalFlowF, + gradient_x, + gradient_y, + args.consistencyThres) + + HaveNN[:, :, 2] = HaveKeySourceFrameFlowNN[:, :, 0] == 1 + HaveNN[:, :, 3] = HaveKeySourceFrameFlowNN[:, :, 1] == 1 + HaveNN[:, :, 4] = HaveKeySourceFrameFlowNN[:, :, 2] == 1 + + HaveNN[:, :, 0] = HaveFlowNN[:, :, indFrame, 0] == 1 + HaveNN[:, :, 1] = HaveFlowNN[:, :, indFrame, 1] == 1 + + NotHaveNN = np.logical_and(np.invert(HaveNN.astype(np.bool)), + np.repeat(np.expand_dims((mask[:, :, indFrame]), 2), num_candidate, axis=2)) + + if args.Nonlocal: + HaveNN_sum = np.logical_or.reduce((HaveNN[:, :, 0], + HaveNN[:, :, 1], + HaveNN[:, :, 2], + HaveNN[:, :, 3], + HaveNN[:, :, 4])) + else: + HaveNN_sum = np.logical_or.reduce((HaveNN[:, :, 0], + HaveNN[:, :, 1])) + + gradient_x_Candidate = np.zeros((imgH, imgW, 3, num_candidate)) + gradient_y_Candidate = np.zeros((imgH, imgW, 3, num_candidate)) + + gradient_x_Candidate[:, :, :, 0] = gradient_x_BN[:, :, :, indFrame] + gradient_y_Candidate[:, :, :, 0] = gradient_y_BN[:, :, :, indFrame] + gradient_x_Candidate[:, :, :, 1] = gradient_x_FN[:, :, :, indFrame] + gradient_y_Candidate[:, :, :, 1] = gradient_y_FN[:, :, :, indFrame] + + if args.Nonlocal: + gradient_x_Candidate[:, :, :, 2] = gradient_x_KeySourceFrameFlowNN[:, :, :, 0] + gradient_y_Candidate[:, :, :, 2] = gradient_y_KeySourceFrameFlowNN[:, :, :, 0] + gradient_x_Candidate[:, :, :, 3] = gradient_x_KeySourceFrameFlowNN[:, :, :, 1] + gradient_y_Candidate[:, :, :, 3] = gradient_y_KeySourceFrameFlowNN[:, :, :, 1] + gradient_x_Candidate[:, :, :, 4] = gradient_x_KeySourceFrameFlowNN[:, :, :, 2] + gradient_y_Candidate[:, :, :, 4] = gradient_y_KeySourceFrameFlowNN[:, :, :, 2] + + consistencyMap[:, :, :, indFrame] = np.exp( - consistencyMap[:, :, :, indFrame] / args.alpha) + + consistencyMap[NotHaveNN[:, :, 0], 0, indFrame] = 0 + consistencyMap[NotHaveNN[:, :, 1], 1, indFrame] = 0 + + if args.Nonlocal: + consistencyMap[NotHaveNN[:, :, 2], 2, indFrame] = 0 + consistencyMap[NotHaveNN[:, :, 3], 3, indFrame] = 0 + consistencyMap[NotHaveNN[:, :, 4], 4, indFrame] = 0 + + weights = (consistencyMap[HaveNN_sum, :, indFrame] * HaveNN[HaveNN_sum, :]) / ((consistencyMap[HaveNN_sum, :, indFrame] * HaveNN[HaveNN_sum, :]).sum(axis=1, keepdims=True)) + + # Fix the numerical issue. 0 / 0 + fix = np.where((consistencyMap[HaveNN_sum, :, indFrame] * HaveNN[HaveNN_sum, :]).sum(axis=1, keepdims=True) == 0)[0] + weights[fix, :] = HaveNN[HaveNN_sum, :][fix, :] / HaveNN[HaveNN_sum, :][fix, :].sum(axis=1, keepdims=True) + + # Fuse RGB channel independently + gradient_x[HaveNN_sum, 0, indFrame] = \ + np.sum(np.multiply(gradient_x_Candidate[HaveNN_sum, 0, :], weights), axis=1) + gradient_x[HaveNN_sum, 1, indFrame] = \ + np.sum(np.multiply(gradient_x_Candidate[HaveNN_sum, 1, :], weights), axis=1) + gradient_x[HaveNN_sum, 2, indFrame] = \ + np.sum(np.multiply(gradient_x_Candidate[HaveNN_sum, 2, :], weights), axis=1) + + gradient_y[HaveNN_sum, 0, indFrame] = \ + np.sum(np.multiply(gradient_y_Candidate[HaveNN_sum, 0, :], weights), axis=1) + gradient_y[HaveNN_sum, 1, indFrame] = \ + np.sum(np.multiply(gradient_y_Candidate[HaveNN_sum, 1, :], weights), axis=1) + gradient_y[HaveNN_sum, 2, indFrame] = \ + np.sum(np.multiply(gradient_y_Candidate[HaveNN_sum, 2, :], weights), axis=1) + + mask_tofill[np.logical_and(np.invert(HaveNN_sum), mask[:, :, indFrame]), indFrame] = True + + return gradient_x, gradient_y, mask_tofill diff --git a/modules/ftn.py b/modules/image_stack.py similarity index 100% rename from modules/ftn.py rename to modules/image_stack.py diff --git a/modules/spatial_inpaint.py b/modules/spatial_inpaint.py new file mode 100644 index 0000000..19f5acc --- /dev/null +++ b/modules/spatial_inpaint.py @@ -0,0 +1,16 @@ +from __future__ import absolute_import, division, print_function, unicode_literals +import os +import cv2 +import numpy as np +import torch + + +def spatial_inpaint(deepfill, mask, video_comp): + + keyFrameInd = np.argmax(np.sum(np.sum(mask, axis=0), axis=0)) + with torch.no_grad(): + img_res = deepfill.forward(video_comp[:, :, :, keyFrameInd] * 255., mask[:, :, keyFrameInd]) / 255. + video_comp[mask[:, :, keyFrameInd], :, keyFrameInd] = img_res[mask[:, :, keyFrameInd], :] + mask[:, :, keyFrameInd] = False + + return mask, video_comp diff --git a/modules/utils/Poisson_blend.py b/modules/utils/Poisson_blend.py new file mode 100644 index 0000000..7a259e9 --- /dev/null +++ b/modules/utils/Poisson_blend.py @@ -0,0 +1,213 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import scipy.ndimage +from scipy.sparse.linalg import spsolve +from scipy import sparse +import scipy.io as sio +import numpy as np +from PIL import Image +import copy +import cv2 +import os +import argparse + + +def sub2ind(pi, pj, imgH, imgW): + return pj + pi * imgW + + +def Poisson_blend(imgTrg, imgSrc_gx, imgSrc_gy, holeMask, edge=None): + + imgH, imgW, nCh = imgTrg.shape + + if not isinstance(edge, np.ndarray): + edge = np.zeros((imgH, imgW), dtype=np.float32) + + # Initialize the reconstructed image + imgRecon = np.zeros((imgH, imgW, nCh), dtype=np.float32) + + # prepare discrete Poisson equation + A, b = solvePoisson(holeMask, imgSrc_gx, imgSrc_gy, imgTrg, edge) + + # Independently process each channel + for ch in range(nCh): + + # solve Poisson equation + x = scipy.sparse.linalg.lsqr(A, b[:, ch, None])[0] + imgRecon[:, :, ch] = x.reshape(imgH, imgW) + + # Combined with the known region in the target + holeMaskC = np.tile(np.expand_dims(holeMask, axis=2), (1, 1, nCh)) + imgBlend = holeMaskC * imgRecon + (1 - holeMaskC) * imgTrg + + # Fill in edge pixel + pi = np.expand_dims(np.where((holeMask * edge) == 1)[0], axis=1) # y, i + pj = np.expand_dims(np.where((holeMask * edge) == 1)[1], axis=1) # x, j + + for k in range(len(pi)): + if pi[k, 0] - 1 >= 0: + if edge[pi[k, 0] - 1, pj[k, 0]] == 0: + imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0] - 1, pj[k, 0], :] + continue + if pi[k, 0] + 1 <= imgH - 1: + if edge[pi[k, 0] + 1, pj[k, 0]] == 0: + imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0] + 1, pj[k, 0], :] + continue + if pj[k, 0] - 1 >= 0: + if edge[pi[k, 0], pj[k, 0] - 1] == 0: + imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0], pj[k, 0] - 1, :] + continue + if pj[k, 0] + 1 <= imgW - 1: + if edge[pi[k, 0], pj[k, 0] + 1] == 0: + imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0], pj[k, 0] + 1, :] + + return imgBlend + +def solvePoisson(holeMask, imgSrc_gx, imgSrc_gy, imgTrg, edge): + + # Prepare the linear system of equations for Poisson blending + imgH, imgW = holeMask.shape + N = imgH * imgW + + # Number of unknown variables + numUnknownPix = holeMask.sum() + + # 4-neighbors: dx and dy + dx = [1, 0, -1, 0] + dy = [0, 1, 0, -1] + + # 3 + # | + # 2 -- * -- 0 + # | + # 1 + # + + # Initialize (I, J, S), for sparse matrix A where A(I(k), J(k)) = S(k) + I = np.empty((0, 1), dtype=np.float32) + J = np.empty((0, 1), dtype=np.float32) + S = np.empty((0, 1), dtype=np.float32) + + # Initialize b + b = np.empty((0, 2), dtype=np.float32) + + # Precompute unkonwn pixel position + pi = np.expand_dims(np.where(holeMask == 1)[0], axis=1) # y, i + pj = np.expand_dims(np.where(holeMask == 1)[1], axis=1) # x, j + pind = sub2ind(pi, pj, imgH, imgW) + + # |--------------------| + # | y (i) | + # | x (j) * | + # | | + # |--------------------| + + qi = np.concatenate((pi + dy[0], + pi + dy[1], + pi + dy[2], + pi + dy[3]), axis=1) + + qj = np.concatenate((pj + dx[0], + pj + dx[1], + pj + dx[2], + pj + dx[3]), axis=1) + + # Handling cases at image borders + validN = (qi >= 0) & (qi <= imgH - 1) & (qj >= 0) & (qj <= imgW - 1) + qind = np.zeros((validN.shape), dtype=np.float32) + qind[validN] = sub2ind(qi[validN], qj[validN], imgH, imgW) + + e_start = 0 # equation counter start + e_stop = 0 # equation stop + + # 4 neighbors + I, J, S, b, e_start, e_stop = constructEquation(0, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) + I, J, S, b, e_start, e_stop = constructEquation(1, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) + I, J, S, b, e_start, e_stop = constructEquation(2, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) + I, J, S, b, e_start, e_stop = constructEquation(3, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) + + nEqn = len(b) + # Construct the sparse matrix A + A = sparse.csr_matrix((S[:, 0], (I[:, 0], J[:, 0])), shape=(nEqn, N)) + + return A, b + + +def constructEquation(n, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop): + + # Pixel that has valid neighbors + validNeighbor = validN[:, n] + + # Change the out-of-boundary value to 0, in order to run edge[y,x] + # in the next line. It won't affect anything as validNeighbor is saved already + + qi_tmp = copy.deepcopy(qi) + qj_tmp = copy.deepcopy(qj) + qi_tmp[np.invert(validNeighbor), n] = 0 + qj_tmp[np.invert(validNeighbor), n] = 0 + + # Not edge + NotEdge = (edge[pi[:, 0], pj[:, 0]] == 0) * (edge[qi_tmp[:, n], qj_tmp[:, n]] == 0) + + # Boundary constraint + Boundary = holeMask[qi_tmp[:, n], qj_tmp[:, n]] == 0 + valid = validNeighbor * NotEdge * Boundary + J_tmp = pind[valid, :] + + # num of equations: len(J_tmp) + e_stop = e_start + len(J_tmp) + I_tmp = np.arange(e_start, e_stop, dtype=np.float32).reshape(-1, 1) + e_start = e_stop + + S_tmp = np.ones(J_tmp.shape, dtype=np.float32) + + if n == 0: + b_tmp = - imgSrc_gx[pi[valid, 0], pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :] + elif n == 2: + b_tmp = imgSrc_gx[pi[valid, 0], pj[valid, 0] - 1, :] + imgTrg[qi[valid, n], qj[valid, n], :] + elif n == 1: + b_tmp = - imgSrc_gy[pi[valid, 0], pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :] + elif n == 3: + b_tmp = imgSrc_gy[pi[valid, 0] - 1, pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :] + + I = np.concatenate((I, I_tmp)) + J = np.concatenate((J, J_tmp)) + S = np.concatenate((S, S_tmp)) + b = np.concatenate((b, b_tmp)) + + + # Non-boundary constraint + NonBoundary = holeMask[qi_tmp[:, n], qj_tmp[:, n]] == 1 + valid = validNeighbor * NotEdge * NonBoundary + + J_tmp = pind[valid, :] + + # num of equations: len(J_tmp) + e_stop = e_start + len(J_tmp) + I_tmp = np.arange(e_start, e_stop, dtype=np.float32).reshape(-1, 1) + e_start = e_stop + + S_tmp = np.ones(J_tmp.shape, dtype=np.float32) + + if n == 0: + b_tmp = - imgSrc_gx[pi[valid, 0], pj[valid, 0], :] + elif n == 2: + b_tmp = imgSrc_gx[pi[valid, 0], pj[valid, 0] - 1, :] + elif n == 1: + b_tmp = - imgSrc_gy[pi[valid, 0], pj[valid, 0], :] + elif n == 3: + b_tmp = imgSrc_gy[pi[valid, 0] - 1, pj[valid, 0], :] + + I = np.concatenate((I, I_tmp)) + J = np.concatenate((J, J_tmp)) + S = np.concatenate((S, S_tmp)) + b = np.concatenate((b, b_tmp)) + + S_tmp = - np.ones(J_tmp.shape, dtype=np.float32) + J_tmp = qind[valid, n, None] + + I = np.concatenate((I, I_tmp)) + J = np.concatenate((J, J_tmp)) + S = np.concatenate((S, S_tmp)) + + return I, J, S, b, e_start, e_stop diff --git a/modules/utils/Poisson_blend_img.py b/modules/utils/Poisson_blend_img.py new file mode 100644 index 0000000..28d3d11 --- /dev/null +++ b/modules/utils/Poisson_blend_img.py @@ -0,0 +1,270 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import scipy.ndimage +from scipy.sparse.linalg import spsolve +from scipy import sparse +import scipy.io as sio +import numpy as np +from PIL import Image +import copy +import cv2 +import os +import argparse + + +def sub2ind(pi, pj, imgH, imgW): + return pj + pi * imgW + + +def Poisson_blend_img(imgTrg, imgSrc_gx, imgSrc_gy, holeMask, gradientMask=None, edge=None): + + imgH, imgW, nCh = imgTrg.shape + + if not isinstance(gradientMask, np.ndarray): + gradientMask = np.zeros((imgH, imgW), dtype=np.float32) + + if not isinstance(edge, np.ndarray): + edge = np.zeros((imgH, imgW), dtype=np.float32) + + # Initialize the reconstructed image + imgRecon = np.zeros((imgH, imgW, nCh), dtype=np.float32) + + # prepare discrete Poisson equation + A, b, UnfilledMask = solvePoisson(holeMask, imgSrc_gx, imgSrc_gy, imgTrg, + gradientMask, edge) + + # Independently process each channel + for ch in range(nCh): + + # solve Poisson equation + x = scipy.sparse.linalg.lsqr(A, b[:, ch])[0] + + imgRecon[:, :, ch] = x.reshape(imgH, imgW) + + # Combined with the known region in the target + holeMaskC = np.tile(np.expand_dims(holeMask, axis=2), (1, 1, nCh)) + imgBlend = holeMaskC * imgRecon + (1 - holeMaskC) * imgTrg + + + # while((UnfilledMask * edge).sum() != 0): + # # Fill in edge pixel + # pi = np.expand_dims(np.where((UnfilledMask * edge) == 1)[0], axis=1) # y, i + # pj = np.expand_dims(np.where((UnfilledMask * edge) == 1)[1], axis=1) # x, j + # + # for k in range(len(pi)): + # if pi[k, 0] - 1 >= 0: + # if (UnfilledMask * edge)[pi[k, 0] - 1, pj[k, 0]] == 0: + # imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0] - 1, pj[k, 0], :] + # UnfilledMask[pi[k, 0], pj[k, 0]] = 0 + # continue + # if pi[k, 0] + 1 <= imgH - 1: + # if (UnfilledMask * edge)[pi[k, 0] + 1, pj[k, 0]] == 0: + # imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0] + 1, pj[k, 0], :] + # UnfilledMask[pi[k, 0], pj[k, 0]] = 0 + # continue + # if pj[k, 0] - 1 >= 0: + # if (UnfilledMask * edge)[pi[k, 0], pj[k, 0] - 1] == 0: + # imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0], pj[k, 0] - 1, :] + # UnfilledMask[pi[k, 0], pj[k, 0]] = 0 + # continue + # if pj[k, 0] + 1 <= imgW - 1: + # if (UnfilledMask * edge)[pi[k, 0], pj[k, 0] + 1] == 0: + # imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0], pj[k, 0] + 1, :] + # UnfilledMask[pi[k, 0], pj[k, 0]] = 0 + + return imgBlend, UnfilledMask + +def solvePoisson(holeMask, imgSrc_gx, imgSrc_gy, imgTrg, + gradientMask, edge): + + # UnfilledMask indicates the region that is not completed + UnfilledMask_topleft = copy.deepcopy(holeMask) + UnfilledMask_bottomright = copy.deepcopy(holeMask) + + # Prepare the linear system of equations for Poisson blending + imgH, imgW = holeMask.shape + N = imgH * imgW + + # Number of unknown variables + numUnknownPix = holeMask.sum() + + # 4-neighbors: dx and dy + dx = [1, 0, -1, 0] + dy = [0, 1, 0, -1] + + # 3 + # | + # 2 -- * -- 0 + # | + # 1 + # + + # Initialize (I, J, S), for sparse matrix A where A(I(k), J(k)) = S(k) + I = np.empty((0, 1), dtype=np.float32) + J = np.empty((0, 1), dtype=np.float32) + S = np.empty((0, 1), dtype=np.float32) + + # Initialize b + b = np.empty((0, 3), dtype=np.float32) + + # Precompute unkonwn pixel position + pi = np.expand_dims(np.where(holeMask == 1)[0], axis=1) # y, i + pj = np.expand_dims(np.where(holeMask == 1)[1], axis=1) # x, j + pind = sub2ind(pi, pj, imgH, imgW) + + # |--------------------| + # | y (i) | + # | x (j) * | + # | | + # |--------------------| + # p[y, x] + + qi = np.concatenate((pi + dy[0], + pi + dy[1], + pi + dy[2], + pi + dy[3]), axis=1) + + qj = np.concatenate((pj + dx[0], + pj + dx[1], + pj + dx[2], + pj + dx[3]), axis=1) + + # Handling cases at image borders + validN = (qi >= 0) & (qi <= imgH - 1) & (qj >= 0) & (qj <= imgW - 1) + qind = np.zeros((validN.shape), dtype=np.float32) + qind[validN] = sub2ind(qi[validN], qj[validN], imgH, imgW) + + e_start = 0 # equation counter start + e_stop = 0 # equation stop + + # 4 neighbors + I, J, S, b, e_start, e_stop = constructEquation(0, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) + I, J, S, b, e_start, e_stop = constructEquation(1, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) + I, J, S, b, e_start, e_stop = constructEquation(2, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) + I, J, S, b, e_start, e_stop = constructEquation(3, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) + + nEqn = len(b) + # Construct the sparse matrix A + A = sparse.csr_matrix((S[:, 0], (I[:, 0], J[:, 0])), shape=(nEqn, N)) + + # Check connected pixels + for ind in range(0, len(pi), 1): + ii = pi[ind, 0] + jj = pj[ind, 0] + + # check up (3) + if ii - 1 >= 0: + if UnfilledMask_topleft[ii - 1, jj] == 0 and gradientMask[ii - 1, jj] == 0: + UnfilledMask_topleft[ii, jj] = 0 + # check left (2) + if jj - 1 >= 0: + if UnfilledMask_topleft[ii, jj - 1] == 0 and gradientMask[ii, jj - 1] == 0: + UnfilledMask_topleft[ii, jj] = 0 + + + for ind in range(len(pi) - 1, -1, -1): + ii = pi[ind, 0] + jj = pj[ind, 0] + + # check bottom (1) + if ii + 1 <= imgH - 1: + if UnfilledMask_bottomright[ii + 1, jj] == 0 and gradientMask[ii, jj] == 0: + UnfilledMask_bottomright[ii, jj] = 0 + # check right (0) + if jj + 1 <= imgW - 1: + if UnfilledMask_bottomright[ii, jj + 1] == 0 and gradientMask[ii, jj] == 0: + UnfilledMask_bottomright[ii, jj] = 0 + + UnfilledMask = UnfilledMask_topleft * UnfilledMask_bottomright + + return A, b, UnfilledMask + + +def constructEquation(n, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop): + + # Pixel that has valid neighbors + validNeighbor = validN[:, n] + + # Change the out-of-boundary value to 0, in order to run edge[y,x] + # in the next line. It won't affect anything as validNeighbor is saved already + + qi_tmp = copy.deepcopy(qi) + qj_tmp = copy.deepcopy(qj) + qi_tmp[np.invert(validNeighbor), n] = 0 + qj_tmp[np.invert(validNeighbor), n] = 0 + + NotEdge = (edge[pi[:, 0], pj[:, 0]] == 0) * (edge[qi_tmp[:, n], qj_tmp[:, n]] == 0) + + # Have gradient + if n == 0: + HaveGrad = gradientMask[pi[:, 0], pj[:, 0]] == 0 + elif n == 2: + HaveGrad = gradientMask[pi[:, 0], pj[:, 0] - 1] == 0 + elif n == 1: + HaveGrad = gradientMask[pi[:, 0], pj[:, 0]] == 0 + elif n == 3: + HaveGrad = gradientMask[pi[:, 0] - 1, pj[:, 0]] == 0 + + # Boundary constraint + Boundary = holeMask[qi_tmp[:, n], qj_tmp[:, n]] == 0 + + valid = validNeighbor * NotEdge * HaveGrad * Boundary + + J_tmp = pind[valid, :] + + # num of equations: len(J_tmp) + e_stop = e_start + len(J_tmp) + I_tmp = np.arange(e_start, e_stop, dtype=np.float32).reshape(-1, 1) + e_start = e_stop + + S_tmp = np.ones(J_tmp.shape, dtype=np.float32) + + if n == 0: + b_tmp = - imgSrc_gx[pi[valid, 0], pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :] + elif n == 2: + b_tmp = imgSrc_gx[pi[valid, 0], pj[valid, 0] - 1, :] + imgTrg[qi[valid, n], qj[valid, n], :] + elif n == 1: + b_tmp = - imgSrc_gy[pi[valid, 0], pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :] + elif n == 3: + b_tmp = imgSrc_gy[pi[valid, 0] - 1, pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :] + + I = np.concatenate((I, I_tmp)) + J = np.concatenate((J, J_tmp)) + S = np.concatenate((S, S_tmp)) + b = np.concatenate((b, b_tmp)) + + # Non-boundary constraint + NonBoundary = holeMask[qi_tmp[:, n], qj_tmp[:, n]] == 1 + valid = validNeighbor * NotEdge * HaveGrad * NonBoundary + + J_tmp = pind[valid, :] + + # num of equations: len(J_tmp) + e_stop = e_start + len(J_tmp) + I_tmp = np.arange(e_start, e_stop, dtype=np.float32).reshape(-1, 1) + e_start = e_stop + + S_tmp = np.ones(J_tmp.shape, dtype=np.float32) + + if n == 0: + b_tmp = - imgSrc_gx[pi[valid, 0], pj[valid, 0], :] + elif n == 2: + b_tmp = imgSrc_gx[pi[valid, 0], pj[valid, 0] - 1, :] + elif n == 1: + b_tmp = - imgSrc_gy[pi[valid, 0], pj[valid, 0], :] + elif n == 3: + b_tmp = imgSrc_gy[pi[valid, 0] - 1, pj[valid, 0], :] + + I = np.concatenate((I, I_tmp)) + J = np.concatenate((J, J_tmp)) + S = np.concatenate((S, S_tmp)) + b = np.concatenate((b, b_tmp)) + + S_tmp = - np.ones(J_tmp.shape, dtype=np.float32) + J_tmp = qind[valid, n, None] + + I = np.concatenate((I, I_tmp)) + J = np.concatenate((J, J_tmp)) + S = np.concatenate((S, S_tmp)) + + return I, J, S, b, e_start, e_stop diff --git a/modules/utils/__init__.py b/modules/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/utils/common_utils.py b/modules/utils/common_utils.py new file mode 100644 index 0000000..7836cf3 --- /dev/null +++ b/modules/utils/common_utils.py @@ -0,0 +1,613 @@ +from __future__ import absolute_import, division, print_function, unicode_literals +import torch +import torch.nn as nn +import cv2 +import copy +import numpy as np +import sys +import os +import time +from PIL import Image +import scipy.ndimage + + +def combine(img1, img2, slope=0.55, band_width=0.015, offset=0): + + imgH, imgW, _ = img1.shape + band_width = int(band_width * imgH) + + if img1.shape != img2.shape: + # img1 = cv2.resize(img1, (imgW, imgH)) + raise NameError('Shape does not match') + + center_point = (int(imgH / 2), int(imgW / 2 + offset)) + + b = (center_point[1] - 1) - slope * (center_point[0] - 1) + comp_img = np.zeros(img2.shape, dtype=np.float32) + + for x in range(imgH): + for y in range(imgW): + if y < (slope * x + b): + comp_img[x, y, :] = img1[x, y, :] + elif y > (slope * x + b): + comp_img[x, y, :] = img2[x, y, :] + + start_point = (int(b - 0.5 * band_width), 0) + end_point = (int(slope * (imgW - 1) + b - 0.5 * band_width), imgW - 1) + + color = (1, 1, 1) + comp_img = cv2.line(comp_img, start_point, end_point, color, band_width, lineType=cv2.LINE_AA) + + return comp_img + + +def save_video(in_dir, out_dir, optimize=False): + + _, ext = os.path.splitext(sorted(os.listdir(in_dir))[0]) + dir = '"' + os.path.join(in_dir, '*' + ext) + '"' + + if optimize: + os.system('ffmpeg -y -pattern_type glob -f image2 -i {} -pix_fmt yuv420p -preset veryslow -crf 27 {}'.format(dir, out_dir)) + else: + os.system('ffmpeg -y -pattern_type glob -f image2 -i {} -pix_fmt yuv420p {}'.format(dir, out_dir)) + +def create_dir(dir): + if not os.path.exists(dir): + os.makedirs(dir) + + +def bboxes_mask(imgH, imgW, type='ori'): + mask = np.zeros((imgH, imgW), dtype=np.float32) + factor = 1920 * 2 // imgW + + for indFrameH in range(int(imgH / (256 * 2 // factor))): + for indFrameW in range(int(imgW / (384 * 2 // factor))): + mask[indFrameH * (256 * 2 // factor) + (128 * 2 // factor) - (64 * 2 // factor) : + indFrameH * (256 * 2 // factor) + (128 * 2 // factor) + (64 * 2 // factor), + indFrameW * (384 * 2 // factor) + (192 * 2 // factor) - (64 * 2 // factor) : + indFrameW * (384 * 2 // factor) + (192 * 2 // factor) + (64 * 2 // factor)] = 1 + + if type == 'ori': + return mask + elif type == 'flow': + # Dilate 25 pixel so that all known pixel is trustworthy + return scipy.ndimage.binary_dilation(mask, iterations=15) + +def bboxes_mask_large(imgH, imgW, type='ori'): + mask = np.zeros((imgH, imgW), dtype=np.float32) + # mask[50 : 450, 280: 680] = 1 + mask[150 : 350, 350: 650] = 1 + + if type == 'ori': + return mask + elif type == 'flow': + # Dilate 35 pixel so that all known pixel is trustworthy + return scipy.ndimage.binary_dilation(mask, iterations=35) + +def gradient_mask(mask): + + gradient_mask = np.logical_or.reduce((mask, + np.concatenate((mask[1:, :], np.zeros((1, mask.shape[1]), dtype=np.bool)), axis=0), + np.concatenate((mask[:, 1:], np.zeros((mask.shape[0], 1), dtype=np.bool)), axis=1))) + + return gradient_mask + + +def flow_edge(flow, mask=None): + # mask: 1 indicates the missing region + if not isinstance(mask, np.ndarray): + mask = None + else: + # using 'mask' parameter prevents canny to detect edges for the masked regions + mask = (1 - mask).astype(np.bool) + + flow_mag = (flow[:, :, 0] ** 2 + flow[:, :, 1] ** 2) ** 0.5 + flow_mag = flow_mag / flow_mag.max() + + edge_canny_flow = canny_flow(flow_mag, flow, mask=mask) + edge_canny = canny(flow_mag, sigma=2, mask=mask) + + if edge_canny_flow.sum() > edge_canny.sum(): + return edge_canny_flow + else: + return edge_canny + + +def np_to_torch(img_np): + '''Converts image in numpy.array to torch.Tensor. + From C x W x H [0..1] to C x W x H [0..1] + ''' + return torch.from_numpy(img_np)[None, :] + + +def torch_to_np(img_var): + '''Converts an image in torch.Tensor format to np.array. + From 1 x C x W x H [0..1] to C x W x H [0..1] + ''' + return img_var.detach().cpu().numpy()[0] + + +def sigmoid_(x, thres): + return 1. / (1 + np.exp(-x + thres)) + + +# def softmax(x): +# e_x = np.exp(x - np.max(x)) +# return e_x / e_x.sum() + + +def softmax(x, axis=None, mask_=None): + + if mask_ is None: + mask_ = np.ones(x.shape) + x = (x - x.max(axis=axis, keepdims=True)) + y = np.multiply(np.exp(x), mask_) + return y / y.sum(axis=axis, keepdims=True) + + +# Bypass cv2's SHRT_MAX limitation +def interp(img, x, y): + + x = x.astype(np.float32).reshape(1, -1) + y = y.astype(np.float32).reshape(1, -1) + + assert(x.shape == y.shape) + + numPix = x.shape[1] + len_padding = (numPix // 1024 + 1) * 1024 - numPix + padding = np.zeros((1, len_padding)).astype(np.float32) + + map_x = np.concatenate((x, padding), axis=1).reshape(1024, numPix // 1024 + 1) + map_y = np.concatenate((y, padding), axis=1).reshape(1024, numPix // 1024 + 1) + + # Note that cv2 takes the input in opposite order, i.e. cv2.remap(img, x, y) + mapped_img = cv2.remap(img, map_x, map_y, cv2.INTER_LINEAR) + + if len(img.shape) == 2: + mapped_img = mapped_img.reshape(-1)[:numPix] + else: + mapped_img = mapped_img.reshape(-1, img.shape[2])[:numPix, :] + + return mapped_img + + +def imsave(img, path): + im = Image.fromarray(img.cpu().numpy().astype(np.uint8).squeeze()) + im.save(path) + + +def postprocess(img): + # [0, 1] => [0, 255] + img = img * 255.0 + img = img.permute(0, 2, 3, 1) + return img.int() + + +# Backward flow propagating and forward flow propagating consistency check +def BFconsistCheck(flowB_neighbor, flowF_vertical, flowF_horizont, + holepixPos, consistencyThres): + + flowBF_neighbor = copy.deepcopy(flowB_neighbor) + + # After the backward and forward propagation, the pixel should go back + # to the original location. + flowBF_neighbor[:, 0] += interp(flowF_vertical, + flowB_neighbor[:, 1], + flowB_neighbor[:, 0]) + flowBF_neighbor[:, 1] += interp(flowF_horizont, + flowB_neighbor[:, 1], + flowB_neighbor[:, 0]) + flowBF_neighbor[:, 2] += 1 + + # Check photometric consistency + BFdiff = ((flowBF_neighbor - holepixPos)[:, 0] ** 2 + + (flowBF_neighbor - holepixPos)[:, 1] ** 2) ** 0.5 + IsConsist = BFdiff < consistencyThres + + return IsConsist, BFdiff + + +# Forward flow propagating and backward flow propagating consistency check +def FBconsistCheck(flowF_neighbor, flowB_vertical, flowB_horizont, + holepixPos, consistencyThres): + + flowFB_neighbor = copy.deepcopy(flowF_neighbor) + + # After the forward and backward propagation, the pixel should go back + # to the original location. + flowFB_neighbor[:, 0] += interp(flowB_vertical, + flowF_neighbor[:, 1], + flowF_neighbor[:, 0]) + flowFB_neighbor[:, 1] += interp(flowB_horizont, + flowF_neighbor[:, 1], + flowF_neighbor[:, 0]) + flowFB_neighbor[:, 2] -= 1 + + # Check photometric consistency + FBdiff = ((flowFB_neighbor - holepixPos)[:, 0] ** 2 + + (flowFB_neighbor - holepixPos)[:, 1] ** 2) ** 0.5 + IsConsist = FBdiff < consistencyThres + + return IsConsist, FBdiff + + +def consistCheck(flowF, flowB): + + # |--------------------| |--------------------| + # | y | | v | + # | x * | | u * | + # | | | | + # |--------------------| |--------------------| + + # sub: numPix * [y x t] + + imgH, imgW, _ = flowF.shape + + (fy, fx) = np.mgrid[0 : imgH, 0 : imgW].astype(np.float32) + fxx = fx + flowB[:, :, 0] # horizontal + fyy = fy + flowB[:, :, 1] # vertical + + u = (fxx + cv2.remap(flowF[:, :, 0], fxx, fyy, cv2.INTER_LINEAR) - fx) + v = (fyy + cv2.remap(flowF[:, :, 1], fxx, fyy, cv2.INTER_LINEAR) - fy) + BFdiff = (u ** 2 + v ** 2) ** 0.5 + + return BFdiff, np.stack((u, v), axis=2) + + +def get_KeySourceFrame_flowNN(sub, + indFrame, + mask, + videoNonLocalFlowB, + videoNonLocalFlowF, + video, + consistencyThres): + + imgH, imgW, _, _, nFrame = videoNonLocalFlowF.shape + KeySourceFrame = [0, nFrame // 2, nFrame - 1] + + # Bool indicator of missing pixels at frame t + holepixPosInd = (sub[:, 2] == indFrame) + + # Hole pixel location at frame t, i.e. [x, y, t] + holepixPos = sub[holepixPosInd, :] + + HaveKeySourceFrameFlowNN = np.zeros((imgH, imgW, 3)) + imgKeySourceFrameFlowNN = np.zeros((imgH, imgW, 3, 3)) + + for KeySourceFrameIdx in range(3): + + # flowF_neighbor + flowF_neighbor = copy.deepcopy(holepixPos) + flowF_neighbor = flowF_neighbor.astype(np.float32) + flowF_vertical = videoNonLocalFlowF[:, :, 1, KeySourceFrameIdx, indFrame] + flowF_horizont = videoNonLocalFlowF[:, :, 0, KeySourceFrameIdx, indFrame] + flowB_vertical = videoNonLocalFlowB[:, :, 1, KeySourceFrameIdx, indFrame] + flowB_horizont = videoNonLocalFlowB[:, :, 0, KeySourceFrameIdx, indFrame] + + flowF_neighbor[:, 0] += flowF_vertical[holepixPos[:, 0], holepixPos[:, 1]] + flowF_neighbor[:, 1] += flowF_horizont[holepixPos[:, 0], holepixPos[:, 1]] + flowF_neighbor[:, 2] = KeySourceFrame[KeySourceFrameIdx] + + # Round the forward flow neighbor location + flow_neighbor_int = np.round(copy.deepcopy(flowF_neighbor)).astype(np.int32) + + # Check the forawrd/backward consistency + IsConsist, _ = FBconsistCheck(flowF_neighbor, flowB_vertical, + flowB_horizont, holepixPos, consistencyThres) + + # Check out-of-boundary + ValidPos = np.logical_and( + np.logical_and(flow_neighbor_int[:, 0] >= 0, + flow_neighbor_int[:, 0] < imgH), + np.logical_and(flow_neighbor_int[:, 1] >= 0, + flow_neighbor_int[:, 1] < imgW)) + + holepixPos_ = copy.deepcopy(holepixPos)[ValidPos, :] + flow_neighbor_int = flow_neighbor_int[ValidPos, :] + flowF_neighbor = flowF_neighbor[ValidPos, :] + IsConsist = IsConsist[ValidPos] + + KnownInd = mask[flow_neighbor_int[:, 0], + flow_neighbor_int[:, 1], + KeySourceFrame[KeySourceFrameIdx]] == 0 + + KnownInd = np.logical_and(KnownInd, IsConsist) + + imgKeySourceFrameFlowNN[:, :, :, KeySourceFrameIdx] = \ + copy.deepcopy(video[:, :, :, indFrame]) + + imgKeySourceFrameFlowNN[holepixPos_[KnownInd, 0], + holepixPos_[KnownInd, 1], + :, KeySourceFrameIdx] = \ + interp(video[:, :, :, KeySourceFrame[KeySourceFrameIdx]], + flowF_neighbor[KnownInd, 1].reshape(-1), + flowF_neighbor[KnownInd, 0].reshape(-1)) + + HaveKeySourceFrameFlowNN[holepixPos_[KnownInd, 0], + holepixPos_[KnownInd, 1], + KeySourceFrameIdx] = 1 + + return HaveKeySourceFrameFlowNN, imgKeySourceFrameFlowNN +# +def get_KeySourceFrame_flowNN_gradient(sub, + indFrame, + mask, + videoNonLocalFlowB, + videoNonLocalFlowF, + gradient_x, + gradient_y, + consistencyThres): + + imgH, imgW, _, _, nFrame = videoNonLocalFlowF.shape + KeySourceFrame = [0, nFrame // 2, nFrame - 1] + + # Bool indicator of missing pixels at frame t + holepixPosInd = (sub[:, 2] == indFrame) + + # Hole pixel location at frame t, i.e. [x, y, t] + holepixPos = sub[holepixPosInd, :] + + HaveKeySourceFrameFlowNN = np.zeros((imgH, imgW, 3)) + gradient_x_KeySourceFrameFlowNN = np.zeros((imgH, imgW, 3, 3)) + gradient_y_KeySourceFrameFlowNN = np.zeros((imgH, imgW, 3, 3)) + + for KeySourceFrameIdx in range(3): + + # flowF_neighbor + flowF_neighbor = copy.deepcopy(holepixPos) + flowF_neighbor = flowF_neighbor.astype(np.float32) + + flowF_vertical = videoNonLocalFlowF[:, :, 1, KeySourceFrameIdx, indFrame] + flowF_horizont = videoNonLocalFlowF[:, :, 0, KeySourceFrameIdx, indFrame] + flowB_vertical = videoNonLocalFlowB[:, :, 1, KeySourceFrameIdx, indFrame] + flowB_horizont = videoNonLocalFlowB[:, :, 0, KeySourceFrameIdx, indFrame] + + flowF_neighbor[:, 0] += flowF_vertical[holepixPos[:, 0], holepixPos[:, 1]] + flowF_neighbor[:, 1] += flowF_horizont[holepixPos[:, 0], holepixPos[:, 1]] + flowF_neighbor[:, 2] = KeySourceFrame[KeySourceFrameIdx] + + # Round the forward flow neighbor location + flow_neighbor_int = np.round(copy.deepcopy(flowF_neighbor)).astype(np.int32) + + # Check the forawrd/backward consistency + IsConsist, _ = FBconsistCheck(flowF_neighbor, flowB_vertical, + flowB_horizont, holepixPos, consistencyThres) + + # Check out-of-boundary + ValidPos = np.logical_and( + np.logical_and(flow_neighbor_int[:, 0] >= 0, + flow_neighbor_int[:, 0] < imgH - 1), + np.logical_and(flow_neighbor_int[:, 1] >= 0, + flow_neighbor_int[:, 1] < imgW - 1)) + + holepixPos_ = copy.deepcopy(holepixPos)[ValidPos, :] + flow_neighbor_int = flow_neighbor_int[ValidPos, :] + flowF_neighbor = flowF_neighbor[ValidPos, :] + IsConsist = IsConsist[ValidPos] + + KnownInd = mask[flow_neighbor_int[:, 0], + flow_neighbor_int[:, 1], + KeySourceFrame[KeySourceFrameIdx]] == 0 + + KnownInd = np.logical_and(KnownInd, IsConsist) + + gradient_x_KeySourceFrameFlowNN[:, :, :, KeySourceFrameIdx] = \ + copy.deepcopy(gradient_x[:, :, :, indFrame]) + gradient_y_KeySourceFrameFlowNN[:, :, :, KeySourceFrameIdx] = \ + copy.deepcopy(gradient_y[:, :, :, indFrame]) + + gradient_x_KeySourceFrameFlowNN[holepixPos_[KnownInd, 0], + holepixPos_[KnownInd, 1], + :, KeySourceFrameIdx] = \ + interp(gradient_x[:, :, :, KeySourceFrame[KeySourceFrameIdx]], + flowF_neighbor[KnownInd, 1].reshape(-1), + flowF_neighbor[KnownInd, 0].reshape(-1)) + + gradient_y_KeySourceFrameFlowNN[holepixPos_[KnownInd, 0], + holepixPos_[KnownInd, 1], + :, KeySourceFrameIdx] = \ + interp(gradient_y[:, :, :, KeySourceFrame[KeySourceFrameIdx]], + flowF_neighbor[KnownInd, 1].reshape(-1), + flowF_neighbor[KnownInd, 0].reshape(-1)) + + HaveKeySourceFrameFlowNN[holepixPos_[KnownInd, 0], + holepixPos_[KnownInd, 1], + KeySourceFrameIdx] = 1 + + return HaveKeySourceFrameFlowNN, gradient_x_KeySourceFrameFlowNN, gradient_y_KeySourceFrameFlowNN + +class Progbar(object): + """Displays a progress bar. + + Arguments: + target: Total number of steps expected, None if unknown. + width: Progress bar width on screen. + verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) + stateful_metrics: Iterable of string names of metrics that + should *not* be averaged over time. Metrics in this list + will be displayed as-is. All others will be averaged + by the progbar before display. + interval: Minimum visual progress update interval (in seconds). + """ + + def __init__(self, target, width=25, verbose=1, interval=0.05, + stateful_metrics=None): + self.target = target + self.width = width + self.verbose = verbose + self.interval = interval + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) + else: + self.stateful_metrics = set() + + self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and + sys.stdout.isatty()) or + 'ipykernel' in sys.modules or + 'posix' in sys.modules) + self._total_width = 0 + self._seen_so_far = 0 + # We use a dict + list to avoid garbage collection + # issues found in OrderedDict + self._values = {} + self._values_order = [] + self._start = time.time() + self._last_update = 0 + + def update(self, current, values=None): + """Updates the progress bar. + + Arguments: + current: Index of current step. + values: List of tuples: + `(name, value_for_last_step)`. + If `name` is in `stateful_metrics`, + `value_for_last_step` will be displayed as-is. + Else, an average of the metric over time will be displayed. + """ + values = values or [] + for k, v in values: + if k not in self._values_order: + self._values_order.append(k) + if k not in self.stateful_metrics: + if k not in self._values: + self._values[k] = [v * (current - self._seen_so_far), + current - self._seen_so_far] + else: + self._values[k][0] += v * (current - self._seen_so_far) + self._values[k][1] += (current - self._seen_so_far) + else: + self._values[k] = v + self._seen_so_far = current + + now = time.time() + info = ' - %.0fs' % (now - self._start) + if self.verbose == 1: + if (now - self._last_update < self.interval and + self.target is not None and current < self.target): + return + + prev_total_width = self._total_width + if self._dynamic_display: + sys.stdout.write('\b' * prev_total_width) + sys.stdout.write('\r') + else: + sys.stdout.write('\n') + + if self.target is not None: + numdigits = int(np.floor(np.log10(self.target))) + 1 + barstr = '%%%dd/%d [' % (numdigits, self.target) + bar = barstr % current + prog = float(current) / self.target + prog_width = int(self.width * prog) + if prog_width > 0: + bar += ('=' * (prog_width - 1)) + if current < self.target: + bar += '>' + else: + bar += '=' + bar += ('.' * (self.width - prog_width)) + bar += ']' + else: + bar = '%7d/Unknown' % current + + self._total_width = len(bar) + sys.stdout.write(bar) + + if current: + time_per_unit = (now - self._start) / current + else: + time_per_unit = 0 + if self.target is not None and current < self.target: + eta = time_per_unit * (self.target - current) + if eta > 3600: + eta_format = '%d:%02d:%02d' % (eta // 3600, + (eta % 3600) // 60, + eta % 60) + elif eta > 60: + eta_format = '%d:%02d' % (eta // 60, eta % 60) + else: + eta_format = '%ds' % eta + + info = ' - ETA: %s' % eta_format + else: + if time_per_unit >= 1: + info += ' %.0fs/step' % time_per_unit + elif time_per_unit >= 1e-3: + info += ' %.0fms/step' % (time_per_unit * 1e3) + else: + info += ' %.0fus/step' % (time_per_unit * 1e6) + + for k in self._values_order: + info += ' - %s:' % k + if isinstance(self._values[k], list): + avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) + if abs(avg) > 1e-3: + info += ' %.4f' % avg + else: + info += ' %.4e' % avg + else: + info += ' %s' % self._values[k] + + self._total_width += len(info) + if prev_total_width > self._total_width: + info += (' ' * (prev_total_width - self._total_width)) + + if self.target is not None and current >= self.target: + info += '\n' + + sys.stdout.write(info) + sys.stdout.flush() + + elif self.verbose == 2: + if self.target is None or current >= self.target: + for k in self._values_order: + info += ' - %s:' % k + avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) + if avg > 1e-3: + info += ' %.4f' % avg + else: + info += ' %.4e' % avg + info += '\n' + + sys.stdout.write(info) + sys.stdout.flush() + + self._last_update = now + + def add(self, n, values=None): + self.update(self._seen_so_far + n, values) + + +class PSNR(nn.Module): + def __init__(self, max_val): + super(PSNR, self).__init__() + + base10 = torch.log(torch.tensor(10.0)) + max_val = torch.tensor(max_val).float() + + self.register_buffer('base10', base10) + self.register_buffer('max_val', 20 * torch.log(max_val) / base10) + + def __call__(self, a, b): + mse = torch.mean((a.float() - b.float()) ** 2) + + if mse == 0: + return torch.tensor(0) + + return self.max_val - 10 * torch.log(mse) / self.base10 +# Get surrounding integer postiion +def IntPos(CurPos): + + x_floor = np.expand_dims(np.floor(CurPos[:, 0]).astype(np.int32), 1) + x_ceil = np.expand_dims(np.ceil(CurPos[:, 0]).astype(np.int32), 1) + y_floor = np.expand_dims(np.floor(CurPos[:, 1]).astype(np.int32), 1) + y_ceil = np.expand_dims(np.ceil(CurPos[:, 1]).astype(np.int32), 1) + Fm = np.expand_dims(np.floor(CurPos[:, 2]).astype(np.int32), 1) + + Pos_tl = np.concatenate((x_floor, y_floor, Fm), 1) + Pos_tr = np.concatenate((x_ceil, y_floor, Fm), 1) + Pos_bl = np.concatenate((x_floor, y_ceil, Fm), 1) + Pos_br = np.concatenate((x_ceil, y_ceil, Fm), 1) + + return Pos_tl, Pos_tr, Pos_bl, Pos_br diff --git a/modules/utils/region_fill.py b/modules/utils/region_fill.py new file mode 100644 index 0000000..e8de05a --- /dev/null +++ b/modules/utils/region_fill.py @@ -0,0 +1,125 @@ +import numpy as np +import cv2 +from scipy import sparse +from scipy.sparse.linalg import spsolve + + +def regionfill(I, mask, factor=1.0): + if np.count_nonzero(mask) == 0: + return I.copy() + resize_mask = cv2.resize( + mask.astype(float), (0, 0), fx=factor, fy=factor) > 0 + resize_I = cv2.resize(I.astype(float), (0, 0), fx=factor, fy=factor) + maskPerimeter = findBoundaryPixels(resize_mask) + regionfillLaplace(resize_I, resize_mask, maskPerimeter) + resize_I = cv2.resize(resize_I, (I.shape[1], I.shape[0])) + resize_I[mask == 0] = I[mask == 0] + return resize_I + + +def findBoundaryPixels(mask): + kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) + maskDilated = cv2.dilate(mask.astype(float), kernel) + return (maskDilated > 0) & (mask == 0) + + +def regionfillLaplace(I, mask, maskPerimeter): + height, width = I.shape + rightSide = formRightSide(I, maskPerimeter) + + # Location of mask pixels + maskIdx = np.where(mask) + + # Only keep values for pixels that are in the mask + rightSide = rightSide[maskIdx] + + # Number the mask pixels in a grid matrix + grid = -np.ones((height, width)) + grid[maskIdx] = range(0, maskIdx[0].size) + # Pad with zeros to avoid "index out of bounds" errors in the for loop + grid = padMatrix(grid) + gridIdx = np.where(grid >= 0) + + # Form the connectivity matrix D=sparse(i,j,s) + # Connect each mask pixel to itself + i = np.arange(0, maskIdx[0].size) + j = np.arange(0, maskIdx[0].size) + # The coefficient is the number of neighbors over which we average + numNeighbors = computeNumberOfNeighbors(height, width) + s = numNeighbors[maskIdx] + # Now connect the N,E,S,W neighbors if they exist + for direction in ((-1, 0), (0, 1), (1, 0), (0, -1)): + # Possible neighbors in the current direction + neighbors = grid[gridIdx[0] + direction[0], gridIdx[1] + direction[1]] + # ConDnect mask points to neighbors with -1's + index = (neighbors >= 0) + i = np.concatenate((i, grid[gridIdx[0][index], gridIdx[1][index]])) + j = np.concatenate((j, neighbors[index])) + s = np.concatenate((s, -np.ones(np.count_nonzero(index)))) + + D = sparse.coo_matrix((s, (i.astype(int), j.astype(int)))).tocsr() + sol = spsolve(D, rightSide) + I[maskIdx] = sol + return I + + +def formRightSide(I, maskPerimeter): + height, width = I.shape + perimeterValues = np.zeros((height, width)) + perimeterValues[maskPerimeter] = I[maskPerimeter] + rightSide = np.zeros((height, width)) + + rightSide[1:height - 1, 1:width - 1] = ( + perimeterValues[0:height - 2, 1:width - 1] + + perimeterValues[2:height, 1:width - 1] + + perimeterValues[1:height - 1, 0:width - 2] + + perimeterValues[1:height - 1, 2:width]) + + rightSide[1:height - 1, 0] = ( + perimeterValues[0:height - 2, 0] + perimeterValues[2:height, 0] + + perimeterValues[1:height - 1, 1]) + + rightSide[1:height - 1, width - 1] = ( + perimeterValues[0:height - 2, width - 1] + + perimeterValues[2:height, width - 1] + + perimeterValues[1:height - 1, width - 2]) + + rightSide[0, 1:width - 1] = ( + perimeterValues[1, 1:width - 1] + perimeterValues[0, 0:width - 2] + + perimeterValues[0, 2:width]) + + rightSide[height - 1, 1:width - 1] = ( + perimeterValues[height - 2, 1:width - 1] + + perimeterValues[height - 1, 0:width - 2] + + perimeterValues[height - 1, 2:width]) + + rightSide[0, 0] = perimeterValues[0, 1] + perimeterValues[1, 0] + rightSide[0, width - 1] = ( + perimeterValues[0, width - 2] + perimeterValues[1, width - 1]) + rightSide[height - 1, 0] = ( + perimeterValues[height - 2, 0] + perimeterValues[height - 1, 1]) + rightSide[height - 1, width - 1] = (perimeterValues[height - 2, width - 1] + + perimeterValues[height - 1, width - 2]) + return rightSide + + +def computeNumberOfNeighbors(height, width): + # Initialize + numNeighbors = np.zeros((height, width)) + # Interior pixels have 4 neighbors + numNeighbors[1:height - 1, 1:width - 1] = 4 + # Border pixels have 3 neighbors + numNeighbors[1:height - 1, (0, width - 1)] = 3 + numNeighbors[(0, height - 1), 1:width - 1] = 3 + # Corner pixels have 2 neighbors + numNeighbors[(0, 0, height - 1, height - 1), (0, width - 1, 0, + width - 1)] = 2 + return numNeighbors + + +def padMatrix(grid): + height, width = grid.shape + gridPadded = -np.ones((height + 2, width + 2)) + gridPadded[1:height + 1, 1:width + 1] = grid + gridPadded = gridPadded.astype(grid.dtype) + return gridPadded diff --git a/run_data_extract.py b/run_extractor.py similarity index 94% rename from run_data_extract.py rename to run_extractor.py index 3b62d3a..5412dc2 100644 --- a/run_data_extract.py +++ b/run_extractor.py @@ -2,7 +2,7 @@ import numpy as np import sys import time -from modules import ftn +from modules import image_stack # Functions @@ -22,7 +22,7 @@ def init(): strFormat = "%-13s%-35s%2s" print("=" * 50) print("%s%48s" % ("==", "==")) - print(strFormat % ("==", "Invisible Cloak Project", "==")) + print(strFormat % ("==", "Invisibility Cloak Project", "==")) print(strFormat % ("==", "Project ver. : v1.3", "==")) print(strFormat % ("==", "Python ver. : " + pyVersion, "==")) print(strFormat % ("==", "OpenCV ver. : " + cv2.__version__, "==")) @@ -130,12 +130,12 @@ def main(argv): imgBlank = np.zeros((100, 100), np.uint8) imgResult = cv2.addWeighted(res1, 1, res2, 1, 0) - imgStack = ftn.stackImages(0.7, ([img, imgHSV, mask], [res1, res2, imgResult])) + imgStack = image_stack.stackImages(0.7, ([img, imgHSV, mask], [res1, res2, imgResult])) # Only display the image if it is not empty if ret: - frame, prevTime, strFPS = ftn.videoText(imgStack, frame, baseTime, prevTime, strRun, strFPS) + frame, prevTime, strFPS = image_stack.videoText(imgStack, frame, baseTime, prevTime, strRun, strFPS) cv2.imshow("Result", imgStack) out.write(img) diff --git a/run_video_inpainting_test.py b/run_inpainting.py similarity index 74% rename from run_video_inpainting_test.py rename to run_inpainting.py index fa79376..c78a883 100644 --- a/run_video_inpainting_test.py +++ b/run_inpainting.py @@ -16,7 +16,7 @@ from skimage.feature import canny # Custom -import edgeconnect.utils +import modules.EdgeConnect.utils import PIL.ImageOps import shutil import skimage.io @@ -25,23 +25,22 @@ from logging import warning as warn # RAFT -from RAFT import utils -from RAFT import RAFT +from modules.RAFT import utils +from modules.RAFT import RAFT # EdgeConnect -from edgeconnect.networks import EdgeGenerator_ +from modules.EdgeConnect.networks import EdgeGenerator_ -# tool -from tool.get_flowNN import get_flowNN -from tool.get_flowNN_gradient import get_flowNN_gradient -from tool.spatial_inpaint import spatial_inpaint -from tool.frame_inpaint import DeepFillv1 +# tools +from modules.frame_inpaint import DeepFillv1 +from modules.get_flowNN_gradient import get_flowNN_gradient +from modules.spatial_inpaint import spatial_inpaint # utils -import utils.region_fill as rf -from utils.Poisson_blend import Poisson_blend -from utils.Poisson_blend_img import Poisson_blend_img -from utils.common_utils import flow_edge +from modules.utils.common_utils import flow_edge +from modules.utils.Poisson_blend import Poisson_blend +from modules.utils.Poisson_blend_img import Poisson_blend_img +import modules.utils.region_fill as rf # Custom root warnings. @@ -52,6 +51,7 @@ def _precision_warn(p1, p2, extra=""): ) warn(msg.format(p1, p2, extra, p2)) + def silence_imageio_warning(*args, **kwarge): pass @@ -129,6 +129,7 @@ def calculate_flow(args, model, video, mode): print("Loading {0}".format(flow_name), '\r', end='') flow = utils.frame_utils.readFlow(flow_name) Flow = np.concatenate((Flow, flow[..., None]), axis=-1) + return Flow create_dir(os.path.join(args.outroot, '1_flow', mode + '_flo')) @@ -161,7 +162,7 @@ def calculate_flow(args, model, video, mode): flow_img.save(os.path.join(args.outroot, '1_flow', mode + '_png', '%05d.png'%i)) utils.frame_utils.writeFlow(os.path.join(args.outroot, '1_flow', mode + '_flo', '%05d.flo'%i), flow) - # Convert image to gif + # Convert image to gif. flow_gif.append(imageio.imread(os.path.join(args.outroot, '1_flow', mode + '_png', '%05d.png' % i))) # Save gif. @@ -171,7 +172,7 @@ def calculate_flow(args, model, video, mode): def edge_completion(args, EdgeGenerator, corrFlow, flow_mask, mode): - """2. Calculate flow edge and complete it. + """2 ~ 3. Calculate flow edge and complete it. """ if mode not in ['forward', 'backward']: @@ -186,6 +187,7 @@ def edge_completion(args, EdgeGenerator, corrFlow, flow_mask, mode): print("Loading {0}".format(edge_name), '\r', end='') edge = np.load(edge_name) Edge = np.concatenate((Edge, edge[..., None]), axis=-1) + return Edge create_dir(os.path.join(args.outroot, '2_edge_canny', mode + '_png')) @@ -260,7 +262,7 @@ def edge_completion(args, EdgeGenerator, corrFlow, flow_mask, mode): def complete_flow(args, corrFlow, flow_mask, mode, edge=None): - """3. Completes flow. + """4. Completes flow. """ if mode not in ['forward', 'backward']: raise NotImplementedError @@ -270,10 +272,12 @@ def complete_flow(args, corrFlow, flow_mask, mode, edge=None): # If already exist flow_comp, load and return. if os.path.isdir(os.path.join(args.outroot, '4_flow_comp', mode + '_flo')): compFlow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32) + for flow_name in sorted(glob.glob(os.path.join(args.outroot, '4_flow_comp', mode + '_flo', '*.flo'))): print("Loading {0}".format(flow_name), '\r', end='') flow = utils.frame_utils.readFlow(flow_name) compFlow = np.concatenate((compFlow, flow[..., None]), axis=-1) + return compFlow create_dir(os.path.join(args.outroot, '4_flow_comp', mode + '_flo')) @@ -317,7 +321,7 @@ def complete_flow(args, corrFlow, flow_mask, mode, edge=None): flow_img = utils.flow_viz.flow_to_image(compFlow[:, :, :, i]) flow_img = Image.fromarray(flow_img) - # Saves the flow and flow_img. + # Save the flow and flow_img. flow_img.save(os.path.join(args.outroot, '4_flow_comp', mode + '_png', '%05d.png'%i)) utils.frame_utils.writeFlow(os.path.join(args.outroot, '4_flow_comp', mode + '_flo', '%05d.flo'%i), compFlow[:, :, :, i]) @@ -330,122 +334,6 @@ def complete_flow(args, corrFlow, flow_mask, mode, edge=None): return compFlow -def video_completion(args): - - # Flow model. - RAFT_model = initialize_RAFT(args) - - # Loads frames. - filename_list = glob.glob(os.path.join(args.path, '*.png')) + \ - glob.glob(os.path.join(args.path, '*.jpg')) - - # Obtains imgH, imgW and nFrame. - imgH, imgW = np.array(Image.open(filename_list[0])).shape[:2] - nFrame = len(filename_list) - print('Image Size : {0} x {1} x {2} frames'.format(imgH, imgW, nFrame)) - - # Loads video. - video = [] - for filename in sorted(filename_list): - video.append(torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8)[..., :3]).permute(2, 0, 1).float()) - - video = torch.stack(video, dim=0) - video = video.to('cuda') - - # Calcutes the corrupted flow. - corrFlowF = calculate_flow(args, RAFT_model, video, 'forward') - corrFlowB = calculate_flow(args, RAFT_model, video, 'backward') - print('\nFinish flow prediction.') - - # Makes sure video is in BGR (opencv) format. - video = video.permute(2, 3, 1, 0).cpu().numpy()[:, :, ::-1, :] / 255. - - '''Object removal without seamless - ''' - # Loads masks. - filename_list = glob.glob(os.path.join(args.path_mask, '*.png')) + \ - glob.glob(os.path.join(args.path_mask, '*.jpg')) - - mask = [] - flow_mask = [] - for filename in sorted(filename_list): - mask_img = np.array(Image.open(filename).convert('L')) - mask.append(mask_img) - - # Dilate 15 pixel so that all known pixel is trustworthy - flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=15) - # Close the small holes inside the foreground objects - flow_mask_img = cv2.morphologyEx(flow_mask_img.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((21, 21),np.uint8)).astype(np.bool) - flow_mask_img = scipy.ndimage.binary_fill_holes(flow_mask_img).astype(np.bool) - flow_mask.append(flow_mask_img) - - # mask indicating the missing region in the video. - mask = np.stack(mask, -1).astype(np.bool) - flow_mask = np.stack(flow_mask, -1).astype(np.bool) - - if args.edge_guide: - # Edge completion model. - EdgeGenerator = EdgeGenerator_() - EdgeComp_ckpt = torch.load(args.edge_completion_model) - EdgeGenerator.load_state_dict(EdgeComp_ckpt['generator']) - EdgeGenerator.to(torch.device('cuda:0')) - EdgeGenerator.eval() - - # Edge completion. - FlowF_edge = edge_completion(args, EdgeGenerator, corrFlowF, flow_mask, 'forward') - FlowB_edge = edge_completion(args, EdgeGenerator, corrFlowB, flow_mask, 'backward') - print('\nFinish edge completion.') - else: - FlowF_edge, FlowB_edge = None, None - - # Completes the flow. - videoFlowF = complete_flow(args, corrFlowF, flow_mask, 'forward', FlowF_edge) - videoFlowB = complete_flow(args, corrFlowB, flow_mask, 'backward', FlowB_edge) - print('\nFinish flow completion.') - - #return - iter = 0 - mask_tofill = mask - video_comp = video - - # Image inpainting model. - deepfill = DeepFillv1(pretrained_model=args.deepfill_model, image_shape=[imgH, imgW]) - - # We iteratively complete the video. - while(np.sum(mask_tofill) > 0): - create_dir(os.path.join(args.outroot, '5_frame_comp_' + str(iter))) - - # Color propagation. - video_comp, mask_tofill, _ = get_flowNN(args, - video_comp, - mask_tofill, - videoFlowF, - videoFlowB, - None, - None) - - for i in range(nFrame): - mask_tofill[:, :, i] = scipy.ndimage.binary_dilation(mask_tofill[:, :, i], iterations=2) - img = video_comp[:, :, :, i] * 255 - # Green indicates the regions that are not filled yet. - img[mask_tofill[:, :, i]] = [0, 255, 0] - cv2.imwrite(os.path.join(args.outroot, '5_frame_comp_' + str(iter), '%05d.png'%i), img) - - # video_comp_ = (video_comp * 255).astype(np.uint8).transpose(3, 0, 1, 2)[:, :, :, ::-1] - # imageio.mimwrite(os.path.join(args.outroot, 'frame_comp_' + str(iter), 'intermediate_{0}.mp4'.format(str(iter))), video_comp_, fps=12, quality=8, macro_block_size=1) - # imageio.mimsave(os.path.join(args.outroot, 'frame_comp_' + str(iter), 'intermediate_{0}.gif'.format(str(iter))), video_comp_, format='gif', fps=12) - mask_tofill, video_comp = spatial_inpaint(deepfill, mask_tofill, video_comp) - iter += 1 - - create_dir(os.path.join(args.outroot, '5_frame_comp_' + 'final')) - video_comp_ = (video_comp * 255).astype(np.uint8).transpose(3, 0, 1, 2)[:, :, :, ::-1] - for i in range(nFrame): - img = video_comp[:, :, :, i] * 255 - cv2.imwrite(os.path.join(args.outroot, '5_frame_comp_' + 'final', '%05d.png'%i), img) - imageio.mimwrite(os.path.join(args.outroot, '5_frame_comp_' + 'final', 'final.mp4'), video_comp_, fps=20, quality=8, macro_block_size=1) - imageio.mimsave(os.path.join(args.outroot, '0_process', '5_frame_comp_final.gif'), video_comp_, format='gif', fps=20) - - def video_completion_seamless(args): # Flow model. @@ -504,20 +392,17 @@ def video_completion_seamless(args): mask_dilated = np.stack(mask_dilated, -1).astype(np.bool) flow_mask = np.stack(flow_mask, -1).astype(np.bool) - if args.edge_guide: - # Edge completion model. - EdgeGenerator = EdgeGenerator_() - EdgeComp_ckpt = torch.load(args.edge_completion_model) - EdgeGenerator.load_state_dict(EdgeComp_ckpt['generator']) - EdgeGenerator.to(torch.device('cuda:0')) - EdgeGenerator.eval() - - # Edge completion. - FlowF_edge = edge_completion(args, EdgeGenerator, corrFlowF, flow_mask, 'forward') - FlowB_edge = edge_completion(args, EdgeGenerator, corrFlowB, flow_mask, 'backward') - print('\nFinish edge completion.') - else: - FlowF_edge, FlowB_edge = None, None + # Edge completion model. + EdgeGenerator = EdgeGenerator_() + EdgeComp_ckpt = torch.load(args.edge_model) + EdgeGenerator.load_state_dict(EdgeComp_ckpt['generator']) + EdgeGenerator.to(torch.device('cuda:0')) + EdgeGenerator.eval() + + # Edge completion. + FlowF_edge = edge_completion(args, EdgeGenerator, corrFlowF, flow_mask, 'forward') + FlowB_edge = edge_completion(args, EdgeGenerator, corrFlowB, flow_mask, 'backward') + print('\nFinish edge completion.') # Completes the flow. videoFlowF = complete_flow(args, corrFlowF, flow_mask, 'forward', FlowF_edge) @@ -546,8 +431,8 @@ def video_completion_seamless(args): else: create_dir(os.path.join(args.outroot, '5_gradient', 'x_npy')) - create_dir(os.path.join(args.outroot, '5_gradient', 'y_npy')) create_dir(os.path.join(args.outroot, '5_gradient', 'x_png')) + create_dir(os.path.join(args.outroot, '5_gradient', 'y_npy')) create_dir(os.path.join(args.outroot, '5_gradient', 'y_png')) grad_x_gif = [] grad_y_gif = [] @@ -579,7 +464,7 @@ def video_completion_seamless(args): # print("grad_x shape: {}, dimension: {}".format(grad_x.shape, grad_x.ndim)) # print("grad_y shape: {}, dimension: {}".format(grad_y.shape, grad_y.ndim)) - # Conveert image to gif. + # Convert image to gif. grad_x_gif.append(imageio.imread(os.path.join(args.outroot, '5_gradient', 'x_png', '%05d.png' % indFrame))) grad_y_gif.append(imageio.imread(os.path.join(args.outroot, '5_gradient', 'y_png', '%05d.png' % indFrame))) @@ -599,13 +484,11 @@ def video_completion_seamless(args): # Image inpainting model. deepfill = DeepFillv1(pretrained_model=args.deepfill_model, image_shape=[imgH, imgW]) - # We iteratively complete the video. while(np.sum(mask) > 0): - create_dir(os.path.join(args.outroot, '6_frame_seamless_comp_' + str(iter))) + create_dir(os.path.join(args.outroot, '7_frame_seamless_comp_' + str(iter))) # Gradient propagation. - gradient_x_filled, gradient_y_filled, mask_gradient = \ get_flowNN_gradient(args, gradient_x_filled, @@ -617,34 +500,40 @@ def video_completion_seamless(args): None, None) - create_dir(os.path.join(args.outroot, '6_gradient', 'x_png')) - create_dir(os.path.join(args.outroot, '6_gradient', 'y_png')) - - for indFrame in range(nFrame): - grad_x_filled = gradient_x_filled[:, :, 0, indFrame] - grad_y_filled = gradient_y_filled[:, :, 0, indFrame] - skimage.io.imsave(os.path.join(args.outroot, '6_gradient', 'x_png', '%05d.png' % indFrame), grad_x_filled) - skimage.io.imsave(os.path.join(args.outroot, '6_gradient', 'y_png', '%05d.png' % indFrame), grad_y_filled) - + create_dir(os.path.join(args.outroot, '6_gradient_filled', 'x_png')) + create_dir(os.path.join(args.outroot, '6_gradient_filled', 'y_png')) + grad_x_filled_gif = [] + grad_y_filled_gif = [] # if there exist holes in mask, Poisson blending will fail. So I did this trick. I sacrifice some value. Another solution is to modify Poisson blending. for indFrame in range(nFrame): mask_gradient[:, :, indFrame] = scipy.ndimage.binary_fill_holes(mask_gradient[:, :, indFrame]).astype(np.bool) + + # Save the gradient filled img. + grad_x_filled = gradient_x_filled[:, :, 0, indFrame] + grad_y_filled = gradient_y_filled[:, :, 0, indFrame] + skimage.io.imsave(os.path.join(args.outroot, '6_gradient_filled', 'x_png', '%05d.png' % indFrame), grad_x_filled) + skimage.io.imsave(os.path.join(args.outroot, '6_gradient_filled', 'y_png', '%05d.png' % indFrame), grad_y_filled) + + # Convert image to gif. + grad_x_filled_gif.append(imageio.imread(os.path.join(args.outroot, '6_gradient_filled', 'x_png', '%05d.png' % indFrame))) + grad_y_filled_gif.append(imageio.imread(os.path.join(args.outroot, '6_gradient_filled', 'y_png', '%05d.png' % indFrame))) + # Save gif. + imageio.mimsave(os.path.join(args.outroot, '0_process', '6_gradient_filled' + 'x.gif'), grad_x_filled_gif, format='gif', fps=20) + imageio.mimsave(os.path.join(args.outroot, '0_process', '6_gradient_filled' + 'y.gif'), grad_y_filled_gif, format='gif', fps=20) + # After one gradient propagation iteration # gradient --> RGB for indFrame in range(nFrame): print("Poisson blending frame {0:3d}".format(indFrame)) if mask[:, :, indFrame].sum() > 0: - ''' try: frameBlend, UnfilledMask = Poisson_blend_img(video_comp[:, :, :, indFrame], gradient_x_filled[:, 0 : imgW - 1, :, indFrame], gradient_y_filled[0 : imgH - 1, :, :, indFrame], mask[:, :, indFrame], mask_gradient[:, :, indFrame]) - UnfilledMask = scipy.ndimage.binary_fill_holes(UnfilledMask).astype(np.bool) + # UnfilledMask = scipy.ndimage.binary_fill_holes(UnfilledMask).astype(np.bool) except: frameBlend, UnfilledMask = video_comp[:, :, :, indFrame], mask[:, :, indFrame] - ''' - frameBlend, UnfilledMask = video_comp[:, :, :, indFrame], mask[:, :, indFrame] frameBlend = np.clip(frameBlend, 0, 1.0) tmp = cv2.inpaint((frameBlend * 255).astype(np.uint8), UnfilledMask.astype(np.uint8), 3, cv2.INPAINT_TELEA).astype(np.float32) / 255. @@ -659,12 +548,12 @@ def video_completion_seamless(args): else: frameBlend_ = video_comp[:, :, :, indFrame] - cv2.imwrite(os.path.join(args.outroot, '6_frame_seamless_comp_' + str(iter), '%05d.png'%indFrame), frameBlend_ * 255.) + cv2.imwrite(os.path.join(args.outroot, '7_frame_seamless_comp_' + str(iter), '%05d.png' % indFrame), frameBlend_ * 255.) video_comp_ = (video_comp * 255).astype(np.uint8).transpose(3, 0, 1, 2)[:, :, :, ::-1] - # imageio.mimwrite(os.path.join(args.outroot, '6_frame_seamless_comp_' + str(iter), 'intermediate_{0}.mp4'.format(str(iter))), video_comp_, fps=20, quality=8, macro_block_size=1) - imageio.mimsave(os.path.join(args.outroot, '6_frame_seamless_comp_' + str(iter), 'intermediate_{0}.gif'.format(str(iter))), video_comp_, format='gif', fps=20) - return + # imageio.mimwrite(os.path.join(args.outroot, '7_frame_seamless_comp_' + str(iter), 'intermediate_{0}.mp4'.format(str(iter))), video_comp_, fps=20, quality=8, macro_block_size=1) + # imageio.mimsave(os.path.join(args.outroot, '7_frame_seamless_comp_' + str(iter), 'intermediate_{0}.gif'.format(str(iter))), video_comp_, format='gif', fps=20) + mask, video_comp = spatial_inpaint(deepfill, mask, video_comp) iter += 1 @@ -678,28 +567,45 @@ def video_completion_seamless(args): gradient_x_filled[mask_gradient[:, :, indFrame], :, indFrame] = 0 gradient_y_filled[mask_gradient[:, :, indFrame], :, indFrame] = 0 - create_dir(os.path.join(args.outroot, '6_frame_seamless_comp_' + 'final')) + create_dir(os.path.join(args.outroot, '7_frame_seamless_comp_' + 'final')) video_comp_ = (video_comp * 255).astype(np.uint8).transpose(3, 0, 1, 2)[:, :, :, ::-1] + for i in range(nFrame): img = video_comp[:, :, :, i] * 255 - cv2.imwrite(os.path.join(args.outroot, '6_frame_seamless_comp_' + 'final', '%05d.png' % i), img) - imageio.mimwrite(os.path.join(args.outroot, '6_frame_seamless_comp_' + 'final', 'final.mp4'), video_comp_, fps=20, quality=8, macro_block_size=1) - imageio.mimsave(os.path.join(args.outroot, '0_process', '6_frame_seamless_comp_final.gif'), video_comp_, format='gif', fps=20) + cv2.imwrite(os.path.join(args.outroot, '7_frame_seamless_comp_' + 'final', '%05d.png' % i), img) + imageio.mimwrite(os.path.join(args.outroot, '7_frame_seamless_comp_' + 'final', 'final.mp4'), video_comp_, fps=20, quality=8, macro_block_size=1) + imageio.mimsave(os.path.join(args.outroot, '0_process', '7_frame_seamless_comp_final.gif'), video_comp_, format='gif', fps=20) def args_list(args): - print("\n") - print("================================================================") - print("== TEST ==") - print("================================================================") - print("\n") + print("=" * 50) + print("=" + " " * 48 + "=") + print("%-5s%-35s%10s" % ("=", "Invisibility Cloak Project", "=")) + print("%-5s%-35s%10s" % ("=", "Updated : 2021.11.27 (Sat.)", "=")) + print("=" + " " * 48 + "=") + print("%-5s%-35s%10s" % ("=", "Project : v1.3", "=")) + print("%-5s%-35s%10s" % ("=", "Python : v3.8.12", "=")) + print("%-5s%-35s%10s" % ("=", "OpenCV : v4.5.4", "=")) + print("%-5s%-35s%10s" % ("=", "PyTorch : v1.6.0", "=")) + print("%-5s%-35s%10s" % ("=", "CUDA : v10.2.89", "=")) + print("%-5s%-35s%10s" % ("=", "Matplotlib : v3.4.3", "=")) + print("%-5s%-35s%10s" % ("=", "Scipy : v1.6.2", "=")) + print("=" + " " * 48 + "=") args_dict = vars(args) for key in args_dict: val = args_dict[key] - print("%s : %s" % (key, val)) - - print("") + + if len(str(key)) > 10: + key = str(key)[:7] + "..." + + if len(str(val)) > 21: + val = str(val)[:18] + "..." + + print("%-5s%-10s%-25s%10s" % ("=", key, " : " + str(val), "=")) + + print("=" + " " * 48 + "=") + print("=" * 50) def main(args): @@ -707,11 +613,11 @@ def main(args): "Accepted modes: 'object_removal', 'video_extrapolation', but input is %s" ) % args.mode + # Custom warnings warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=UserWarning) imageio.core.util._precision_warn = silence_imageio_warning - args.edge_guide = True args_list(args) if args.clean: @@ -719,12 +625,7 @@ def main(args): if args.run: create_dir(os.path.join(args.outroot, '0_process')) - - if args.seamless: - # args.outroot = 'D:/_data/tennis_result_seamless' - video_completion_seamless(args) - else: - video_completion(args) + video_completion_seamless(args) @@ -733,13 +634,11 @@ def main(args): # video completion parser.add_argument('--mode', default='object_removal', help="modes: object_removal / video_extrapolation") - parser.add_argument('--path', default='D:/_data/tennis', help="dataset for evaluation") - parser.add_argument('--path_mask', default='D:/_data/square_mask', help="mask for object removal") - parser.add_argument('--outroot', default='D:/_data/tennis_result', help="output directory") + parser.add_argument('--path', default='./data/color', help="dataset for evaluation") + parser.add_argument('--path_mask', default='./data/mask', help="mask for object removal") + parser.add_argument('--outroot', default='./data/result', help="output directory") # options - parser.add_argument('--seamless', action='store_true', help='Whether operate in the gradient domain') - parser.add_argument('--edge_guide', action='store_true', help='Whether use edge as guidance to complete flow') parser.add_argument('--Nonlocal', dest='Nonlocal', default=False, type=bool) parser.add_argument('--alpha', dest='alpha', default=0.1, type=float) parser.add_argument('--consistencyThres', dest='consistencyThres', default=np.inf, type=float, help='flow consistency error threshold') @@ -751,12 +650,12 @@ def main(args): parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') # Edge completion - parser.add_argument('--edge_completion_model', default='./weight/edge_completion.pth', help="restore checkpoint") + parser.add_argument('--edge_model', default='./weight/edge_completion.pth', help="restore checkpoint") # Deepfill parser.add_argument('--deepfill_model', default='./weight/imagenet_deepfill.pth', help="restore checkpoint") - # custom + # Custom parser.add_argument('--run', action='store_true', help='run video completion') parser.add_argument('--merge', action='store_true', help='merge image canny edge and completed edge') parser.add_argument('--clean', action='store_true', help='clear result directory')