diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..36d8041 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +sagemaker_job.py +export_model.py +secrets.env +wandb/* +__pycache__/* \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..37c6482 --- /dev/null +++ b/README.md @@ -0,0 +1,23 @@ +Implementation + +Dataset + +- This model is trained on CCIHP dataset which contains 22 class labels. + +Please download imagenet pretrained resent-101 from [baidu drive](https://pan.baidu.com/s/1NoxI_JetjSVa7uqgVSKdPw) or [Google drive](https://drive.google.com/open?id=1rzLU-wK6rEorCNJfwrmIu5hY2wRMyKTK), and put it into dataset folder. + +#### Training + +- Set necessary arguments and run `train_simplified.py`. + +Citation: + +@InProceedings{Liu_2022_CVPR, + author = {Liu, Kunliang and Choi, Ouk and Wang, Jianming and Hwang, Wonjun}, + title = {CDGNet: Class Distribution Guided Network for Human Parsing}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2022}, + pages = {4473-4482} +} + diff --git a/Requirements b/Requirements new file mode 100644 index 0000000..41c9288 --- /dev/null +++ b/Requirements @@ -0,0 +1,9 @@ +Requirements + +Pytorch 1.9.0 +torchvision 0.11.0 +scipy 1.5.2 +cudatoolkit 11.3.1 +tensorboardX 2.2 +torchvision 0.11.0 +Python 3.7 diff --git a/dataset/.DS_Store b/dataset/.DS_Store new file mode 100644 index 0000000..a408936 Binary files /dev/null and b/dataset/.DS_Store differ diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataset/__pycache__/__init__.cpython-310.pyc b/dataset/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..560f7e4 Binary files /dev/null and b/dataset/__pycache__/__init__.cpython-310.pyc differ diff --git a/dataset/__pycache__/__init__.cpython-36.pyc b/dataset/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..8e4fa8a Binary files /dev/null and b/dataset/__pycache__/__init__.cpython-36.pyc differ diff --git a/dataset/__pycache__/__init__.cpython-37.pyc b/dataset/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..35f1599 Binary files /dev/null and b/dataset/__pycache__/__init__.cpython-37.pyc differ diff --git a/dataset/__pycache__/data_augmentation.cpython-36.pyc b/dataset/__pycache__/data_augmentation.cpython-36.pyc new file mode 100644 index 0000000..b897daa Binary files /dev/null and b/dataset/__pycache__/data_augmentation.cpython-36.pyc differ diff --git a/dataset/__pycache__/data_augmentation.cpython-37.pyc b/dataset/__pycache__/data_augmentation.cpython-37.pyc new file mode 100644 index 0000000..fdd82f2 Binary files /dev/null and b/dataset/__pycache__/data_augmentation.cpython-37.pyc differ diff --git a/dataset/__pycache__/dataset_LIP.cpython-36.pyc b/dataset/__pycache__/dataset_LIP.cpython-36.pyc new file mode 100644 index 0000000..f177d50 Binary files /dev/null and b/dataset/__pycache__/dataset_LIP.cpython-36.pyc differ diff --git a/dataset/__pycache__/datasets.cpython-310.pyc b/dataset/__pycache__/datasets.cpython-310.pyc new file mode 100644 index 0000000..b76e520 Binary files /dev/null and b/dataset/__pycache__/datasets.cpython-310.pyc differ diff --git a/dataset/__pycache__/datasets.cpython-36.pyc b/dataset/__pycache__/datasets.cpython-36.pyc new file mode 100644 index 0000000..b8540a5 Binary files /dev/null and b/dataset/__pycache__/datasets.cpython-36.pyc differ diff --git a/dataset/__pycache__/datasets.cpython-37.pyc b/dataset/__pycache__/datasets.cpython-37.pyc new file mode 100644 index 0000000..02b5cf7 Binary files /dev/null and b/dataset/__pycache__/datasets.cpython-37.pyc differ diff --git a/dataset/__pycache__/joint_transformation.cpython-36.pyc b/dataset/__pycache__/joint_transformation.cpython-36.pyc new file mode 100644 index 0000000..deaf4e3 Binary files /dev/null and b/dataset/__pycache__/joint_transformation.cpython-36.pyc differ diff --git a/dataset/__pycache__/joint_transformation.cpython-37.pyc b/dataset/__pycache__/joint_transformation.cpython-37.pyc new file mode 100644 index 0000000..615d6d0 Binary files /dev/null and b/dataset/__pycache__/joint_transformation.cpython-37.pyc differ diff --git a/dataset/__pycache__/target_generation.cpython-310.pyc b/dataset/__pycache__/target_generation.cpython-310.pyc new file mode 100644 index 0000000..cd1e8d5 Binary files /dev/null and b/dataset/__pycache__/target_generation.cpython-310.pyc differ diff --git a/dataset/__pycache__/target_generation.cpython-36.pyc b/dataset/__pycache__/target_generation.cpython-36.pyc new file mode 100644 index 0000000..43051da Binary files /dev/null and b/dataset/__pycache__/target_generation.cpython-36.pyc differ diff --git a/dataset/__pycache__/target_generation.cpython-37.pyc b/dataset/__pycache__/target_generation.cpython-37.pyc new file mode 100644 index 0000000..b49b120 Binary files /dev/null and b/dataset/__pycache__/target_generation.cpython-37.pyc differ diff --git a/dataset/__pycache__/voc.cpython-36.pyc b/dataset/__pycache__/voc.cpython-36.pyc new file mode 100644 index 0000000..aab1dfd Binary files /dev/null and b/dataset/__pycache__/voc.cpython-36.pyc differ diff --git a/dataset/datasets.py b/dataset/datasets.py new file mode 100644 index 0000000..a8551c5 --- /dev/null +++ b/dataset/datasets.py @@ -0,0 +1,259 @@ +import os +import numpy as np +import random +import torch +import cv2 +import json +import sys +sys.path.insert(0, '.') +from torch.utils import data +from torch.utils.data import DataLoader +import matplotlib.pyplot as plt +from dataset.target_generation import generate_edge, generate_hw_gt +from utils.transforms import get_affine_transform +from utils.ImgTransforms import AugmentationBlock, autoaug_imagenet_policies +from utils.utils import decode_parsing + + + +# statisticSeg=[ 30462,7026,21054,2404,1660,23165,1201,8182,2178,16224, +# 455,518,634,24418,18539,20033,4763,4832,8126,8166] +class LIPDataSet(data.Dataset): + def __init__(self, root, dataset, crop_size=[473, 473], scale_factor=0.25, + rotation_factor=30, ignore_label=255, transform=None): + """ + :rtype: + """ + self.root = root + self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0] + self.crop_size = np.asarray(crop_size) + self.ignore_label = ignore_label + self.scale_factor = scale_factor + self.rotation_factor = rotation_factor + self.flip_prob = 0.5 + self.flip_pairs = [[0, 5], [1, 4], [2, 3], [11, 14], [12, 13], [10, 15]] + self.transform = transform + self.dataset = dataset + # self.statSeg = np.array( statisticSeg, dtype ='float') + # self.statSeg = self.statSeg/30462 + + list_path = os.path.join(self.root, self.dataset + '_id.txt') + + self.im_list = [i_id.strip() for i_id in open(list_path)] + # if dataset != 'val': + # im_list_2 = [] + # for i in range(len(self.im_list)): + # if i % 5 ==0: + # im_list_2.append(self.im_list[i]) + # self.im_list = im_list_2 + self.number_samples = len(self.im_list) + #================================================================================ + self.augBlock = AugmentationBlock( autoaug_imagenet_policies ) + #================================================================================ + def __len__(self): + return self.number_samples + + def _box2cs(self, box): + x, y, w, h = box[:4] + return self._xywh2cs(x, y, w, h) + + def _xywh2cs(self, x, y, w, h): + center = np.zeros((2), dtype=np.float32) + center[0] = x + w * 0.5 + center[1] = y + h * 0.5 + if w > self.aspect_ratio * h: + h = w * 1.0 / self.aspect_ratio + elif w < self.aspect_ratio * h: + w = h * self.aspect_ratio + scale = np.array([w * 1.0, h * 1.0], dtype=np.float32) + + return center, scale + + def __getitem__(self, index): + # Load training image + im_name = self.im_list[index] + + im_path = os.path.join(self.root, self.dataset + '_images', im_name + '.jpg') + #print(im_path) + parsing_anno_path = os.path.join(self.root, self.dataset + '_segmentations', im_name + '.png') + + im = cv2.imread(im_path, cv2.IMREAD_COLOR) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + #================================================= + if self.dataset != 'val': + im = self.augBlock( im ) + #================================================= + h, w, _ = im.shape + parsing_anno = np.zeros((h, w), dtype=np.long) + + # Get center and scale + center, s = self._box2cs([0, 0, w - 1, h - 1]) + r = 0 + + if self.dataset != 'test': + parsing_anno = cv2.imread(parsing_anno_path, cv2.IMREAD_GRAYSCALE) + + if self.dataset == 'train' or self.dataset == 'trainval': + + sf = self.scale_factor + rf = self.rotation_factor + s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf) + r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) \ + if random.random() <= 0.6 else 0 + + if random.random() <= self.flip_prob: + im = im[:, ::-1, :] + parsing_anno = parsing_anno[:, ::-1] + + center[0] = im.shape[1] - center[0] - 1 + right_idx = [15, 17, 19] + left_idx = [14, 16, 18] + for i in range(0, 3): + right_pos = np.where(parsing_anno == right_idx[i]) + left_pos = np.where(parsing_anno == left_idx[i]) + parsing_anno[right_pos[0], right_pos[1]] = left_idx[i] + parsing_anno[left_pos[0], left_pos[1]] = right_idx[i] + + trans = get_affine_transform(center, s, r, self.crop_size) + input = cv2.warpAffine( + im, + trans, + (int(self.crop_size[1]), int(self.crop_size[0])), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(0, 0, 0)) + + if self.transform: + input = self.transform(input) + + meta = { + 'name': im_name, + 'center': center, + 'height': h, + 'width': w, + 'scale': s, + 'rotation': r + } + + if self.dataset != 'train': + return input, meta + else: + label_parsing = cv2.warpAffine( + parsing_anno, + trans, + (int(self.crop_size[1]), int(self.crop_size[0])), + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(255)) + + # label_edge = generate_edge(label_parsing) + hgt, wgt, hwgt = generate_hw_gt(label_parsing) + label_parsing = torch.from_numpy(label_parsing) + # label_edge = torch.from_numpy(label_edge) + + return input, label_parsing, hgt,wgt,hwgt, meta + +class LIPDataValSet(data.Dataset): + def __init__(self, root, dataset='val', crop_size=[512, 512], transform=None, flip=False): + self.root = root + self.crop_size = crop_size + self.transform = transform + self.flip = flip + self.dataset = dataset + self.root = root + self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0] + self.crop_size = np.asarray(crop_size) + + list_path = os.path.join(self.root, self.dataset + '_id.txt') + val_list = [i_id.strip() for i_id in open(list_path)] + + self.val_list = val_list + self.number_samples = len(self.val_list) + + def __len__(self): + return len(self.val_list) + + def _box2cs(self, box): + x, y, w, h = box[:4] + return self._xywh2cs(x, y, w, h) + + def _xywh2cs(self, x, y, w, h): + center = np.zeros((2), dtype=np.float32) + center[0] = x + w * 0.5 + center[1] = y + h * 0.5 + if w > self.aspect_ratio * h: + h = w * 1.0 / self.aspect_ratio + elif w < self.aspect_ratio * h: + w = h * self.aspect_ratio + scale = np.array([w * 1.0, h * 1.0], dtype=np.float32) + + return center, scale + + def __getitem__(self, index): + val_item = self.val_list[index] + # Load training image + im_path = os.path.join(self.root, self.dataset + '_images', val_item + '.jpg') + im = cv2.imread(im_path, cv2.IMREAD_COLOR) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + h, w, _ = im.shape + # Get person center and scale + person_center, s = self._box2cs([0, 0, w - 1, h - 1]) + r = 0 + trans = get_affine_transform(person_center, s, r, self.crop_size) + input = cv2.warpAffine( + im, + trans, + (int(self.crop_size[1]), int(self.crop_size[0])), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(0, 0, 0)) + input = self.transform(input) + flip_input = input.flip(dims=[-1]) + if self.flip: + batch_input_im = torch.stack([input, flip_input]) + else: + batch_input_im = input + + meta = { + 'name': val_item, + 'center': person_center, + 'height': h, + 'width': w, + 'scale': s, + 'rotation': r + } + + return batch_input_im, meta + +''' +root = '/home/vrushank/Spyne/CCIHP' +dataset = 'train' +data1 = LIPDataValSet(root, dataset, crop_size=[512, 512]) +loader = DataLoader(data1, batch_size = 1, shuffle = True) + +for idx, (input, label_parsing, hgt,wgt,hwgt, meta) in enumerate(loader): + + if idx == 0: + + print(input.shape) + print(label_parsing.shape) + + ip = input.squeeze(0).cpu().numpy() + label = decode_parsing(label_parsing, num_classes = 22) + print(type(label)) + label = label[0].data.cpu().numpy() + label = label.transpose((1,2,0)) + #label = cv2.cvtColor(label, cv2.COLOR_GRAY2BGR) + print(ip.shape) + print(label.shape) + res = np.concatenate((ip, label), axis = 1) + plt.imshow(res) + plt.show() + #print(f'{hgt}: {hgt.shape}') + #print(f'{wgt}: {wgt.shape}') + #print(f'{hwgt}: {hwgt.shape}') + + else: + + break +''' \ No newline at end of file diff --git a/dataset/list/.DS_Store b/dataset/list/.DS_Store new file mode 100644 index 0000000..74dc7d8 Binary files /dev/null and b/dataset/list/.DS_Store differ diff --git a/dataset/target_generation.py b/dataset/target_generation.py new file mode 100644 index 0000000..1b82d28 --- /dev/null +++ b/dataset/target_generation.py @@ -0,0 +1,82 @@ +import os +import sys +import numpy as np +import random +import cv2 +import torch +from torch.nn import functional as F + +def generate_hw_gt( target, class_num = 22 ): + h,w = target.shape + target = torch.from_numpy(target) + target_c = target.clone() + target_c[target_c==255]=0 + target_c = target_c.long() + target_c = target_c.view(h*w) + target_c = target_c.unsqueeze(1) + target_onehot = torch.zeros(h*w,class_num) + target_onehot.scatter_( 1, target_c, 1 ) #h*w,class_num + target_onehot = target_onehot.transpose(0,1) + target_onehot = target_onehot.view(class_num,h,w) + # h distribution ground truth + hgt = torch.zeros((class_num,h)) + hgt=( torch.sum( target_onehot, dim=2 ) ).float() + hgt[0,:] = 0 + max = torch.max(hgt,dim=1)[0] #c,1 + min = torch.min(hgt,dim=1)[0] + max = max.unsqueeze(1) + min = min.unsqueeze(1) + hgt = hgt / ( max + 1e-5 ) + # w distribution gound truth + wgt = torch.zeros((class_num,w)) + wgt=( torch.sum(target_onehot, dim=1 ) ).float() + wgt[0,:]=0 + max = torch.max(wgt,dim=1)[0] #c,1 + min = torch.min(wgt,dim=1)[0] + max = max.unsqueeze(1) + min = min.unsqueeze(1) + wgt = wgt / ( max + 1e-5 ) + #=========================================================== + hwgt = torch.matmul( hgt.transpose(0,1), wgt ) + max = torch.max( hwgt.view(-1), dim=0 )[0] + # print(max) + hwgt = hwgt / ( max + 1.0e-5 ) + #==================================================================== + return hgt, wgt, hwgt #,cch, ccw gt_hw + +def generate_edge(label, edge_width=3): + label = label.type(torch.cuda.FloatTensor) + if len(label.shape) == 2: + label = label.unsqueeze(0) + n, h, w = label.shape + edge = torch.zeros(label.shape, dtype=torch.float).cuda() + # right + edge_right = edge[:, 1:h, :] + edge_right[(label[:, 1:h, :] != label[:, :h - 1, :]) & (label[:, 1:h, :] != 255) + & (label[:, :h - 1, :] != 255)] = 1 + + # up + edge_up = edge[:, :, :w - 1] + edge_up[(label[:, :, :w - 1] != label[:, :, 1:w]) + & (label[:, :, :w - 1] != 255) + & (label[:, :, 1:w] != 255)] = 1 + + # upright + edge_upright = edge[:, :h - 1, :w - 1] + edge_upright[(label[:, :h - 1, :w - 1] != label[:, 1:h, 1:w]) + & (label[:, :h - 1, :w - 1] != 255) + & (label[:, 1:h, 1:w] != 255)] = 1 + + # bottomright + edge_bottomright = edge[:, :h - 1, 1:w] + edge_bottomright[(label[:, :h - 1, 1:w] != label[:, 1:h, :w - 1]) + & (label[:, :h - 1, 1:w] != 255) + & (label[:, 1:h, :w - 1] != 255)] = 1 + + kernel = torch.ones((1, 1, edge_width, edge_width), dtype=torch.float).cuda() + with torch.no_grad(): + edge = edge.unsqueeze(1) + edge = F.conv2d(edge, kernel, stride=1, padding=1) + edge[edge!=0] = 1 + edge = edge.squeeze() + return edge \ No newline at end of file diff --git a/engine.py b/engine.py new file mode 100644 index 0000000..5edee44 --- /dev/null +++ b/engine.py @@ -0,0 +1,133 @@ +import os +import os.path as osp +import time +import argparse + +import torch +import torch.distributed as dist + +from utils.logger import get_logger +from utils.pyt_utils import parse_devices, all_reduce_tensor, extant_file +''' +try: + from apex.parallel import DistributedDataParallel, SyncBatchNorm +except ImportError: + raise ImportError( + "Please install apex from https://www.github.com/nvidia/apex .") +''' + +logger = get_logger() + + +class Engine(object): + def __init__(self, custom_parser=None): + logger.info( + "PyTorch Version {}".format(torch.__version__)) + self.devices = None + self.distributed = False + + if custom_parser is None: + self.parser = argparse.ArgumentParser() + else: + assert isinstance(custom_parser, argparse.ArgumentParser) + self.parser = custom_parser + + self.inject_default_parser() + self.args = self.parser.parse_args() + + self.continue_state_object = self.args.continue_fpath + + # if not self.args.gpu == 'None': + # os.environ["CUDA_VISIBLE_DEVICES"]=self.args.gpu + + if 'WORLD_SIZE' in os.environ: + self.distributed = int(os.environ['WORLD_SIZE']) > 1 + + if self.distributed: + self.local_rank = self.args.local_rank + self.world_size = int(os.environ['WORLD_SIZE']) + torch.cuda.set_device(self.local_rank) + dist.init_process_group(backend="nccl", init_method='env://') + self.devices = [i for i in range(self.world_size)] + else: + gpus = os.environ["CUDA_VISIBLE_DEVICES"] + self.devices = [i for i in range(len(gpus.split(',')))] + + def inject_default_parser(self): + p = self.parser + p.add_argument('-d', '--devices', default='', + help='set data parallel training') + p.add_argument('-c', '--continue', type=extant_file, + metavar="FILE", + dest="continue_fpath", + help='continue from one certain checkpoint') + p.add_argument('--local_rank', default=0, type=int, + help='process rank on node') + + def data_parallel(self, model): + if self.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.local_rank], + output_device=self.local_rank,) + else: + model = torch.nn.DataParallel(model) + return model + + def get_train_loader(self, train_dataset): + train_sampler = None + is_shuffle = True + batch_size = self.args.batch_size + + if self.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset) + batch_size = self.args.batch_size // self.world_size + is_shuffle = False + + train_loader = torch.utils.data.DataLoader(train_dataset, + batch_size=batch_size, + num_workers=self.args.num_workers, + drop_last=False, + shuffle=is_shuffle, + pin_memory=True, + sampler=train_sampler) + + return train_loader, train_sampler + + def get_test_loader(self, test_dataset): + test_sampler = None + is_shuffle = False + batch_size = self.args.batch_size + + if self.distributed: + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_dataset) + batch_size = self.args.batch_size // self.world_size + + test_loader = torch.utils.data.DataLoader(test_dataset, + batch_size=batch_size, + num_workers=self.args.num_workers, + drop_last=False, + shuffle=is_shuffle, + pin_memory=True, + sampler=test_sampler) + + return test_loader, test_sampler + + + def all_reduce_tensor(self, tensor, norm=True): + if self.distributed: + return all_reduce_tensor(tensor, world_size=self.world_size, norm=norm) + else: + return torch.mean(tensor) + + + def __enter__(self): + return self + + def __exit__(self, type, value, tb): + torch.cuda.empty_cache() + if type is not None: + logger.warning( + "A exception occurred during Engine initialization, " + "give up running process") + return False diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..29d7f99 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,274 @@ +import argparse +import numpy as np +import torch +torch.multiprocessing.set_start_method("spawn", force=True) +from torch.utils import data +from tqdm import tqdm +from networks.CDGNet import Res_Deeplab +from dataset.datasets import LIPDataSet +import os +import torchvision.transforms as transforms +from utils.miou import compute_mean_ioU +from copy import deepcopy + +from PIL import Image as PILImage + +DATA_DIRECTORY = '/ssd1/liuting14/Dataset/LIP/' +DATA_LIST_PATH = './dataset/list/lip/valList.txt' +IGNORE_LABEL = 255 +NUM_CLASSES = 20 +SNAPSHOT_DIR = './snapshots/' +INPUT_SIZE = (473,473) + +# colour map +COLORS = [(0,0,0) + # 0=background + ,(128,0,0),(0,128,0),(128,128,0),(0,0,128),(128,0,128) + # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle + ,(0,128,128),(128,128,128),(64,0,0),(192,0,0),(64,128,0) + # 6=bus, 7=car, 8=cat, 9=chair, 10=cow + ,(192,128,0),(64,0,128),(192,0,128),(64,128,128),(192,128,128) + # 11=diningtable, 12=dog, 13=horse, 14=motorbike, 15=person + ,(0,64,0),(128,64,0),(0,192,0),(128,192,0),(0,64,128)] + # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor + +def get_ccihp_pallete(): + + palette = [ + 120, 120, 120, + 127, 0, 0, + 254, 0, 0, + 0, 84, 0, + 169, 0, 50, + 254, 84, 0, + 255, 0, 84, + 0, 118, 220, + 84, 84, 0, + 0, 84, 84, + 84, 50, 0, + 51, 85, 127, + 0, 127, 0, + 0, 0, 254, + 50, 169, 220, + 0, 254, 254, + 84, 254, 169, + 169, 254, 84, + 254, 254, 0, + 254, 169, 0, + 102, 254, 0, + 182, 255, 0 + + ] + + return palette + +def get_lip_palette(): + palette = [0,0,0, + 128,0,0, + 255,0,0, + 0,85,0, + 170,0,51, + 255,85,0, + 0,0,85, + 0,119,221, + 85,85,0, + 0,85,85, + 85,51,0, + 52,86,128, + 0,128,0, + 0,0,255, + 51,170,221, + 0,255,255, + 85,255,170, + 170,255,85, + 255,255,0, + 255,170,0] + return palette +def get_palette(num_cls): + """ Returns the color map for visualizing the segmentation mask. + + Inputs: + =num_cls= + Number of classes. + + Returns: + The color map. + """ + n = num_cls + palette = [0] * (n * 3) + for j in range(0, n): + lab = j + palette[j * 3 + 0] = 0 + palette[j * 3 + 1] = 0 + palette[j * 3 + 2] = 0 + i = 0 + while lab: + palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) + palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) + palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) + i += 1 + lab >>= 3 + return palette + +def get_arguments(): + """Parse all the arguments provided from the CLI. + + Returns: + A list of parsed arguments. + """ + parser = argparse.ArgumentParser(description="CE2P Network") + parser.add_argument("--batch-size", type=int, default=1, + help="Number of images sent to the network in one step.") + parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, + help="Path to the directory containing the PASCAL VOC dataset.") + parser.add_argument("--dataset", type=str, default='val', + help="Path to the file listing the images in the dataset.") + parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, + help="The index of the label to ignore during the training.") + parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, + help="Number of classes to predict (including background).") + parser.add_argument("--restore-from", type=str, + help="Where restore model parameters from.") + parser.add_argument("--gpu", type=str, default='0', + help="choose gpu device.") + parser.add_argument("--input-size", type=str, default=INPUT_SIZE, + help="Comma-separated string with height and width of images.") + + return parser.parse_args() + +def valid(model, valloader, input_size, num_samples, gpus): + + model.eval() + parsing_preds = np.zeros((num_samples, input_size[0], input_size[1]), + dtype=np.uint8) + + scales = np.zeros((num_samples, 2), dtype=np.float32) + centers = np.zeros((num_samples, 2), dtype=np.int32) + + idx = 0 + interp = torch.nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) + with torch.no_grad(): + loop = tqdm(valloader, position = 0, leave = True) + for index, batch in enumerate(valloader): + image, meta = batch + num_images = image.size(0) + if index % 100 == 0: + print('%d processed' % (index * num_images)) + + c = meta['center'].numpy() + s = meta['scale'].numpy() + scales[idx:idx + num_images, :] = s[:, :] + centers[idx:idx + num_images, :] = c[:, :] + #==================================================================================== + org_img = image.numpy() + normal_img = org_img + flipped_img = org_img[:,:,:,::-1] + fused_img = np.concatenate( (normal_img,flipped_img), axis=0 ) + with torch.no_grad(): + outputs = model( torch.from_numpy(fused_img).cuda()) + prediction = interp( outputs[0][-1].cpu()).data.numpy().transpose(0, 2, 3, 1) #N,H,W,C + single_out = prediction[:num_images,:,:,:] + single_out_flip = np.zeros( single_out.shape ) + single_out_tmp = prediction[num_images:, :,:,:] + for c in range(14): + single_out_flip[:,:, :, c] = single_out_tmp[:, :, :, c] + single_out_flip[:, :, :, 14] = single_out_tmp[:, :, :, 15] + single_out_flip[:, :, :, 15] = single_out_tmp[:, :, :, 14] + single_out_flip[:, :, :, 16] = single_out_tmp[:, :, :, 17] + single_out_flip[:, :, :, 17] = single_out_tmp[:, :, :, 16] + single_out_flip[:, :, :, 18] = single_out_tmp[:, :, :, 19] + single_out_flip[:, :, :, 19] = single_out_tmp[:, :, :, 18] + single_out_flip = single_out_flip[:, :, ::-1, :] + # Fuse two outputs + single_out = ( single_out+single_out_flip ) / 2 + parsing_preds[idx:idx + num_images, :, :] = np.asarray(np.argmax(single_out, axis=3), dtype=np.uint8) + #==================================================================================== + # outputs = model(image.cuda()) + # if gpus > 1: + # for output in outputs: + # parsing = output[0][-1] + # nums = len(parsing) + # parsing = interp(parsing).data.cpu().numpy() + # parsing = parsing.transpose(0, 2, 3, 1) # NCHW NHWC + # parsing_preds[idx:idx + nums, :, :] = np.asarray(np.argmax(parsing, axis=3), dtype=np.uint8) + + # idx += nums + # else: + # parsing = outputs[0][-1] + # parsing = interp(parsing).data.cpu().numpy() + # parsing = parsing.transpose(0, 2, 3, 1) # NCHW NHWC + # parsing_preds[idx:idx + num_images, :, :] = np.asarray(np.argmax(parsing, axis=3), dtype=np.uint8) + + idx += num_images + + parsing_preds = parsing_preds[:num_samples, :, :] + print(f'Parsing preds: {parsing_preds.shape}') + + return parsing_preds, fused_img, scales, centers + +def main(): + """Create the model and start the evaluation process.""" + args = get_arguments() + + os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu + gpus = [int(i) for i in args.gpu.split(',')] + + h, w = map(int, args.input_size.split(',')) + + input_size = (h, w) + + model = Res_Deeplab(num_classes=args.num_classes) + + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + transform = transforms.Compose([ + transforms.ToTensor(), + normalize, + ]) + + lip_dataset = LIPDataSet(args.data_dir, 'val', crop_size=input_size, transform=transform) + num_samples = len(lip_dataset) + + valloader = data.DataLoader(lip_dataset, batch_size=args.batch_size * len(gpus), + shuffle=False, pin_memory=True) + + restore_from = args.restore_from + + state_dict = model.state_dict().copy() + state_dict_old = torch.load(restore_from) + + for key, nkey in zip(state_dict_old.keys(), state_dict.keys()): + if key != nkey: + # remove the 'module.' in the 'key' + state_dict[key[7:]] = deepcopy(state_dict_old[key]) + else: + state_dict[key] = deepcopy(state_dict_old[key]) + + model.load_state_dict(state_dict) + + model.eval() + model.cuda() + + parsing_preds, scales, centers = valid(model, valloader, input_size, num_samples, len(gpus)) + + #================================================================= + # list_path = os.path.join(args.data_dir, args.dataset + '_id.txt') + # val_id = [i_id.strip() for i_id in open(list_path)] + # pred_root = os.path.join( args.data_dir, 'pred_parsing') + # if not os.path.exists( pred_root ): + # os.makedirs( pred_root ) + # palette = get_lip_palette() + # output_parsing = parsing_preds + # for i in range( num_samples ): + # output_image = PILImage.fromarray( output_parsing[i] ) + # output_image.putpalette( palette ) + # output_image.save( os.path.join( pred_root, str(val_id[i])+'.png')) + #================================================================= + + mIoU = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size) + + print(mIoU) + +if __name__ == '__main__': + main() diff --git a/evaluate_multi.py b/evaluate_multi.py new file mode 100644 index 0000000..956baa0 --- /dev/null +++ b/evaluate_multi.py @@ -0,0 +1,252 @@ +import argparse +import numpy as np +import torch +torch.multiprocessing.set_start_method("spawn", force=True) +from torch.utils import data +from networks.CDGNet import Res_Deeplab +from dataset.datasets import LIPDataValSet +import os +import torchvision.transforms as transforms +from utils.miou import compute_mean_ioU +from copy import deepcopy +import cv2 + +from PIL import Image as PILImage + +DATA_DIRECTORY = '/ssd1/liuting14/Dataset/LIP/' +DATA_LIST_PATH = './dataset/list/lip/valList.txt' +IGNORE_LABEL = 255 +NUM_CLASSES = 20 +SNAPSHOT_DIR = './snapshots/' +INPUT_SIZE = (473,473) + +# colour map +COLORS = [(0,0,0) + # 0=background + ,(128,0,0),(0,128,0),(128,128,0),(0,0,128),(128,0,128) + # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle + ,(0,128,128),(128,128,128),(64,0,0),(192,0,0),(64,128,0) + # 6=bus, 7=car, 8=cat, 9=chair, 10=cow + ,(192,128,0),(64,0,128),(192,0,128),(64,128,128),(192,128,128) + # 11=diningtable, 12=dog, 13=horse, 14=motorbike, 15=person + ,(0,64,0),(128,64,0),(0,192,0),(128,192,0),(0,64,128)] + # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor +def get_lip_palette(): + palette = [0,0,0, + 128,0,0, + 255,0,0, + 0,85,0, + 170,0,51, + 255,85,0, + 0,0,85, + 0,119,221, + 85,85,0, + 0,85,85, + 85,51,0, + 52,86,128, + 0,128,0, + 0,0,255, + 51,170,221, + 0,255,255, + 85,255,170, + 170,255,85, + 255,255,0, + 255,170,0] + return palette +def get_palette(num_cls): + """ Returns the color map for visualizing the segmentation mask. + + Inputs: + =num_cls= + Number of classes. + + Returns: + The color map. + """ + n = num_cls + palette = [0] * (n * 3) + for j in range(0, n): + lab = j + palette[j * 3 + 0] = 0 + palette[j * 3 + 1] = 0 + palette[j * 3 + 2] = 0 + i = 0 + while lab: + palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) + palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) + palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) + i += 1 + lab >>= 3 + return palette + +def get_arguments(): + """Parse all the arguments provided from the CLI. + + Returns: + A list of parsed arguments. + """ + parser = argparse.ArgumentParser(description="CE2P Network") + parser.add_argument("--batch-size", type=int, default=1, + help="Number of images sent to the network in one step.") + parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, + help="Path to the directory containing the PASCAL VOC dataset.") + parser.add_argument("--dataset", type=str, default='val', + help="Path to the file listing the images in the dataset.") + parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, + help="The index of the label to ignore during the training.") + parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, + help="Number of classes to predict (including background).") + parser.add_argument("--restore-from", type=str, + help="Where restore model parameters from.") + parser.add_argument("--gpu", type=str, default='0', + help="choose gpu device.") + parser.add_argument("--input-size", type=str, default=INPUT_SIZE, + help="Comma-separated string with height and width of images.") + + return parser.parse_args() + +# def scale_image(image, scale): +# image = image[0, :, :, :] +# image = image.transpose((1, 2, 0)) +# image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) +# image = image.transpose((2, 0, 1)) +# return image + +def valid(model, valloader, input_size, num_samples, gpus): + model.eval() + + parsing_preds = np.zeros((num_samples, input_size[0], input_size[1]), + dtype=np.uint8) + + scales = np.zeros((num_samples, 2), dtype=np.float32) + centers = np.zeros((num_samples, 2), dtype=np.int32) + + hpreds_lst = [] + wpreds_lst = [] + + idx = 0 + interp = torch.nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) + eval_scale=[0.75,1.0,1.25] + # eval_scale=[1.0] + flipped_idx = (15, 14, 17, 16, 19, 18) + with torch.no_grad(): + for index, batch in enumerate(valloader): + image, meta = batch + # num_images = image.size(0) + # print( image.size() ) + image = image.squeeze() + if index % 10 == 0: + print('%d processd' % (index * 1)) + c = meta['center'].numpy()[0] + s = meta['scale'].numpy()[0] + scales[idx, :] = s + centers[idx, :] = c + #==================================================================================== + mul_outputs = [] + for scale in eval_scale: + interp_img = torch.nn.Upsample(scale_factor=scale, mode='bilinear', align_corners=True) + scaled_img = interp_img( image ) + # print( scaled_img.size() ) + outputs = model( scaled_img.cuda() ) + prediction = outputs[0][-1] + #========================================================== + hPreds = outputs[2][0] + wPreds = outputs[2][1] + hpreds_lst.append( hPreds[0].data.cpu().numpy() ) + wpreds_lst.append( wPreds[0].data.cpu().numpy() ) + #========================================================== + single_output = prediction[0] + flipped_output = prediction[1] + flipped_output[14:20,:,:]=flipped_output[flipped_idx,:,:] + single_output += flipped_output.flip(dims=[-1]) + single_output *=0.5 + # print( single_output.size() ) + single_output = interp( single_output.unsqueeze(0) ) + mul_outputs.append( single_output[0] ) + fused_prediction = torch.stack( mul_outputs ) + fused_prediction = fused_prediction.mean(0) + fused_prediction = fused_prediction.permute(1, 2, 0) # HWC + fused_prediction = torch.argmax(fused_prediction, dim=2) + fused_prediction = fused_prediction.data.cpu().numpy() + parsing_preds[idx, :, :] = np.asarray(fused_prediction, dtype=np.uint8) + #==================================================================================== + idx += 1 + + parsing_preds = parsing_preds[:num_samples, :, :] + + + return parsing_preds, scales, centers, hpreds_lst, wpreds_lst + +def main(): + """Create the model and start the evaluation process.""" + args = get_arguments() + + os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu + gpus = [int(i) for i in args.gpu.split(',')] + + h, w = map(int, args.input_size.split(',')) + + input_size = (h, w) + + model = Res_Deeplab(num_classes=args.num_classes) + + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + transform = transforms.Compose([ + transforms.ToTensor(), + normalize, + ]) + + lip_dataset = LIPDataValSet(args.data_dir, 'val', crop_size=input_size, transform=transform, flip = True ) + num_samples = len(lip_dataset) + + valloader = data.DataLoader(lip_dataset, batch_size=args.batch_size * len(gpus), + shuffle=False, pin_memory=True) + + restore_from = args.restore_from + + state_dict = model.state_dict().copy() + state_dict_old = torch.load(restore_from) + + for key, nkey in zip(state_dict_old.keys(), state_dict.keys()): + if key != nkey: + # remove the 'module.' in the 'key' + state_dict[key[7:]] = deepcopy(state_dict_old[key]) + else: + state_dict[key] = deepcopy(state_dict_old[key]) + + model.load_state_dict(state_dict) + + model.eval() + model.cuda() + + parsing_preds, scales, centers, hpredLst, wpredLst = valid(model, valloader, input_size, num_samples, len(gpus)) + + #================================================================= + # list_path = os.path.join(args.data_dir, args.dataset + '_id.txt') + # val_id = [i_id.strip() for i_id in open(list_path)] + # # pred_root = os.path.join( args.data_dir, 'pred_parsing') + # pred_root = os.path.join( os.getcwd(), 'pred_parsing') + # print( pred_root ) + # if not os.path.exists( pred_root ): + # os.makedirs( pred_root ) + # palette = get_lip_palette() + # output_parsing = parsing_preds + # for i in range( num_samples ): + # output_image = PILImage.fromarray( output_parsing[i] ) + # output_image.putpalette( palette ) + # output_image.save( os.path.join( pred_root, str(val_id[i])+'.png')) + # i=0 + # for i in range(len( hpredLst )): + # filenameh = os.path.join( pred_root, str( val_id[i] ) + "_h" ) + # np.save( filenameh, hpredLst[i] ) + # filenamew = os.path.join( pred_root, str( val_id[i] ) + "_w" ) + # np.save( filenamew, wpredLst[i] ) + #================================================================= + mIoU = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size) + + print(mIoU) + +if __name__ == '__main__': + main() diff --git a/networks/.DS_Store b/networks/.DS_Store new file mode 100644 index 0000000..820d82b Binary files /dev/null and b/networks/.DS_Store differ diff --git a/networks/CDGNet.py b/networks/CDGNet.py new file mode 100644 index 0000000..006375a --- /dev/null +++ b/networks/CDGNet.py @@ -0,0 +1,334 @@ +import torch.nn as nn +import sys +sys.path.insert(0, '.') +from torch.nn import functional as F +import math +import torch.utils.model_zoo as model_zoo +import torch +import numpy as np +from torch.autograd import Variable +affine_par = True +import functools + +import sys, os +from utils.attention import CDGAttention, C2CAttention +from torch.nn import BatchNorm2d as BatchNorm2d + +def InPlaceABNSync(in_channel): + layers = [ + BatchNorm2d(in_channel), + nn.ReLU(), + ] + return nn.Sequential(*layers) + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class Bottleneck(nn.Module): + expansion = 4 + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, fist_dilation=1, multi_grid=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=dilation*multi_grid, dilation=dilation*multi_grid, bias=False) + self.bn2 = BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=False) + self.relu_inplace = nn.ReLU(inplace=True) + self.downsample = downsample + self.dilation = dilation + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu_inplace(out) + + return out + +class ASPPModule(nn.Module): + """ + Reference: + Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."* + """ + def __init__(self, features, inner_features=256, out_features=512, dilations=(12, 24, 36)): + super(ASPPModule, self).__init__() + + self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)), + nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1, bias=False), + InPlaceABNSync(inner_features)) + self.conv2 = nn.Sequential(nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1, bias=False), + InPlaceABNSync(inner_features)) + self.conv3 = nn.Sequential(nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False), + InPlaceABNSync(inner_features)) + self.conv4 = nn.Sequential(nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False), + InPlaceABNSync(inner_features)) + self.conv5 = nn.Sequential(nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False), + InPlaceABNSync(inner_features)) + + self.bottleneck = nn.Sequential( + nn.Conv2d(inner_features * 5, out_features, kernel_size=1, padding=0, dilation=1, bias=False), + InPlaceABNSync(out_features), + nn.Dropout2d(0.1) + ) + + def forward(self, x): + + _, _, h, w = x.size() + + feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True) + + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1) + + bottle = self.bottleneck(out) + return bottle + +class Edge_Module(nn.Module): + + def __init__(self,in_fea=[256,512,1024], mid_fea=256, out_fea=2): + super(Edge_Module, self).__init__() + + self.conv1 = nn.Sequential( + nn.Conv2d(in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False), + InPlaceABNSync(mid_fea) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False), + InPlaceABNSync(mid_fea) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False), + InPlaceABNSync(mid_fea) + ) + self.conv4 = nn.Conv2d(mid_fea,out_fea, kernel_size=3, padding=1, dilation=1, bias=True) + self.conv5 = nn.Conv2d(out_fea*3,out_fea, kernel_size=1, padding=0, dilation=1, bias=True) + + def forward(self, x1, x2, x3): + _, _, h, w = x1.size() + + edge1_fea = self.conv1(x1) + edge1 = self.conv4(edge1_fea) + edge2_fea = self.conv2(x2) + edge2 = self.conv4(edge2_fea) + edge3_fea = self.conv3(x3) + edge3 = self.conv4(edge3_fea) + + edge2_fea = F.interpolate(edge2_fea, size=(h, w), mode='bilinear',align_corners=True) + edge3_fea = F.interpolate(edge3_fea, size=(h, w), mode='bilinear',align_corners=True) + edge2 = F.interpolate(edge2, size=(h, w), mode='bilinear',align_corners=True) + edge3 = F.interpolate(edge3, size=(h, w), mode='bilinear',align_corners=True) + + edge = torch.cat([edge1, edge2, edge3], dim=1) + edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1) + edge = self.conv5(edge) + return edge, edge_fea + +class PSPModule(nn.Module): + """ + Reference: + Zhao, Hengshuang, et al. *"Pyramid scene parsing network."* + """ + def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)): + super(PSPModule, self).__init__() + + self.stages = [] + self.stages = nn.ModuleList([self._make_stage(features, out_features, size) for size in sizes]) + self.bottleneck = nn.Sequential( + nn.Conv2d(features+len(sizes)*out_features, out_features, kernel_size=3, padding=1, dilation=1, bias=False), + InPlaceABNSync(out_features), + ) + + def _make_stage(self, features, out_features, size): + prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) + conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False) + bn = InPlaceABNSync(out_features) + return nn.Sequential(prior, conv, bn) + + def forward(self, feats): + h, w = feats.size(2), feats.size(3) + priors = [ F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] + [feats] + bottle = self.bottleneck(torch.cat(priors, 1)) + return bottle + +class Decoder_Module(nn.Module): + + def __init__(self, num_classes): + super(Decoder_Module, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(512, 256, kernel_size=1, padding=0, dilation=1, bias=False), + InPlaceABNSync(256) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(256, 48, kernel_size=3, stride=1, padding=1, dilation=1, bias=False), + InPlaceABNSync(48) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(304, 256, kernel_size=3, padding=1, dilation=1, bias=False), + InPlaceABNSync(256), + nn.Conv2d(256, 256, kernel_size=1, padding=0, dilation=1, bias=False), + InPlaceABNSync(256) + ) + self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True) + #========================================================================================= + self.addCAM = nn.Sequential( + nn.Conv2d(512, 256, kernel_size=3, padding=1, dilation=1, bias=False), + InPlaceABNSync(256), + ) + #======================================================================================= + def PCM(self, cam, f): + n,c,h,w = f.size() + cam = F.interpolate(cam, (h,w), mode='bilinear', align_corners=True).view(n,-1,h*w) + f = f.view(n,-1,h*w) + aff = torch.matmul(f.transpose(1,2), f) + aff = ( c ** -0.5 ) * aff + aff = F.softmax( aff, dim = -1 ) #无tanspose时是57.28,有transpose�?6.64 + cam_rv = torch.matmul(cam, aff).view(n,-1,h,w) + return cam_rv + def forward(self, xt, xl, xPCM = None ): + _, _, h, w = xl.size() + xt = F.interpolate(self.conv1(xt), size=(h, w), mode='bilinear', align_corners=True) + xl = self.conv2(xl) + x = torch.cat([xt, xl], dim=1) + x = self.conv3(x) + with torch.no_grad(): + xM = F.relu( x.detach() ) + xPCM = F.interpolate( self.PCM( xM, xPCM ), size=(h, w), mode='bilinear', align_corners=True ) + x = torch.cat( [x, xPCM ], dim = 1 ) + x = self.addCAM( x ) + seg = self.conv4(x) + return seg,x + +class ResNet(nn.Module): + def __init__(self, block, layers, num_classes): + self.inplanes = 128 + super(ResNet, self).__init__() + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = BatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=False) + self.conv2 = conv3x3(64, 64) + self.bn2 = BatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=False) + self.conv3 = conv3x3(64, 128) + self.bn3 = BatchNorm2d(128) + self.relu3 = nn.ReLU(inplace=False) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=2, multi_grid=(1,1,1)) + + self.layer5 = PSPModule(2048,512) + + self.edge_layer = Edge_Module() + self.layer6 = Decoder_Module(num_classes) + + self.layer7 = nn.Sequential( + nn.Conv2d(1024, 256, kernel_size=3, padding=1, dilation=1, bias=False), + InPlaceABNSync(256), + nn.Dropout2d(0.1), + nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True) + ) + #=================================================================================== + self.sq4 = nn.Sequential( + nn.Conv2d(1024, 256, kernel_size=1, padding=0, dilation=1, bias=False), + InPlaceABNSync(256) + ) + self.sq5 = nn.Sequential( + nn.Conv2d( 256, 256, kernel_size=1, padding=0, dilation=1, bias=False ), + InPlaceABNSync(256) + ) + self.f9 = nn.Sequential( + nn.Conv2d(256+256+3, 256, kernel_size=1, padding=0, dilation=1, bias=False), + InPlaceABNSync(256) + ) + #=============================================================== + self.hwAttention = CDGAttention(512, 256, num_classes, [473//4,473//4], 7 ) + self.L = nn.Conv2d(1024, num_classes, kernel_size=1, padding=0, dilation=1, bias=True) + #================================================================ + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight.data) + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(planes * block.expansion)) + + layers = [] + generate_multi_grid = lambda index, grids: grids[index%len(grids)] if isinstance(grids, tuple) else 1 + layers.append(block(self.inplanes, planes, stride,dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid))) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid))) + + return nn.Sequential(*layers) + + def forward(self, x): + x_org = x + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + x2 = self.layer1(x) #1/4 256 + x3 = self.layer2(x2) #1/8 512 + x4 = self.layer3(x3) #1/16 1024 + seg0 = self.L(x4) + x5 = self.layer4(x4) #1/16 2048 + x = self.layer5(x5) #1/16 512 + #============================================================== + x,fea_h1, fea_w1 = self.hwAttention(x) + #============================================================== + edge,edge_fea = self.edge_layer(x2,x3,x4) + #============================================================== + n, c, h, w = x4.size() + fr1 = self.sq5( x2 ) + fr1 = F.interpolate( fr1, (h,w), mode='bilinear', align_corners= True ) + fr1 = F.relu( fr1, inplace= True ) + fr2 = self.sq4( x4 ) + fr2 = F.interpolate( fr2, (h,w), mode='bilinear', align_corners= True ) + fr2 = F.relu( fr2, inplace= True ) + frOrg = F.interpolate( x_org,(h,w), mode='bilinear', align_corners=True ) + fCat = torch.cat([frOrg, fr1, fr2 ], dim = 1) + fCat = self.f9( fCat ) + #============================================================== + seg1,x = self.layer6( x,x2, fCat ) + #============================================================= + x = torch.cat([x, edge_fea], dim=1) + seg2 = self.layer7(x) + return [[seg0, seg1, seg2], [edge],[fea_h1,fea_w1]] + +def Res_Deeplab(num_classes=21): + model = ResNet(Bottleneck,[3, 4, 23, 3], num_classes) + return model + + diff --git a/networks/__pycache__/CDGNet.cpython-310.pyc b/networks/__pycache__/CDGNet.cpython-310.pyc new file mode 100644 index 0000000..a90bb8c Binary files /dev/null and b/networks/__pycache__/CDGNet.cpython-310.pyc differ diff --git a/networks/__pycache__/CDGNet.cpython-36.pyc b/networks/__pycache__/CDGNet.cpython-36.pyc new file mode 100644 index 0000000..1d69f3e Binary files /dev/null and b/networks/__pycache__/CDGNet.cpython-36.pyc differ diff --git a/networks/__pycache__/CE2P.cpython-36.pyc b/networks/__pycache__/CE2P.cpython-36.pyc new file mode 100644 index 0000000..07f2be3 Binary files /dev/null and b/networks/__pycache__/CE2P.cpython-36.pyc differ diff --git a/networks/__pycache__/CE2P.cpython-37.pyc b/networks/__pycache__/CE2P.cpython-37.pyc new file mode 100644 index 0000000..96bba2e Binary files /dev/null and b/networks/__pycache__/CE2P.cpython-37.pyc differ diff --git a/networks/__pycache__/CE2P.cpython-38.pyc b/networks/__pycache__/CE2P.cpython-38.pyc new file mode 100644 index 0000000..eb8c87e Binary files /dev/null and b/networks/__pycache__/CE2P.cpython-38.pyc differ diff --git a/networks/__pycache__/CE2PHybrid.cpython-36.pyc b/networks/__pycache__/CE2PHybrid.cpython-36.pyc new file mode 100644 index 0000000..5e9f190 Binary files /dev/null and b/networks/__pycache__/CE2PHybrid.cpython-36.pyc differ diff --git a/networks/__pycache__/CE2P_test.cpython-36.pyc b/networks/__pycache__/CE2P_test.cpython-36.pyc new file mode 100644 index 0000000..3c0498a Binary files /dev/null and b/networks/__pycache__/CE2P_test.cpython-36.pyc differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f57f139 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +wandb +tqdm \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..7914403 --- /dev/null +++ b/run.sh @@ -0,0 +1,38 @@ +#!/bin/bash +uname -a +#date +#env +date +CS_PATH='/mnt/data/humanparsing/LIP' +LR=3.0e-3 +WD=5e-4 +BS=8 +GPU_IDS=0,1,2,3 +RESTORE_FROM='./dataset/resnet101-imagenet.pth' +INPUT_SIZE='473,473' +SNAPSHOT_DIR='./snapshots' +DATASET='train' +NUM_CLASSES=20 +EPOCHS=150 + +if [[ ! -e ${SNAPSHOT_DIR} ]]; then + mkdir -p ${SNAPSHOT_DIR} +fi + + python -m torch.distributed.launch --nproc_per_node=4 --nnode=1 \ + --node_rank=0 --master_addr=222.32.33.224 --master_port 29500 train.py \ + --data-dir ${CS_PATH} \ + --random-mirror\ + --random-scale\ + --restore-from ${RESTORE_FROM}\ + --gpu ${GPU_IDS}\ + --learning-rate ${LR}\ + --weight-decay ${WD}\ + --batch-size ${BS} \ + --input-size ${INPUT_SIZE}\ + --snapshot-dir ${SNAPSHOT_DIR}\ + --dataset ${DATASET}\ + --num-classes ${NUM_CLASSES} \ + --epochs ${EPOCHS} + +# python evaluate.py diff --git a/run_evaluate.sh b/run_evaluate.sh new file mode 100644 index 0000000..d604a46 --- /dev/null +++ b/run_evaluate.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# CS_PATH='./dataset/LIP' +CS_PATH='/mnt/data/humanparsing/LIP' +BS=1 +GPU_IDS='1' +INPUT_SIZE='473,473' +SNAPSHOT_FROM='./snapshots/LIP_epoch_149.pth' +DATASET='val' +NUM_CLASSES=20 + +CUDA_VISIBLE_DEVICES=1 python evaluate.py --data-dir ${CS_PATH} \ + --gpu ${GPU_IDS} \ + --batch-size ${BS} \ + --input-size ${INPUT_SIZE}\ + --restore-from ${SNAPSHOT_FROM}\ + --dataset ${DATASET}\ + --num-classes ${NUM_CLASSES} diff --git a/run_evaluate_multiScale.sh b/run_evaluate_multiScale.sh new file mode 100644 index 0000000..9c36623 --- /dev/null +++ b/run_evaluate_multiScale.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# CS_PATH='./dataset/LIP' +CS_PATH='/mnt/data/humanparsing/LIP' +# CS_PATH='/mnt/data/humanparsing/CIHP' +BS=1 +GPU_IDS='1' +INPUT_SIZE='473,473' +SNAPSHOT_FROM='./snapshots/LIP_epoch_149.pth' +DATASET='val' +NUM_CLASSES=20 + +CUDA_VISIBLE_DEVICES=1 python evaluate_multi.py --data-dir ${CS_PATH} \ + --gpu ${GPU_IDS} \ + --batch-size ${BS} \ + --input-size ${INPUT_SIZE}\ + --restore-from ${SNAPSHOT_FROM}\ + --dataset ${DATASET}\ + --num-classes ${NUM_CLASSES} diff --git a/sage_train.sh b/sage_train.sh new file mode 100644 index 0000000..7be137f --- /dev/null +++ b/sage_train.sh @@ -0,0 +1,3 @@ +pip3 install -r requirements.txt + +python train_simplified.py \ No newline at end of file diff --git a/trace_model.py b/trace_model.py new file mode 100644 index 0000000..b3a0826 --- /dev/null +++ b/trace_model.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import torch.nn.functional as fun +import torchvision.transforms as T +import numpy as np +import cv2 +from PIL import Image +import os + +from networks.CDGNet import Res_Deeplab + +net = Res_Deeplab(22).cuda() +net.load_state_dict(torch.load('')) + +data = cv2.imread() +data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB) +data = cv2.resize(data, (512,512)) +data = torch.from_numpy(data[None]).to('cuda') + +def visualize_segmap(input, multi_channel=True, tensor_out=True, batch=0, agnostic = False) : + + if not agnostic: + palette = [ + 0, 0, 0, 128, 0, 0, 254, 0, 0, 0, 85, 0, 169, 0, 51, + 254, 85, 0, 0, 0, 85, 0, 119, 220, 85, 85, 0, 0, 85, 85, + 85, 51, 0, 52, 86, 128, 0, 128, 0, 0, 0, 254, 51, 169, 220, + 0, 254, 254, 85, 254, 169, 169, 254, 85, 254, 254, 0, 254, 169, 0, + 0,0,0,0,0,0,0,0,0 + ] + if agnostic: + palette = [ + 0, 0, 0, 128, 0, 0, 254, 0, 0, 0, 0, 0, 169, 0, 51, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 85, 85, 0, 0, 85, 85, + 0, 0, 0, 0, 0, 0, 0, 128, 0, 0, 0, 254, 0, 0, 0, + 0, 0, 0, 85, 254, 169, 169, 254, 85, 254, 254, 0, 254, 169, 0, + 0,0,0,0,0,0,0,0,0 + ] + input = input.detach() + if multi_channel : + input = ndim_tensor2im(input,batch=batch) + else : + input = input[batch][0].cpu() + input = np.asarray(input) + input = input.astype(np.uint8) + input = Image.fromarray(input, 'P') + input.putpalette(palette) + + if tensor_out : + trans = T.ToTensor() + return trans(input.convert('RGB')) + + return input + + +def ndim_tensor2im(image_tensor, imtype=np.uint8, batch=0): + image_numpy = image_tensor[batch].cpu().float().numpy() + result = np.argmax(image_numpy, axis=0) + return result.astype(imtype) + + + +class WrappedModel(nn.Module): + + def __init__(self, model): + + super().__init__() + self.model = model + self.mean = torch.Tensor([0.485, 0.456, 0.406]).cuda().reshape([1, 3, 1, 1]) + self.std = torch.Tensor([0.229, 0.224, 0.225]).cuda().reshape([1, 3, 1, 1]) + + @torch.inference_mode() + def forward(self, data, fp16=True): + + data = data.permute(0, 3, 1, 2).contiguous() + data = data.div(255).sub(self.mean).div_(self.std) + pred = self.model(data) + pred = fun.interpolate(pred[0][-1], (1024, 768), mode = 'bilinear') + + return pred.contiguous() + + +wrp_model = WrappedModel(net).cuda().eval() +torch.cuda.synchronize() + +with torch.no_grad(): + svd_out = wrp_model(data) + +torch.cuda.synchronize() +print(svd_out.shape) + +w1 = visualize_segmap(svd_out, tensor_out = False) +w1.save('w1.png') + +OUT_PATH = "out" +os.makedirs(OUT_PATH, exist_ok=True) + +wrp_model = wrp_model.half() + +with torch.inference_mode(), torch.jit.optimized_execution(True): + traced_script_module = torch.jit.trace(wrp_model, data) + traced_script_module = torch.jit.optimize_for_inference( + traced_script_module) + + +print(traced_script_module.code) +print(f"{OUT_PATH}/model.pt") +traced_script_module.save(f"{OUT_PATH}/model.pt") + +traced_script_module = torch.jit.load(f"{OUT_PATH}/model.pt") + +torch.cuda.synchronize() +with torch.no_grad(): + o = traced_script_module(data) +torch.cuda.synchronize() +print(o.shape) +w2 = visualize_segmap(o, tensor_out = False) +w2.save('w2.png') + + + diff --git a/train.py b/train.py new file mode 100644 index 0000000..65438ad --- /dev/null +++ b/train.py @@ -0,0 +1,359 @@ +import argparse + +import torch +torch.multiprocessing.set_start_method("spawn", force=True) +from torch.utils import data +import numpy as np +from PIL import Image +import torch.optim as optim +import torchvision.utils as vutils +import torch.backends.cudnn as cudnn +import os +import os.path as osp +from tqdm import tqdm +from networks.CDGNet import Res_Deeplab +from dataset.datasets import LIPDataSet +from dataset.target_generation import generate_edge +import torchvision.transforms as transforms +import timeit +import torch.distributed as dist +import wandb +#from tensorboardX import SummaryWriter +from utils.utils import decode_parsing, inv_preprocess +from utils.criterion import CriterionAll +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler + +from utils.miou import compute_mean_ioU +from evaluate import get_ccihp_pallete, valid + +start = timeit.default_timer() + +BATCH_SIZE = 2 +DATA_DIRECTORY = '/home/vrushank/Spyne/HR-Viton/CCIHP' +#DATA_LIST_PATH = './dataset/list/cityscapes/train.lst' +IGNORE_LABEL = 255 +INPUT_SIZE = '32, 32' +LEARNING_RATE = 3e-4 +MOMENTUM = 0.9 +NUM_CLASSES = 22 +POWER = 0.9 +RANDOM_SEED = 1234 +RESTORE_FROM= '/home/vrushank/Spyne/HR-Viton/CCIHP/resnet101-imagenet.pth' +SAVE_NUM_IMAGES = 2 +SAVE_PRED_EVERY = 10000 +SNAPSHOT_DIR = './snapshots/' +WEIGHT_DECAY = 0.0005 +GPU_IDS = '0' + + + + +def reduce_loss(tensor, rank, world_size): + with torch.no_grad(): + dist.reduce(tensor, dst=0) + if rank == 0: + tensor /= world_size + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def get_arguments(): + """Parse all the arguments provided from the CLI. + + Returns: + A list of parsed arguments. + """ + parser = argparse.ArgumentParser(description="CE2P Network") + parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, + help="Number of images sent to the network in one step.") + parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, + help="Path to the directory containing the dataset.") + parser.add_argument("--dataset", type=str, default='train', choices=['train', 'val', 'trainval', 'test'], + help="Path to the file listing the images in the dataset.") + parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, + help="The index of the label to ignore during the training.") + parser.add_argument("--input-size", type=str, default=INPUT_SIZE, + help="Comma-separated string with height and width of images.") + parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, + help="Base learning rate for training with polynomial decay.") + parser.add_argument("--momentum", type=float, default=MOMENTUM, + help="Momentum component of the optimiser.") + parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, + help="Number of classes to predict (including background).") + parser.add_argument("--start-iters", type=int, default=0, + help="Number of classes to predict (including background).") + parser.add_argument("--power", type=float, default=POWER, + help="Decay parameter to compute the learning rate.") + parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY, + help="Regularisation parameter for L2-loss.") + parser.add_argument("--random-mirror", action="store_true", + help="Whether to randomly mirror the inputs during the training.") + parser.add_argument("--random-scale", action="store_true", + help="Whether to randomly scale the inputs during the training.") + parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, + help="Random seed to have reproducible results.") + parser.add_argument("--restore-from", type=str, default=RESTORE_FROM, + help="Where restore model parameters from.") + parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES, + help="How many images to save.") + parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR, + help="Where to save snapshots of the model.") + parser.add_argument("--gpu", type=str, default=GPU_IDS, + help="choose gpu device.") + parser.add_argument("--start-epoch", type=int, default=0, + help="choose the number of recurrence.") + parser.add_argument("--epochs", type=int, default=150, + help="choose the number of recurrence.") + parser.add_argument('--local_rank', type=int, default = 0, help="local gpu id") + # os.environ['MASTER_ADDR'] = '202.30.29.226' + # os.environ['MASTER_PORT'] = '8888' + return parser.parse_args() + + +args = get_arguments() + + +def lr_poly(base_lr, iter, max_iter, power): + return base_lr * ((1 - float(iter) / max_iter) ** (power)) + + +def adjust_learning_rate(optimizer, i_iter, total_iters): + """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs""" + lr = lr_poly(args.learning_rate, i_iter, total_iters, args.power) + optimizer.param_groups[0]['lr'] = lr + return lr + + +def adjust_learning_rate_pose(optimizer, epoch): + decay = 0 + if epoch + 1 >= 135: + decay = 0.05 + elif epoch + 1 >= 100: + decay = 0.1 + elif epoch + 1 >= 60: + decay = 0.25 + elif epoch + 1 >= 40: + decay = 0.5 + else: + decay = 1 + + lr = args.learning_rate * decay + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return lr + + +def set_bn_eval(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm') != -1: + m.eval() + + +def set_bn_momentum(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm') != -1 or classname.find('InPlaceABN') != -1: + m.momentum = 0.0003 + + +def main(): + """Create the model and start the training.""" + + if not os.path.exists(args.snapshot_dir): + os.makedirs(args.snapshot_dir) + + #writer = SummaryWriter(args.snapshot_dir) + gpus = [int(i) for i in args.gpu.split(',')] + if not args.gpu == 'None': + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu + + h, w = map(int, args.input_size.split(',')) + input_size = [h, w] + + cudnn.enabled = True + # cudnn related setting + cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.enabled = True + + dist.init_process_group( backend='nccl', init_method='env://' ) + torch.cuda.set_device( int(args.local_rank) ) + gloabl_rank = dist.get_rank() + world_size = dist.get_world_size() + print( world_size ) + #if world_size == 1: + # return + #dist.barrier() + print('Loading model') + deeplab = Res_Deeplab(num_classes=args.num_classes) + + saved_state_dict = torch.load(args.restore_from) + new_params = deeplab.state_dict().copy() + for i in saved_state_dict: + i_parts = i.split('.') + # print(i_parts) + if not i_parts[0] == 'fc': + new_params['.'.join(i_parts[0:])] = saved_state_dict[i] + + deeplab.load_state_dict(new_params) + print('Model Loaded') + deeplab.cuda() + #model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(deeplab) + model = DDP(deeplab, device_ids=[args.local_rank], output_device=args.local_rank ) + print(model) + + criterion = CriterionAll() + # criterion = DataParallelCriterion(criterion) + criterion.cuda() + + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + transform = transforms.Compose([ + transforms.ToTensor(), + normalize, + ]) + lipDataset = LIPDataSet(args.data_dir, args.dataset, crop_size=[128, 128], transform=transform) + sampler = DistributedSampler(lipDataset) + trainloader = data.DataLoader(lipDataset, + batch_size=args.batch_size, shuffle=False, + sampler = sampler, + num_workers=4, + pin_memory=True) + lip_dataset = LIPDataSet(args.data_dir, 'val', crop_size=input_size, transform=transform) + num_samples = len(lip_dataset) + + valloader = data.DataLoader(lip_dataset, batch_size=args.batch_size * len(gpus), + shuffle=False, pin_memory=True) + + optimizer = optim.Adam( + model.parameters(), + lr=args.learning_rate, + betas= (0.5, 0.999), + weight_decay=args.weight_decay, + ) + scaler = torch.cuda.amp.grad_scaler.GradScaler() + optimizer.zero_grad(set_to_none = True) + total_iters = args.epochs * len(trainloader) + + # path = osp.join( args.snapshot_dir, 'model_LIP'+'.pth') + # if os.path.exists( path ): + # checkpoint = torch.load(path) + # model.load_state_dict(checkpoint['model']) + # optimizer.load_state_dict(checkpoint['optimizer']) + # epoch = checkpoint['epoch'] + # print( epoch ) + # args.start_epoch = epoch + # print( 'Load model first!') + # else: + # print( 'No model exits from beginning!') + + model.train() + for epoch in range(args.start_epoch, args.epochs): + sampler.set_epoch(epoch) + + loop = tqdm(trainloader, position = 0, leave = True) + for i_iter, batch in enumerate(loop): + i_iter += len(trainloader) * epoch + lr = adjust_learning_rate(optimizer, i_iter, total_iters) + + images, labels, hgt, wgt, hwgt, _= batch + labels = labels.cuda(non_blocking=True) + edges = generate_edge(labels) + labels = labels.type(torch.cuda.LongTensor) + edges = edges.type(torch.cuda.LongTensor) + hgt = hgt.float().cuda(non_blocking=True) + wgt = wgt.float().cuda(non_blocking=True) + hwgt = hwgt.float().cuda(non_blocking=True) + optimizer.zero_grad(set_to_none = True) + #[[seg0, seg1, seg2], [edge],[fea_h1,fea_w1]] + with torch.cuda.amp.autocast_mode.autocast(): + preds = model(images) + loss = criterion(preds, [labels, edges],[hgt,wgt,hwgt]) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + torch.cuda.synchronize() + + reduce_loss( loss, gloabl_rank, world_size ) + + if i_iter % 500 == 0: + + wandb.log({'Learning rate': lr, 'Loss': loss.data.cpu().numpy()}) + #writer.add_scalar('learning_rate', lr, i_iter) + #writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter) + + if i_iter % 2000 == 0: + + images_inv = inv_preprocess(images, args.save_num_images) + labels_colors = decode_parsing(labels, args.save_num_images, args.num_classes, is_pred=False) + edges_colors = decode_parsing(edges, args.save_num_images, 2, is_pred=False) + #[[seg0, seg1, seg2], [edge],[fea_h1,fea_w1]] + if isinstance(preds, list): + preds = preds[0] + preds_colors = decode_parsing(preds[-1], args.save_num_images, args.num_classes, is_pred=True) + pred_edges = decode_parsing(preds[-1], args.save_num_images, 2, is_pred=True) + + img = vutils.make_grid(images_inv*255, normalize=False, scale_each=True) + lab = vutils.make_grid(labels_colors, normalize=False, scale_each=True) + pred = vutils.make_grid(preds_colors, normalize=False, scale_each=True) + edge = vutils.make_grid(edges_colors, normalize=False, scale_each=True) + pred_edge = vutils.make_grid(pred_edges, normalize=False, scale_each=True) + + img_wb = wandb.Image(img.to(torch.uint8).cpu().numpy().transpose((1, 2, 0)), caption = f'{i_iter}') + label_wb = wandb.Image(lab.to(torch.uint8).cpu().numpy().transpose((1, 2, 0)), caption = f'{i_iter}') + pred_wb = wandb.Image(pred.to(torch.uint8).cpu().numpy().transpose((1, 2, 0)), caption = f'{i_iter}') + edge_wb = wandb.Image(edge.to(torch.uint8).cpu().numpy().transpose((1, 2, 0)), caption = f'{i_iter}') + pred_edge_wb = wandb.Image(pred_edge.to(torch.uint8).cpu().numpy().transpose((1, 2, 0)), caption = f'{i_iter}') + + wandb.log({ + 'Image': img_wb, + 'Target': label_wb, + 'Pred': pred_wb, + 'Target Edge': edge_wb, + 'Predicted Edge': pred_edge_wb + }) + + + if gloabl_rank == 0 and i_iter % 500 == 0 : + print('Epoch:{} iter = {} of {} completed, loss = {}'.format(epoch, i_iter, total_iters, loss.data.cpu().numpy())) + + if epoch > 140 and gloabl_rank == 0: + torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'CCIHP_epoch_' + str(epoch) + '.pth')) + + if epoch % 5 == 0 and gloabl_rank == 0: + path = osp.join(args.snapshot_dir, 'CCIHP_epoch_' + str(epoch) + '.pth') + state = { 'model': model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch': epoch } + torch.save(state, path) + + if epoch % 2 == 0: + num_samples = len(lip_dataset) + parsing_preds, scales, centers = valid(model, valloader, input_size, num_samples, len(gpus)) + output_parsing = parsing_preds + mIoU, pixel_acc, mean_acc = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size) + print(mIoU) + palette = get_ccihp_pallete() + wandb.log({'Valid MIou': mIoU, 'Valid Pixel Accuracy': pixel_acc, 'Valid Mean Accuracy': mean_acc}) + for i in range(10): + output_image = Image.fromarray( output_parsing[i] ) + output_image.putpalette( palette ) + output_label_wb = wandb.Image(output_image) + wandb.log({'Valid pred': output_label_wb}) + + + + end = timeit.default_timer() + print(end - start, 'seconds') + + +if __name__ == '__main__': + + wandb.init(project="Viton_Segmentation_CDGNet",config={"name": "Virtual Try-on"}) + main() diff --git a/train_simplified.py b/train_simplified.py new file mode 100644 index 0000000..eaf9623 --- /dev/null +++ b/train_simplified.py @@ -0,0 +1,299 @@ +import torch +import torch.nn as nn +import torchvision.transforms as T +import torchvision.utils as vutils +import torch.nn.functional as fun +from torch.utils.data import Dataset, DataLoader +import numpy as np +from PIL import Image +import os +from tqdm import tqdm +import argparse +import wandb + +from networks.CDGNet import Res_Deeplab +from dataset.datasets import LIPDataSet, LIPDataValSet +from dataset.target_generation import generate_edge +from utils.utils import decode_parsing, inv_preprocess, AverageMeter +from utils.criterion import CriterionAll +from utils.miou import compute_mean_ioU +from evaluate import get_ccihp_pallete, valid + + +def get_arguments(): + """Parse all the arguments provided from the CLI. + + Returns: + A list of parsed arguments. + """ + + BATCH_SIZE = 8 + try: + DATA_DIRECTORY = os.environ['SM_CHANNEL_TRAIN'] + except KeyError: + DATA_DIRECTORY = '/home/vrushank/Spyne/HR-Viton/CCIHP' + + IGNORE_LABEL = 255 + INPUT_SIZE = '512, 512' + LEARNING_RATE = 3e-4 + MOMENTUM = 0.9 + NUM_CLASSES = 22 + POWER = 0.9 + RANDOM_SEED = 1234 + try: + RESTORE_FROM= 'resnet101-imagenet.pth' + except FileNotFoundError: + RESTORE_FROM = '/home/vrushank/Spyne/HR-Viton/CCIHP/resnet101-imagenet.pth' + SAVE_NUM_IMAGES = 2 + SAVE_PRED_EVERY = 10000 + try: + SNAPSHOT_DIR = '/opt/ml/checkpoints/' + except KeyError: + SNAPSHOT_DIR = 'snapshots/' + + WEIGHT_DECAY = 0.0005 + + parser = argparse.ArgumentParser(description="CDG Network") + parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, + help="Number of images sent to the network in one step.") + parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, + help="Path to the directory containing the dataset.") + parser.add_argument("--dataset", type=str, default='train', choices=['train', 'val', 'trainval', 'test'], + help="Path to the file listing the images in the dataset.") + parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, + help="The index of the label to ignore during the training.") + parser.add_argument("--input-size", type=str, default=INPUT_SIZE, + help="Comma-separated string with height and width of images.") + parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, + help="Base learning rate for training with polynomial decay.") + parser.add_argument("--momentum", type=float, default=MOMENTUM, + help="Momentum component of the optimiser.") + parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, + help="Number of classes to predict (including background).") + parser.add_argument("--start-iters", type=int, default=0, + help="Number of classes to predict (including background).") + parser.add_argument("--power", type=float, default=POWER, + help="Decay parameter to compute the learning rate.") + parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY, + help="Regularisation parameter for L2-loss.") + parser.add_argument("--random-mirror", action="store_true", + help="Whether to randomly mirror the inputs during the training.") + parser.add_argument("--random-scale", action="store_true", + help="Whether to randomly scale the inputs during the training.") + parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, + help="Random seed to have reproducible results.") + parser.add_argument("--restore-from", type=str, default=RESTORE_FROM, + help="Where restore model parameters from.") + parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES, + help="How many images to save.") + parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR, + help="Where to save snapshots of the model.") + parser.add_argument("--start-epoch", type=int, default=0, + help="choose the number of recurrence.") + parser.add_argument("--num_epochs", type=int, default=150, + help="choose the number of recurrence.") + + return parser.parse_args() + + + +def lr_poly(base_lr, iter, max_iter, power): + return base_lr * ((1 - float(iter) / max_iter) ** (power)) + + +def adjust_learning_rate(optimizer, i_iter, total_iters): + """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs""" + args = get_arguments() + lr = lr_poly(args.learning_rate, i_iter, total_iters, args.power) + optimizer.param_groups[0]['lr'] = lr + return lr + + +def set_bn_eval(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm') != -1: + m.eval() + + +def set_bn_momentum(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm') != -1 or classname.find('InPlaceABN') != -1: + m.momentum = 0.0003 + + +def train(loader, valid_loader, model, opt, scaler, criterion, total_iters, epoch, args): + + + model.train() + loop = tqdm(loader, position = 0, leave = True) + loss_ = AverageMeter() + for idx, batch in enumerate(loop): + + idx += len(loader) * epoch + lr = adjust_learning_rate(opt, idx, total_iters) + + imgs, labels, hgt, wgt, hwgt, _ = batch + imgs, labels = imgs.cuda(non_blocking = True), labels.cuda(non_blocking = True) + edges = generate_edge(labels) + labels = labels.type(torch.cuda.LongTensor) #Check LongStorage which torch.cuda recommended + edges = edges.type(torch.cuda.LongTensor) + hgt = hgt.float().cuda(non_blocking = True) + wgt = wgt.float().cuda(non_blocking = True) + hwgt = hwgt.float().cuda(non_blocking = True) + opt.zero_grad(set_to_none = True) + + with torch.cuda.amp.autocast_mode.autocast(): + + preds = model(imgs) + loss = criterion(preds, [labels, edges], [hgt, wgt, hwgt]) + + loss_.update(loss.detach(), imgs.size(0)) + scaler.scale(loss).backward() + scaler.step(opt) + scaler.update() + + if idx % 500 == 0: + + wandb.log({'Training Loss': loss_.avg, 'Learning Rate': lr}) + print(f'Epoch [{epoch}/{args.num_epochs}] iter [{idx}/{len(loader)}] Learning Rate: {lr} Loss: {loss_.avg}') + + if idx % 2000 == 0: + + #print(imgs.shape) + imgs_inv = inv_preprocess(imgs, args.save_num_images) + labels_colors = decode_parsing(labels, args.save_num_images, is_pred = False) + edges_colors = decode_parsing(edges, args.save_num_images, is_pred = False) + #if isinstance(preds, list): + # preds = preds[0] + pred = fun.interpolate(preds[0][-1],(512,512), mode='bilinear', align_corners=True ) + pred_edge = fun.interpolate(preds[1][-1],(512,512), mode='bilinear', align_corners=True ) + preds_colors = decode_parsing(pred, args.save_num_images, is_pred = True) + #Check the position of edges in the list + pred_edges_colors = decode_parsing(pred_edge, args.save_num_images, 2, is_pred = True) + + #preds_colors = fun.interpolate(preds_colors, (512, 512), mode = 'bilinear', align_corners = True) + #pred_edges_colors = fun.interpolate(pred_edges_colors, (512, 512), mode = 'bilinear', align_corners = True) + + img = vutils.make_grid(imgs_inv*255, normalize = False, scale_each = True) + lab = vutils.make_grid(labels_colors, normalize = False, scale_each = True) + pred = vutils.make_grid(preds_colors, normalize = False, scale_each = True) + edge = vutils.make_grid(edges_colors, normalize = False, scale_each = True) + pred_edge = vutils.make_grid(pred_edges_colors, normalize = False, scale_each = True) + + img_wb = wandb.Image(img.to(torch.uint8).cpu().numpy().transpose((1,2,0))) + labels_wb = wandb.Image(lab.to(torch.uint8).cpu().numpy().transpose((1,2,0))) + pred_wb = wandb.Image(pred.to(torch.uint8).cpu().numpy().transpose((1,2,0))) + edge_wb = wandb.Image(edge.to(torch.uint8).cpu().numpy().transpose((1,2,0))) + pred_edge_wb = wandb.Image(pred_edge.to(torch.uint8).cpu().numpy().transpose((1,2,0))) + + wandb.log({ + 'Images': img_wb, + 'Target': labels_wb, + 'Pred': pred_wb, + 'Edges': edge_wb, + 'Pred Edges': pred_edge_wb + }) + + if epoch % 2 == 0: + + num_samples = len(valid_loader) * args.batch_size + parsing_preds, img, scales, centers = valid(model, valid_loader, [512, 512], num_samples, 1) + if isinstance(parsing_preds, np.ndarray): + output_parsing = parsing_preds.copy() + if isinstance(parsing_preds, torch.Tensor): + output_parsing = parsing_preds.clone() + else: + output_parsing = parsing_preds + + mIoU, pixel_acc, mean_acc = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, [512, 512]) + print('Printing MIoU Values...') + for k, v in mIoU.items(): + print(f'{k}: {v}') + print(f'Pixel Accuracy: {pixel_acc}') + print(f'Mean Accuracy: {mean_acc}') + palette = get_ccihp_pallete() + wandb.log({ + 'Valid MIoU': mIoU, + 'Valid Pixel Accuracy': pixel_acc, + 'Valid Mean Accuracy': mean_acc + }) + print('Values Logged on wandb') + for i in range(10): + print('Inside Loop') + #ip_img = Image.fromarray(img[i]) + op_img = Image.fromarray(output_parsing[i]) + op_img.putpalette(palette) + #ip_img_wb = wandb.Image(ip_img) + op_label_wb = wandb.Image(op_img) + wandb.log({'Valid Pred': op_label_wb}) + + + + +def main(): + + transform = T.Compose([ + T.ToTensor(), + T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]) + ]) + args = get_arguments() + dataset = LIPDataSet(args.data_dir, args.dataset, [512, 512], transform = transform) + train_loader = DataLoader(dataset, + batch_size = args.batch_size, + shuffle = True, + num_workers = 4, + pin_memory = True) + + val_dataset = LIPDataValSet(args.data_dir, transform = transform) + valid_loader = DataLoader(val_dataset, + batch_size = args.batch_size, + shuffle = False, + pin_memory= True) + + model = Res_Deeplab(num_classes = args.num_classes) + print("Loading Model...") + ckpt = torch.load(os.path.join(os.getcwd(), args.restore_from)) + new_params = model.state_dict().copy() + + for i in ckpt: + i_parts = i.split('.') + if not i_parts[0] == 'fc': + new_params['.'.join(i_parts[0:])] = ckpt[i] + + model.load_state_dict(new_params) + model.cuda() + print('Model Loaded.') + + criterion = CriterionAll().cuda() + opt = torch.optim.SGD( + model.parameters(), + lr = args.learning_rate, + momentum = 0.9, + weight_decay = args.weight_decay, + nesterov = True + ) + + scaler = torch.cuda.amp.grad_scaler.GradScaler() + total_iters = len(train_loader) * args.num_epochs + + for epoch in range(args.num_epochs): + + train(train_loader, + valid_loader, + model, + opt, + scaler, + criterion, + total_iters, + epoch, + args + ) + + +if __name__ == '__main__': + + wandb.init(project = 'Human Parsing') + main() + + + \ No newline at end of file diff --git a/try.py b/try.py new file mode 100644 index 0000000..f4e19b2 --- /dev/null +++ b/try.py @@ -0,0 +1,148 @@ +import torch +import torchvision.transforms as T +import torch.nn.functional as fun +import torchvision.transforms as T +import cv2 +import numpy as np +from PIL import Image +import os +import matplotlib.pyplot as plt +from concurrent.futures import ThreadPoolExecutor +from glob import glob +CUDA_LAUNCH_BLOCKING = 1 +from networks.CDGNet import Res_Deeplab +from utils.utils import decode_parsing, decode_parsing_agnostic, inv_preprocess + +imgs = glob('/home/ubuntu/Vrushank/CDGNet/VITON-data/train/image/*') +print(len(imgs)) + +model = Res_Deeplab(22).cuda() +model.load_state_dict(torch.load('/home/ubuntu/Vrushank/CDGNet/snapshots/model_latest.pth')) +print('Done') +model.eval() + +out_dir = '/home/ubuntu/Vrushank/CDGNet/VITON-data/train/image-parse-agnosticv3.2' +#out_dir1 = '/home/ubuntu/Vrushank/CDGNet/VITON-data/train/image-parse-agnostic' +#out_dir2 = '/home/ubuntu/Vrushank/CDGNet/VITON-data/parse-down' + +if not os.path.exists(out_dir): + os.makedirs(out_dir, exist_ok = True) + +#if not os.path.exists(out_dir1): +# os.makedirs(out_dir1, exist_ok = True) + +#if not os.path.exists(out_dir2): +# os.makedirs(out_dir2, exist_ok = True) +def visualize_segmap(input, multi_channel=True, tensor_out=True, batch=0, agnostic = False) : + + if not agnostic: + palette = [ + 0, 0, 0, 128, 0, 0, 254, 0, 0, 0, 85, 0, 169, 0, 51, + 254, 85, 0, 0, 0, 85, 0, 119, 220, 85, 85, 0, 0, 85, 85, + 85, 51, 0, 52, 86, 128, 0, 128, 0, 0, 0, 254, 51, 169, 220, + 0, 254, 254, 85, 254, 169, 169, 254, 85, 254, 254, 0, 254, 169, 0, + 0,0,0,0,0,0,0,0,0 + ] + if agnostic: + palette = [ + 0, 0, 0, 128, 0, 0, 254, 0, 0, 0, 0, 0, 169, 0, 51, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 85, 85, 0, 0, 85, 85, + 0, 0, 0, 0, 0, 0, 0, 128, 0, 0, 0, 254, 0, 0, 0, + 0, 0, 0, 85, 254, 169, 169, 254, 85, 254, 254, 0, 254, 169, 0, + 0,0,0,0,0,0,0,0,0 + ] + input = input.detach() + if multi_channel : + input = ndim_tensor2im(input,batch=batch) + else : + input = input[batch][0].cpu() + input = np.asarray(input) + input = input.astype(np.uint8) + input = Image.fromarray(input, 'P') + input.putpalette(palette) + + if tensor_out : + trans = T.ToTensor() + return trans(input.convert('RGB')) + + return input + + +def ndim_tensor2im(image_tensor, imtype=np.uint8, batch=0): + image_numpy = image_tensor[batch].cpu().float().numpy() + result = np.argmax(image_numpy, axis=0) + return result.astype(imtype) + + +transform = T.Compose([ + T.Resize((512, 512)), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) +]) + +def get_outputs(p): + + #og = cv2.imread(p) + #og = cv2.resize(og, (768, 1024)) + name = p.split('/')[-1].split('.')[0] + img = Image.open(p).convert('RGB') + w, h = img.size # + img = transform(img) + img = img.unsqueeze(0).cuda() + + with torch.no_grad(): + + preds = model(img) + + #img_inv = inv_preprocess(img, 1) + pred = fun.interpolate(preds[0][-1], (1024, 768), mode = 'bilinear') + #p1 = pred.squeeze(0).cpu().numpy() + #print(p1.shape) + + label = visualize_segmap(pred, tensor_out=False, agnostic=True) + #print(label.getbands()) + #arr = np.array(label) + #print(arr.max()) + #print(arr) + label.save(f'{out_dir}/{name}.png') + #y = label.cpu().numpy().transpose(2,1,0) + #y = y * 255.0 + #print(type(y)) + #print(y.shape) + #print(y.max()) + #cv2.imwrite('y.png', y) + #y = Image.fromarray(y) + + + #print(y.getbands()) + #y.save('y.png') + #label = decode_parsing(pred, 1, is_pred = True) + #label_ag = decode_parsing_agnostic(pred, 1, is_pred = True) + + #img1 = img_inv.squeeze(0).to(torch.uint8).cpu().numpy().transpose((1,2,0)) + #pred = label.squeeze(0).to(torch.uint8).cpu().numpy().transpose((1,2,0)) + #pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR) + #pred = cv2.resize(pred, (w, h)) + #cv2.imwrite('x.png', pred) + #cv2.imwrite(f'{out_dir}/{name}.png', pred) + + #pred1 = label_ag.squeeze(0).to(torch.uint8).cpu().numpy().transpose((1,2,0)) + #pred1 = cv2.cvtColor(pred1, cv2.COLOR_RGB2BGR) + #pred1 = cv2.resize(pred1, (w, h)) + #cv2.imwrite(f'{out_dir1}/{name}.png', pred1) + #pred_gs = np.argmax(pred1, axis = -1) + #pred_gs = (pred_gs / 22) * 255 + #pred_gs = np.expand_dims(pred_gs, axis = -1) + #pred_gs = cv2.resize(pred_gs, (768, 1024)) + #cv2.imwrite(f'{out_dir1}/{name}.png', pred_gs) + #pred_gs_down = cv2.resize(pred_gs, (384, 512)) + #cv2.imwrite(f'{out_dir2}/{name}.png', pred_gs_down) + #print(pred1.shape) + #pred1 = cv2.cvtColor(pred1, cv2.COLOR_RGB2GRAY) + #res = np.concatenate((og, pred1), axis = 1) + +#for p in imgs: +# get_outputs(p) +with ThreadPoolExecutor() as executor: + + executor.map(get_outputs, imgs) \ No newline at end of file diff --git a/utils/.DS_Store b/utils/.DS_Store new file mode 100644 index 0000000..3bbdb52 Binary files /dev/null and b/utils/.DS_Store differ diff --git a/utils/ImgTransforms.py b/utils/ImgTransforms.py new file mode 100644 index 0000000..c2ba07b --- /dev/null +++ b/utils/ImgTransforms.py @@ -0,0 +1,394 @@ +import os +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image, ImageFilter +import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw +from torchvision import transforms +# from torchvision import models,datasets +# import matplotlib.pyplot as plt +import random +import cv2 + +RESAMPLE_MODE=Image.BICUBIC + +# cat=cv2.imread('d:/testpy/839_482127.jpg') + +random_mirror = True + +def ShearX(img, v): # [-0.3, 0.3] + assert -0.3 <= v <= 0.3 + if random_mirror and random.random() > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, v, 0, 0, 1, 0), + RESAMPLE_MODE) + +def ShearY(img, v): # [-0.3, 0.3] + assert -0.3 <= v <= 0.3 + if random_mirror and random.random() > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, 0, v, 1, 0), + RESAMPLE_MODE) + +def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + assert -0.45 <= v <= 0.45 + if random_mirror and random.random() > 0.5: + v = -v + v = v * img.size[0] + return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0), + RESAMPLE_MODE) + +def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + assert -0.45 <= v <= 0.45 + if random_mirror and random.random() > 0.5: + v = -v + v = v * img.size[1] + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v), + RESAMPLE_MODE) + +def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + assert 0 <= v + if random.random() > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0), + RESAMPLE_MODE) + + +def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + assert 0 <= v + if random.random() > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v), + RESAMPLE_MODE) + +def Rotate(img, v): # [-30, 30] + assert -30 <= v <= 30 + if random_mirror and random.random() > 0.5: + v = -v + return img.rotate(v) + +def AutoContrast(img, _): + return PIL.ImageOps.autocontrast(img,1) + +def Invert(img, _): + return PIL.ImageOps.invert(img) + +def Equalize(img, _): + return PIL.ImageOps.equalize(img) + +def Flip(img, _): # not from the paper + return PIL.ImageOps.mirror(img) + +def Solarize(img, v): # [0, 256] + assert 0 <= v <= 256 + return PIL.ImageOps.solarize(img, v) + +def SolarizeAdd(img, addition=0, threshold=128): + img_np = np.array(img).astype(np.int) + img_np = img_np + addition + img_np = np.clip(img_np, 0, 255) + img_np = img_np.astype(np.uint8) + img = Image.fromarray(img_np) + return PIL.ImageOps.solarize(img, threshold) + +def Posterize(img, v): # [4, 8] + #assert 4 <= v <= 8 + v = int(v) + return PIL.ImageOps.posterize(img, v) + +def Contrast(img, v): # [0.1,1.9] + assert 0.1 <= v <= 1.9 + return PIL.ImageEnhance.Contrast(img).enhance(v) + +def Color(img, v): # [0.1,1.9] + assert 0.1 <= v <= 1.9 + return PIL.ImageEnhance.Color(img).enhance(v) + +def Brightness(img, v): # [0.1,1.9] + assert 0.1 <= v <= 1.9 + return PIL.ImageEnhance.Brightness(img).enhance(v) + +def Sharpness(img, v): # [0.1,1.9] + assert 0.1 <= v <= 1.9 + return PIL.ImageEnhance.Sharpness(img).enhance(v) + +def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] + # assert 0 <= v <= 20 + if v < 0: + return img + w, h = img.size + x0 = np.random.uniform(w) + y0 = np.random.uniform(h) + + x0 = int(max(0, x0 - v / 2.)) + y0 = int(max(0, y0 - v / 2.)) + x1 = min(w, x0 + v) + y1 = min(h, y0 + v) + + xy = (x0, y0, x1, y1) + color = (125, 123, 114) + # color = (0, 0, 0) + img = img.copy() + PIL.ImageDraw.Draw(img).rectangle(xy, color) + return img + +def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] + assert 0.0 <= v <= 0.2 + if v <= 0.: + return img + + v = v * img.size[0] + return CutoutAbs(img, v) + +def TranslateYAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + assert 0 <= v <= 10 + if random.random() > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v), + resample=RESAMPLE_MODE) + + +def TranslateXAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + assert 0 <= v <= 10 + if random.random() > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0), + resample=RESAMPLE_MODE) + +def Posterize2(img, v): # [0, 4] + assert 0 <= v <= 4 + v = int(v) + return PIL.ImageOps.posterize(img, v) + +def SamplePairing(imgs): # [0, 0.4] + def f(img1, v): + i = np.random.choice(len(imgs)) + img2 = Image.fromarray(imgs[i]) + return Image.blend(img1, img2, v) + + return f + +def augment_list(for_autoaug=True): # 16 oeprations and their ranges + l = [ + (ShearX, -0.3, 0.3), # 0 + (ShearY, -0.3, 0.3), # 1 + (TranslateX, -0.45, 0.45), # 2 + (TranslateY, -0.45, 0.45), # 3 + (Rotate, -30, 30), # 4 + (AutoContrast, 0, 1), # 5 + (Invert, 0, 1), # 6 + (Equalize, 0, 1), # 7 + (Solarize, 0, 256), # 8 + (Posterize, 4, 8), # 9 + (Contrast, 0.1, 1.9), # 10 + (Color, 0.1, 1.9), # 11 + (Brightness, 0.1, 1.9), # 12 + (Sharpness, 0.1, 1.9), # 13 + (Cutout, 0, 0.2), # 14 + # (SamplePairing(imgs), 0, 0.4), # 15 + ] + if for_autoaug: + l += [ + (CutoutAbs, 0, 20), # compatible with auto-augment + (Posterize2, 0, 4), # 9 + (TranslateXAbs, 0, 10), # 9 + (TranslateYAbs, 0, 10), # 9 + ] + return l + +augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()} + +PARAMETER_MAX = 10 + + +def float_parameter(level, maxval): + return float(level) * maxval / PARAMETER_MAX + + +def int_parameter(level, maxval): + return int(float_parameter(level, maxval)) + +def rand_augment_list(): # 16 oeprations and their ranges + l = [ + (AutoContrast, 0, 1), + (Equalize, 0, 1), + (Invert, 0, 1), + (Rotate, 0, 30), + (Posterize, 0, 4), + (Solarize, 0, 256), + (SolarizeAdd, 0, 110), + (Color, 0.1, 1.9), + (Contrast, 0.1, 1.9), + (Brightness, 0.1, 1.9), + (Sharpness, 0.1, 1.9), + (ShearX, 0., 0.3), + (ShearY, 0., 0.3), + (CutoutAbs, 0, 40), + (TranslateXabs, 0., 100), + (TranslateYabs, 0., 100), + ] + + return l + +def autoaug2fastaa(f): + def autoaug(): + mapper = defaultdict(lambda: lambda x: x) + mapper.update({ + 'ShearX': lambda x: float_parameter(x, 0.3), + 'ShearY': lambda x: float_parameter(x, 0.3), + 'TranslateX': lambda x: int_parameter(x, 10), + 'TranslateY': lambda x: int_parameter(x, 10), + 'Rotate': lambda x: int_parameter(x, 30), + 'Solarize': lambda x: 256 - int_parameter(x, 256), + 'Posterize2': lambda x: 4 - int_parameter(x, 4), + 'Contrast': lambda x: float_parameter(x, 1.8) + .1, + 'Color': lambda x: float_parameter(x, 1.8) + .1, + 'Brightness': lambda x: float_parameter(x, 1.8) + .1, + 'Sharpness': lambda x: float_parameter(x, 1.8) + .1, + 'CutoutAbs': lambda x: int_parameter(x, 20) + }) + + def low_high(name, prev_value): + _, low, high = get_augment(name) + return float(prev_value - low) / (high - low) + + policies = f() + new_policies = [] + for policy in policies: + new_policies.append([(name, pr, low_high(name, mapper[name](level))) for name, pr, level in policy]) + return new_policies + + return autoaug + +# @autoaug2fastaa +def autoaug_imagenet_policies(): + return [ + # [('Posterize2', 0.4, 8), ('Rotate', 0.6, 9)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + #[('Posterize2', 0.6, 7), ('Posterize2', 0.6, 6)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + # [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], + [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], + [('Posterize2', 0.8, 5), ('Equalize', 1.0, 2)], + # [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], + [('Equalize', 0.6, 8), ('Posterize2', 0.4, 6)], + # [('Rotate', 0.8, 8), ('Color', 0.4, 0)], + # [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], + [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + # [('Rotate', 0.8, 8), ('Color', 1.0, 0)], + [('Color', 0.8, 8), ('Solarize', 0.8, 7)], + [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], + # [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], + [('Color', 0.4, 0), ('Equalize', 0.6, 3)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + ] + +class ToPIL(object): + """Convert image from ndarray format to PIL + """ + def __call__(self, img): + x = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB)) + return x + +class ToNDArray(object): + def __call__(self, img): + x = cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR) + return x + +class RandAugment(object): + def __init__(self, n, m): + self.n = n + self.m = m + self.augment_list = rand_augment_list() + self.topil = ToPIL() + + def __call__(self, img): + img = self.topil(img) + ops = random.choices(self.augment_list, k=self.n) + for op, minval, maxval in ops: + if random.random() > random.uniform(0.2, 0.8): + continue + val = (float(self.m) / 30) * float(maxval - minval) + minval + img = op(img, val) + return img + +def get_augment(name): + return augment_dict[name] + + +def apply_augment(img, name, level): + augment_fn, low, high = get_augment(name) + return augment_fn(img.copy(), level * (high - low) + low) +class PILGaussianBlur(ImageFilter.Filter): + name = "GaussianBlur" + def __init__(self, radius=2, bounds=None): + self.radius = radius + self.bounds = bounds + def filter(self, image): + if self.bounds: + clips = image.crop(self.bounds).gaussian_blur(self.radius) + image.paste(clips, self.bounds) + return image + else: + return image.gaussian_blur(self.radius) +class GaussianBlur(object): + def __init__(self, radius=2 ): + self.GaussianBlur=PILGaussianBlur(radius) + def __call__(self, img): + img = img.filter( self.GaussianBlur ) + return img +class AugmentationBlock(object): + r""" + AutoAugment Block + + Example + ------- + >>> from autogluon.utils.augment import AugmentationBlock, autoaug_imagenet_policies + >>> aa_transform = AugmentationBlock(autoaug_imagenet_policies()) + """ + def __init__(self, policies): + """ + plicies : list of (name, pr, level) + """ + super().__init__() + self.policies = policies() + self.topil = ToPIL() + self.tond = ToNDArray() + self.Gaussian_blue = PILGaussianBlur(2) + self.policy = [GaussianBlur(),transforms.ColorJitter( 0.1026, 0.0935, 0.8386, 0.1592 ), + transforms.Grayscale(num_output_channels=3)] + # self.colorAug = transforms.RandomApply([transforms.ColorJitter( 0.1026, 0.0935, 0.8386, 0.1592 )], p=0.5) + def __call__(self, img): + img = self.topil(img) + trans = random.choice(self.policy) + if random.random() >= 0.5: + img = trans( img ) + img = self.tond(img) + return img + + +# augBlock = AugmentationBlock( autoaug_imagenet_policies ) +# plt.figure() +# for i in range(20): +# catAug = augBlock( cat ) +# plt.subplot(4,5,i+1) +# plt.imshow(catAug) + +# plt.show() +# im_path = os.path.join('D:/testPy/839_482127.jpg') +# img = Image.open( im_path ).convert('RGB') + +# factor = random.uniform(-0.4, 0.4) +# imgb = T.adjust_brightness(img, 1 + factor) + +# imgc = transforms.ColorJitter( 0.4,0.4,0.4,0.4 )(img) + +# imgd = transforms.RandomHorizontalFlip()(img) \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/__pycache__/ImgTransforms.cpython-310.pyc b/utils/__pycache__/ImgTransforms.cpython-310.pyc new file mode 100644 index 0000000..2581e41 Binary files /dev/null and b/utils/__pycache__/ImgTransforms.cpython-310.pyc differ diff --git a/utils/__pycache__/ImgTransforms.cpython-36.pyc b/utils/__pycache__/ImgTransforms.cpython-36.pyc new file mode 100644 index 0000000..e7a279e Binary files /dev/null and b/utils/__pycache__/ImgTransforms.cpython-36.pyc differ diff --git a/utils/__pycache__/ImgTransforms.cpython-37.pyc b/utils/__pycache__/ImgTransforms.cpython-37.pyc new file mode 100644 index 0000000..09dbf63 Binary files /dev/null and b/utils/__pycache__/ImgTransforms.cpython-37.pyc differ diff --git a/utils/__pycache__/OCRAttention.cpython-36.pyc b/utils/__pycache__/OCRAttention.cpython-36.pyc new file mode 100644 index 0000000..2def857 Binary files /dev/null and b/utils/__pycache__/OCRAttention.cpython-36.pyc differ diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..e444cf4 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/__pycache__/__init__.cpython-36.pyc b/utils/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..52308ea Binary files /dev/null and b/utils/__pycache__/__init__.cpython-36.pyc differ diff --git a/utils/__pycache__/__init__.cpython-37.pyc b/utils/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..a101987 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-37.pyc differ diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..b390198 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/__pycache__/attention.cpython-310.pyc b/utils/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000..36f2540 Binary files /dev/null and b/utils/__pycache__/attention.cpython-310.pyc differ diff --git a/utils/__pycache__/attention.cpython-36.pyc b/utils/__pycache__/attention.cpython-36.pyc new file mode 100644 index 0000000..8f1dd56 Binary files /dev/null and b/utils/__pycache__/attention.cpython-36.pyc differ diff --git a/utils/__pycache__/attention.cpython-37.pyc b/utils/__pycache__/attention.cpython-37.pyc new file mode 100644 index 0000000..c991476 Binary files /dev/null and b/utils/__pycache__/attention.cpython-37.pyc differ diff --git a/utils/__pycache__/attention.cpython-38.pyc b/utils/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000..33daa99 Binary files /dev/null and b/utils/__pycache__/attention.cpython-38.pyc differ diff --git a/utils/__pycache__/criterion.cpython-310.pyc b/utils/__pycache__/criterion.cpython-310.pyc new file mode 100644 index 0000000..0e2019a Binary files /dev/null and b/utils/__pycache__/criterion.cpython-310.pyc differ diff --git a/utils/__pycache__/criterion.cpython-36.pyc b/utils/__pycache__/criterion.cpython-36.pyc new file mode 100644 index 0000000..c1675fc Binary files /dev/null and b/utils/__pycache__/criterion.cpython-36.pyc differ diff --git a/utils/__pycache__/criterion.cpython-37.pyc b/utils/__pycache__/criterion.cpython-37.pyc new file mode 100644 index 0000000..1987bb6 Binary files /dev/null and b/utils/__pycache__/criterion.cpython-37.pyc differ diff --git a/utils/__pycache__/distributed.cpython-36.pyc b/utils/__pycache__/distributed.cpython-36.pyc new file mode 100644 index 0000000..4786780 Binary files /dev/null and b/utils/__pycache__/distributed.cpython-36.pyc differ diff --git a/utils/__pycache__/distributed.cpython-37.pyc b/utils/__pycache__/distributed.cpython-37.pyc new file mode 100644 index 0000000..e4c2fa9 Binary files /dev/null and b/utils/__pycache__/distributed.cpython-37.pyc differ diff --git a/utils/__pycache__/encoding.cpython-36.pyc b/utils/__pycache__/encoding.cpython-36.pyc new file mode 100644 index 0000000..2b55c09 Binary files /dev/null and b/utils/__pycache__/encoding.cpython-36.pyc differ diff --git a/utils/__pycache__/encoding.cpython-37.pyc b/utils/__pycache__/encoding.cpython-37.pyc new file mode 100644 index 0000000..c76626a Binary files /dev/null and b/utils/__pycache__/encoding.cpython-37.pyc differ diff --git a/utils/__pycache__/logger.cpython-36.pyc b/utils/__pycache__/logger.cpython-36.pyc new file mode 100644 index 0000000..c3bc515 Binary files /dev/null and b/utils/__pycache__/logger.cpython-36.pyc differ diff --git a/utils/__pycache__/loss.cpython-310.pyc b/utils/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000..732b7c6 Binary files /dev/null and b/utils/__pycache__/loss.cpython-310.pyc differ diff --git a/utils/__pycache__/loss.cpython-36.pyc b/utils/__pycache__/loss.cpython-36.pyc new file mode 100644 index 0000000..76d018e Binary files /dev/null and b/utils/__pycache__/loss.cpython-36.pyc differ diff --git a/utils/__pycache__/loss.cpython-37.pyc b/utils/__pycache__/loss.cpython-37.pyc new file mode 100644 index 0000000..916b3b7 Binary files /dev/null and b/utils/__pycache__/loss.cpython-37.pyc differ diff --git a/utils/__pycache__/lovasz_losses.cpython-310.pyc b/utils/__pycache__/lovasz_losses.cpython-310.pyc new file mode 100644 index 0000000..535cb34 Binary files /dev/null and b/utils/__pycache__/lovasz_losses.cpython-310.pyc differ diff --git a/utils/__pycache__/lovasz_losses.cpython-36.pyc b/utils/__pycache__/lovasz_losses.cpython-36.pyc new file mode 100644 index 0000000..7f2c3d6 Binary files /dev/null and b/utils/__pycache__/lovasz_losses.cpython-36.pyc differ diff --git a/utils/__pycache__/lovasz_losses.cpython-37.pyc b/utils/__pycache__/lovasz_losses.cpython-37.pyc new file mode 100644 index 0000000..e17eaf8 Binary files /dev/null and b/utils/__pycache__/lovasz_losses.cpython-37.pyc differ diff --git a/utils/__pycache__/miou.cpython-310.pyc b/utils/__pycache__/miou.cpython-310.pyc new file mode 100644 index 0000000..1012fee Binary files /dev/null and b/utils/__pycache__/miou.cpython-310.pyc differ diff --git a/utils/__pycache__/miou.cpython-36.pyc b/utils/__pycache__/miou.cpython-36.pyc new file mode 100644 index 0000000..820bb19 Binary files /dev/null and b/utils/__pycache__/miou.cpython-36.pyc differ diff --git a/utils/__pycache__/miou.cpython-37.pyc b/utils/__pycache__/miou.cpython-37.pyc new file mode 100644 index 0000000..fd36a43 Binary files /dev/null and b/utils/__pycache__/miou.cpython-37.pyc differ diff --git a/utils/__pycache__/model_store.cpython-36.pyc b/utils/__pycache__/model_store.cpython-36.pyc new file mode 100644 index 0000000..aa79dce Binary files /dev/null and b/utils/__pycache__/model_store.cpython-36.pyc differ diff --git a/utils/__pycache__/pyt_utils.cpython-36.pyc b/utils/__pycache__/pyt_utils.cpython-36.pyc new file mode 100644 index 0000000..8fbb937 Binary files /dev/null and b/utils/__pycache__/pyt_utils.cpython-36.pyc differ diff --git a/utils/__pycache__/transforms.cpython-310.pyc b/utils/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000..568a9c8 Binary files /dev/null and b/utils/__pycache__/transforms.cpython-310.pyc differ diff --git a/utils/__pycache__/transforms.cpython-36.pyc b/utils/__pycache__/transforms.cpython-36.pyc new file mode 100644 index 0000000..557b2b3 Binary files /dev/null and b/utils/__pycache__/transforms.cpython-36.pyc differ diff --git a/utils/__pycache__/transforms.cpython-37.pyc b/utils/__pycache__/transforms.cpython-37.pyc new file mode 100644 index 0000000..2938194 Binary files /dev/null and b/utils/__pycache__/transforms.cpython-37.pyc differ diff --git a/utils/__pycache__/utils.cpython-310.pyc b/utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..78964df Binary files /dev/null and b/utils/__pycache__/utils.cpython-310.pyc differ diff --git a/utils/__pycache__/utils.cpython-36.pyc b/utils/__pycache__/utils.cpython-36.pyc new file mode 100644 index 0000000..355a480 Binary files /dev/null and b/utils/__pycache__/utils.cpython-36.pyc differ diff --git a/utils/__pycache__/utils.cpython-37.pyc b/utils/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000..e310e40 Binary files /dev/null and b/utils/__pycache__/utils.cpython-37.pyc differ diff --git a/utils/attention.py b/utils/attention.py new file mode 100644 index 0000000..fdba455 --- /dev/null +++ b/utils/attention.py @@ -0,0 +1,276 @@ + +import numpy as np +import torch +import math +import torch.nn as nn +from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \ + NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding +from torch.nn import functional as F +from torch.autograd import Variable +import functools + +from torch.nn import BatchNorm2d as BatchNorm2d +from torch.nn import BatchNorm1d as BatchNorm1d + +def conv2d(in_channel, out_channel, kernel_size): + layers = [ + nn.Conv2d(in_channel, out_channel, kernel_size, padding=kernel_size // 2, bias=False), + BatchNorm2d(out_channel), + nn.ReLU(), + ] + + return nn.Sequential(*layers) + +def conv1d(in_channel, out_channel, kernel_size): + layers = [ + nn.Conv1d(in_channel, out_channel, kernel_size, padding=kernel_size // 2, bias=False), + BatchNorm1d(out_channel), + nn.ReLU(), + ] + + return nn.Sequential(*layers) + + +class CDGAttention(nn.Module): + def __init__(self, feat_in=512, feat_out=256, num_classes=20, size=[384//16,384//16], kernel_size =7 ): + super(CDGAttention, self).__init__() + h,w = size[0],size[1] + kSize = kernel_size + self.gamma = Parameter(torch.ones(1)) + self.beta = Parameter(torch.ones(1)) + self.rowpool = nn.AdaptiveAvgPool2d((h,1)) + self.colpool = nn.AdaptiveAvgPool2d((1,w)) + self.conv_hgt1 =conv1d(feat_in,feat_out,3) + self.conv_hgt2 =conv1d(feat_in,feat_out,3) + self.conv_hwPred1 = nn.Sequential( + nn.Conv1d(feat_out,num_classes,3,stride=1,padding=1,bias=True), + nn.Sigmoid(), + ) + self.conv_hwPred2 = nn.Sequential( + nn.Conv1d(feat_out,num_classes,3,stride=1,padding=1,bias=True), + nn.Sigmoid(), + ) + self.conv_upDim1 = nn.Sequential( + nn.Conv1d(feat_out,feat_in,kSize,stride=1,padding=kSize//2,bias=True), + nn.Sigmoid(), + ) + self.conv_upDim2 = nn.Sequential( + nn.Conv1d(feat_out,feat_in,kSize,stride=1,padding=kSize//2,bias=True), + nn.Sigmoid(), + ) + self.cmbFea = conv2d( feat_in*3,feat_in,3) + def forward(self,fea): + n,c,h,w = fea.size() + fea_h = self.rowpool(fea).squeeze(3) #n,c,h + fea_w = self.colpool(fea).squeeze(2) #n,c,w + fea_h = self.conv_hgt1(fea_h) #n,c,h + fea_w = self.conv_hgt2(fea_w) + #=========================================================== + fea_hp = self.conv_hwPred1(fea_h) #n,class_num,h + fea_wp = self.conv_hwPred2(fea_w) #n,class_num,w + #=========================================================== + fea_h = self.conv_upDim1(fea_h) + fea_w = self.conv_upDim2(fea_w) + fea_hup = fea_h.unsqueeze(3) + fea_wup = fea_w.unsqueeze(2) + fea_hup = F.interpolate( fea_hup, (h,w), mode='bilinear', align_corners= True ) #n,c,h,w + fea_wup = F.interpolate( fea_wup, (h,w), mode='bilinear', align_corners= True ) #n,c,h,w + fea_hw = self.beta*fea_wup + self.gamma*fea_hup + fea_hw_aug = fea * fea_hw + #=============================================================== + fea = torch.cat([fea, fea_hw_aug, fea_hw], dim = 1 ) + fea = self.cmbFea( fea ) + return fea, fea_hp, fea_wp + +class C2CAttention(nn.Module): + def __init__(self, in_fea, out_fea, num_class ): + super(C2CAttention, self).__init__() + self.in_fea = in_fea + self.out_fea = out_fea + self.num_class = num_class + self.gamma = Parameter(torch.ones(1)) + self.beta = Parameter(torch.ones(1)) + self.bias1 = Parameter( torch.FloatTensor( num_class, num_class )) + self.bias2 = Parameter( torch.FloatTensor( num_class, num_class )) + self.convDwn1 = conv2d( in_fea, out_fea, 1 ) + self.convDwn2 = conv2d( in_fea, out_fea, 1 ) + self.convUp1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1,1)), + conv2d( num_class, out_fea, 1 ), + nn.Conv2d(out_fea,in_fea,1,stride=1,padding=0,bias=True), + ) + self.toClass = nn.Sequential( + nn.Conv2d( out_fea, num_class, 1, stride=1, padding = 0, bias = True ), + ) + self.convUp2 = nn.Sequential( + nn.AdaptiveAvgPool2d((1,1)), + conv2d( num_class, out_fea, 1 ), + nn.Conv2d(out_fea,in_fea,1,stride=1,padding=0,bias=True), + ) + self.fea_fuse = conv2d( in_fea*2, in_fea, 1 ) + self.sigmoid = nn.Sigmoid() + self.reset_parameters() + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.bias1) + torch.nn.init.xavier_uniform_(self.bias2) + def forward(self,input_fea): + n, c, h, w = input_fea.size() + fea_ha = self.convDwn1( input_fea ) + fea_wa = self.convDwn2( input_fea ) + cls_ha = self.toClass( fea_ha ) + cls_ha = F.softmax(cls_ha, dim=1) + cls_wa = self.toClass( fea_wa ) + cls_wa = F.softmax(cls_wa, dim=1) + cls_ha = cls_ha.view( n, self.num_class, h*w ) + cls_wa = cls_wa.view( n, self.num_class, h*w ) + cch = F.relu(torch.matmul( cls_ha, cls_ha.transpose( 1, 2 ) )) #class*class + cch = cch + cch = self.sigmoid( cch ) + self.bias1 + ccw = F.relu(torch.matmul( cls_wa, cls_wa. transpose( 1, 2 ) )) #class*class + ccw = ccw + ccw = self.sigmoid( ccw )+ self.bias2 + cls_ha = torch.matmul( cls_ha.transpose(1,2), cch.transpose(1,2) ) + cls_ha = cls_ha.transpose( 1,2).contiguous().view( n, self.num_class, h, w ) + cls_wa = torch.matmul( cls_wa.transpose(1,2), ccw.transpose(1,2) ) + cls_wa = cls_wa.transpose(1,2).contiguous().view( n, self.num_class, h, w ) + fea_ha = self.convUp1( cls_ha ) + fea_wa = self.convUp2( cls_wa ) + fea_hwa = self.gamma*fea_ha + self.beta*fea_wa + fea_hwa_aug = input_fea * fea_hwa #* + fea_fuse = torch.cat( [fea_hwa_aug, input_fea], dim = 1 ) + fea_fuse = self.fea_fuse( fea_fuse ) + return fea_fuse, cch, ccw + +class StatisticAttention(nn.Module): + def __init__(self,fea_in, fea_out, num_classes ): + super(StatisticAttention, self).__init__() + # self.gamma = Parameter(torch.ones(1)) + self.conv_1 = conv2d( fea_in, fea_in//2, 1) #kernel size 3 + self.conv_2 = conv2d( fea_in//2, num_classes, 3 ) + self.conv_pred = nn.Sequential( + nn.Conv2d( num_classes, 1, 3, stride=1, padding=1, bias=True), #kernel size 1 + nn.Sigmoid() + ) + self.conv_fuse = conv2d( fea_in * 2, fea_out, 3 ) + def forward(self,fea): + fea_att = self.conv_1( fea ) + fea_cls = self.conv_2( fea_att ) + fea_stat = self.conv_pred( fea_cls ) + fea_aug = fea * ( 1 - fea_stat ) + fea_fuse = torch.cat( [fea, fea_aug], dim = 1 ) + fea_res = self.conv_fuse( fea_fuse ) + return fea_res, fea_stat + +class PSPModule(nn.Module): + # (1, 2, 3, 6) + def __init__(self, sizes=(1, 3, 7, 11), dimension=2): + super(PSPModule, self).__init__() + self.stages = nn.ModuleList([self._make_stage(size, dimension) for size in sizes]) + + def _make_stage(self, size, dimension=2): + if dimension == 1: + prior = nn.AdaptiveAvgPool1d(output_size=size) + elif dimension == 2: + prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) + elif dimension == 3: + prior = nn.AdaptiveAvgPool3d(output_size=(size, size, size)) + return prior + + def forward(self, feats): + n, c, _, _ = feats.size() + priors = [stage(feats).view(n, c, -1) for stage in self.stages] + center = torch.cat(priors, -1) + return center + +class PCM(Module): + def __init__(self, feat_channels=[256,1024]): + super().__init__() + feat1, feat2 = feat_channels + self.conv_x2 = conv2d( feat1, 256, 1 ) + self.conv_x4 = conv2d( feat2, 256, 1 ) + self.conv_cmb = conv2d( 256+256+3, 256, 1 ) + self.softmax = Softmax(dim=-1) + self.psp = PSPModule() + self.addCAM = conv2d( 512, 256, 1) + def forward(self, xOrg, stg2, stg4, cam ): + n,c,h,w = stg2.size() + stg2 = self.conv_x2( stg2 ) + stg4 = self.conv_x4( stg4 ) + stg4 = F.interpolate( stg4, (h,w), mode='bilinear', align_corners= True) + stg0 = F.interpolate( xOrg, (h,w), mode='bilinear', align_corners= True) + stgSum = torch.cat([stg0,stg2,stg4],dim=1) + stgSum = self.conv_cmb( stgSum ) + stgPool = self.psp( stgSum ) #(N,c,s) + stgSum = stgSum.view( n, -1, h*w ).transpose(1,2) #(N,h*w,c) + stg_aff = torch.matmul( stgSum, stgPool ) #(N,h*w,c)*(N,c,s)=(N,h*w,s) + stg_aff = ( c ** -0.5 ) * stg_aff + stg_aff = F.softmax( stg_aff, dim = -1 ) #(N,h*w,s) + with torch.no_grad(): + cam_d = F.relu( cam.detach() ) + cam_d = F.interpolate( cam_d, (h,w), mode='bilinear', align_corners= True) + cam_pool = self.psp( cam_d ).transpose(1,2) #(N,s,c) + cam_rv = torch.matmul( stg_aff, cam_pool ).transpose(1,2) + cam_rv=cam_rv.view(n, -1, h, w ) + out = torch.cat([cam, cam_rv], dim=1 ) + out = self.addCAM( out ) + return out + +class GCM(Module): + def __init__(self, feat_channels=512): + super().__init__() + + chHig = feat_channels + self.gamma = Parameter(torch.ones(1)) + self.higC = conv2d( chHig, 256, 3 ) + self.coe = nn.Sequential( + conv2d( 256, 256, 3 ), + nn.AdaptiveAvgPool2d((1,1)) + ) + + def forward(self, fea ): + n,_,h, w = fea.size() + stgHig = self.higC( fea ) + coeHig = self.coe( stgHig ) + sim = stgHig - coeHig + # print( sim.size() ) + simDis = torch.norm( sim, 2, 1, keepdim = True ) + # print( simDis.size() ) + simDimMin = simDis.view( n, -1 ) + simDisMin = torch.min( simDimMin, 1, keepdim = True )[0] + # print( simDisMin.size() ) + simDis = simDis.view( n, -1 ) + weightHig = torch.exp( -( simDis - simDisMin ) / 5 ) + weightHig = weightHig.view(n, -1, h, w ) + upFea = F.interpolate( coeHig, (h,w), mode='bilinear', align_corners=True) + upFea = upFea * weightHig + stgHig = stgHig + self.gamma * upFea + + return weightHig, stgHig + +class LCM(Module): + def __init__(self, feat_channels=[256, 256, 512]): + super().__init__() + + chHig, chLow1, chLow2 = feat_channels + self.beta = Parameter(torch.ones(1)) + self.lowC1 = conv2d( chLow1, 48,3) + self.lowC2 = conv2d( chLow2,128,3) + self.cat1 = conv2d( 256+48, 256, 1 ) + self.cat2 = conv2d( 256+128, 256, 1 ) + + def forward(self, feaHig, feaCeo, feaLow1, feaLow2 ): + n,c,h,w = feaLow1.size() + stgHig = F.interpolate( feaHig, (h,w), mode='bilinear', align_corners=True) + weightLow = F.interpolate( feaCeo, (h,w), mode='bilinear', align_corners=True ) + coeLow = 1 - weightLow + stgLow1 = self.lowC1(feaLow1) + stgLow2 = self.lowC2(feaLow2) + stgLow2 = F.interpolate( stgLow2, (h,w), mode='bilinear', align_corners=True ) + + stgLow1 = self.beta * coeLow * stgLow1 + stgCat = torch.cat( [stgHig, stgLow1], dim = 1 ) + stgCat = self.cat1( stgCat ) + stgLow2 = self.beta * coeLow * stgLow2 + stgCat = torch.cat( [stgCat, stgLow2], dim = 1 ) + stgCat = self.cat2( stgCat ) + return stgCat diff --git a/utils/criterion.py b/utils/criterion.py new file mode 100644 index 0000000..81d682f --- /dev/null +++ b/utils/criterion.py @@ -0,0 +1,114 @@ +import torch.nn as nn +import torch +import numpy as np +import utils.lovasz_losses as L +from torch.nn import functional as F +from torch.nn import Parameter +from .loss import OhemCrossEntropy2d +from dataset.target_generation import generate_edge + + +# class ConsistencyLoss(nn.Module): +# def __init__(self, ignore_index=255): +# super(ConsistencyLoss, self).__init__() +# self.ignore_index=ignore_index + +# def forward(self, parsing, edge, label): +# parsing_pre = torch.argmax(parsing, dim=1) +# parsing_pre[label==self.ignore_index]=self.ignore_index +# generated_edge = generate_edge(parsing_pre) +# edge_pre = torch.argmax(edge, dim=1) +# v_generate_edge = generated_edge[label!=255] +# v_edge_pre = edge_pre[label!=255] +# v_edge_pre = v_edge_pre.type(torch.cuda.FloatTensor) +# positive_union = (v_generate_edge==1)&(v_edge_pre==1) # only the positive values count +# return F.smooth_l1_loss(v_generate_edge[positive_union].squeeze(0), v_edge_pre[positive_union].squeeze(0)) + +class CriterionAll(nn.Module): + def __init__(self, ignore_index=255): + super(CriterionAll, self).__init__() + self.ignore_index = ignore_index + self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) + # self.ConsEdge = ConsistencyLoss(ignore_index=ignore_index) + self.cos_sim = torch.nn.CosineSimilarity(dim=-1) + # self.l2Loss = torch.nn.MSELoss(reduction='mean') + self.l2loss = torch.nn.MSELoss() + + def parsing_loss(self, preds, target, hwgt ): + h, w = target[0].size(1), target[0].size(2) + + pos_num = torch.sum(target[1] == 1, dtype=torch.float) + neg_num = torch.sum(target[1] == 0, dtype=torch.float) + + weight_pos = neg_num / (pos_num + neg_num) + weight_neg = pos_num / (pos_num + neg_num) + weights = torch.tensor([weight_neg, weight_pos]) + loss = 0 + + # loss for parsing + pws = [0.4,1,1,1] + preds_parsing = preds[0] + ind = 0 + tmpLoss = 0 + if isinstance(preds_parsing, list): + for pred_parsing in preds_parsing: + scale_pred = F.interpolate(input=pred_parsing, size=(h, w), + mode='bilinear', align_corners=True) + tmpLoss = self.criterion(scale_pred, target[0]) + scale_pred = F.softmax( scale_pred, dim = 1 ) + tmpLoss += L.lovasz_softmax( scale_pred, target[0], ignore = self.ignore_index ) + tmpLoss *= pws[ind] + loss += tmpLoss + ind+=1 + else: + scale_pred = F.interpolate(input=preds_parsing, size=(h, w), + mode='bilinear', align_corners=True) + loss += self.criterion(scale_pred, target[0]) + # scale_pred = F.softmax( scale_pred, dim = 1 ) + # loss += L.lovasz_softmax( scale_pred, target[0], ignore = self.ignore_index ) + + # loss for edge + tmpLoss = 0 + preds_edge = preds[1] + if isinstance(preds_edge, list): + for pred_edge in preds_edge: + scale_pred = F.interpolate(input=pred_edge, size=(h, w), + mode='bilinear', align_corners=True) + tmpLoss += F.cross_entropy(scale_pred, target[1], + weights.cuda(), ignore_index=self.ignore_index) + else: + scale_pred = F.interpolate(input=preds_edge, size=(h, w), + mode='bilinear', align_corners=True) + tmpLoss += F.cross_entropy(scale_pred, target[1], + weights.cuda(), ignore_index=self.ignore_index) + loss += tmpLoss + # loss for height and width attention + #loss for hwattention + hwLoss = 0 + hgt = hwgt[0] + wgt = hwgt[1] + n,c,h = hgt.size() + w = wgt.size()[2] + hpred = preds[2][0] #fea_h... + scale_hpred = hpred.unsqueeze(3) #n,c,h,1 + scale_hpred = F.interpolate(input=scale_hpred, size=(h,1),mode='bilinear', align_corners=True) + scale_hpred = scale_hpred.squeeze(3) #n,c,h + # hgt = hgt[:,1:,:] + # scale_hpred=scale_hpred[:,1:,:] + hloss = torch.mean( ( hgt - scale_hpred ) * ( hgt - scale_hpred ) ) + wpred = preds[2][1] #fea_w... + scale_wpred = wpred.unsqueeze(2) #n,c,1,w + scale_wpred = F.interpolate(input=scale_wpred, size=(1,w),mode='bilinear', align_corners=True) + scale_wpred = scale_wpred.squeeze(2) #n,c,w + # wgt=wgt[:,1:,:] + # scale_wpred = scale_wpred[:,1:,:] + wloss = torch.mean( ( wgt - scale_wpred ) * ( wgt - scale_wpred ) ) + hwLoss = ( hloss + wloss ) * 45 + loss += hwLoss + return loss + + def forward(self, preds, target, hwgt ): + + loss = self.parsing_loss(preds, target, hwgt ) + return loss + \ No newline at end of file diff --git a/utils/distributed.py b/utils/distributed.py new file mode 100644 index 0000000..44654d2 --- /dev/null +++ b/utils/distributed.py @@ -0,0 +1,165 @@ +import math +import pickle + +import torch +from torch import distributed as dist +from torch.utils.data.sampler import Sampler + + +def get_rank(): + if not dist.is_available(): + return 0 + + if not dist.is_initialized(): + return 0 + + return dist.get_rank() + + +def synchronize(): + if not dist.is_available(): + return + + if not dist.is_initialized(): + return + + world_size = dist.get_world_size() + + if world_size == 1: + return + + dist.barrier() + + +def get_world_size(): + if not dist.is_available(): + return 1 + + if not dist.is_initialized(): + return 1 + + return dist.get_world_size() + + +def all_gather(data): + world_size = get_world_size() + + if world_size == 1: + return [data] + + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to('cuda') + + local_size = torch.IntTensor([tensor.numel()]).to('cuda') + size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) + + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') + tensor = torch.cat((tensor, padding), 0) + + dist.all_gather(tensor_list, tensor) + + data_list = [] + + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_loss_dict(loss_dict): + world_size = get_world_size() + + if world_size < 2: + return loss_dict + + with torch.no_grad(): + keys = [] + losses = [] + + for k in sorted(loss_dict.keys()): + keys.append(k) + losses.append(loss_dict[k]) + + losses = torch.stack(losses, 0) + dist.reduce(losses, dst=0) + + if dist.get_rank() == 0: + losses /= world_size + + reduced_losses = {k: v for k, v in zip(keys, losses)} + + return reduced_losses + + +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Code is copy-pasted exactly as in torch.utils.data.distributed. +# FIXME remove this once c10d fixes the bug it has + + +class DistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset : offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/utils/encode_masks.py b/utils/encode_masks.py new file mode 100644 index 0000000..b033e82 --- /dev/null +++ b/utils/encode_masks.py @@ -0,0 +1,39 @@ +import torch +import torch.nn.functional as fun +import numpy as np +import cv2 +import os +import matplotlib.pyplot as plt + + +def encode_masks(img, colours, num_classes = 22): + + label = np.zeros_like(img) + print(label.shape) + for idx, color in enumerate(colours): + + label[np.sum(img == np.array([[color]]), 2) == 3] = idx + print(label.shape) + onehot = np.eye(num_classes)[label] + return onehot + + +def one_hot(img, colours): + + h, w = img.shape[:2] + x = img.copy() + + + x[np.where()] + + print(x.shape) + return x + +colours = np.array([[120, 120, 120], [127, 0, 0], [254, 0, 0], [0, 84, 0], [169, 0, 50], [254, 84, 0], [255, 0, 84], [0, 118, 220], [84, 84, 0], [0, 84, 84], [84, 50, 0], [51, 85, 127], [0, 127, 0], [0, 0, 254], [50, 169, 220], [0, 254, 254], [84, 254, 169], [169, 254, 84], [254, 254, 0], [254, 169, 0], [102, 254, 0], [182, 255, 0]]) +print(colours[0]) +img = cv2.imread('/home/vrushank/Downloads/instance-level_human_parsing/Training/Categories/0000010.png') + +label = one_hot(img, colours) +print(label.shape) +plt.imshow(label, cmap = 'gray') +plt.show() diff --git a/utils/encoding.py b/utils/encoding.py new file mode 100644 index 0000000..6e8cb8b --- /dev/null +++ b/utils/encoding.py @@ -0,0 +1,251 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Hang Zhang +## ECE Department, Rutgers University +## Email: zhang.hang@rutgers.edu +## Copyright (c) 2017 +## +## This source code is licensed under the MIT-style license found in the +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +"""Encoding Data Parallel""" +import threading +import functools +import torch +from torch.autograd import Variable, Function +import torch.cuda.comm as comm +from torch.nn.parallel.data_parallel import DataParallel +from torch.nn.parallel.parallel_apply import get_a_var +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +torch_ver = torch.__version__[:3] + +__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', + 'patch_replication_callback'] + +def allreduce(*inputs): + """Cross GPU all reduce autograd operation for calculate mean and + variance in SyncBN. + """ + return AllReduce.apply(*inputs) + +class AllReduce(Function): + @staticmethod + def forward(ctx, num_inputs, *inputs): + ctx.num_inputs = num_inputs + ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] + inputs = [inputs[i:i + num_inputs] + for i in range(0, len(inputs), num_inputs)] + # sort before reduce sum + inputs = sorted(inputs, key=lambda i: i[0].get_device()) + results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) + outputs = comm.broadcast_coalesced(results, ctx.target_gpus) + return tuple([t for tensors in outputs for t in tensors]) + + @staticmethod + def backward(ctx, *inputs): + inputs = [i.data for i in inputs] + inputs = [inputs[i:i + ctx.num_inputs] + for i in range(0, len(inputs), ctx.num_inputs)] + results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) + outputs = comm.broadcast_coalesced(results, ctx.target_gpus) + return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) + + +class Reduce(Function): + @staticmethod + def forward(ctx, *inputs): + ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] + inputs = sorted(inputs, key=lambda i: i.get_device()) + return comm.reduce_add(inputs) + + @staticmethod + def backward(ctx, gradOutput): + return Broadcast.apply(ctx.target_gpus, gradOutput) + + +class DataParallelModel(DataParallel): + """Implements data parallelism at the module level. + + This container parallelizes the application of the given module by + splitting the input across the specified devices by chunking in the + batch dimension. + In the forward pass, the module is replicated on each device, + and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. + Note that the outputs are not gathered, please use compatible + :class:`encoding.parallel.DataParallelCriterion`. + + The batch size should be larger than the number of GPUs used. It should + also be an integer multiple of the number of GPUs so that each chunk is + the same size (so that each GPU processes the same number of samples). + + Args: + module: module to be parallelized + device_ids: CUDA devices (default: all devices) + + Reference: + Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, + Amit Agrawal. “Context Encoding for Semantic Segmentation. + *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* + + Example:: + + >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) + >>> y = net(x) + """ + def gather(self, outputs, output_device): + return outputs + + def replicate(self, module, device_ids): + modules = super(DataParallelModel, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +class DataParallelCriterion(DataParallel): + """ + Calculate loss in multiple-GPUs, which balance the memory usage for + Semantic Segmentation. + + The targets are splitted across the specified devices by chunking in + the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. + + Reference: + Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, + Amit Agrawal. “Context Encoding for Semantic Segmentation. + *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* + + Example:: + + >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) + >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) + >>> y = net(x) + >>> loss = criterion(y, target) + """ + def forward(self, inputs, *targets, **kwargs): + # input should be already scatterd + # scattering the targets instead + if not self.device_ids: + return self.module(inputs, *targets, **kwargs) + targets, kwargs = self.scatter(targets, kwargs, self.device_ids) + if len(self.device_ids) == 1: + return self.module(inputs, *targets[0], **kwargs[0]) + replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) + outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) + return Reduce.apply(*outputs) / len(outputs) + #return self.gather(outputs, self.output_device).mean() + + +def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): + assert len(modules) == len(inputs) + assert len(targets) == len(inputs) + if kwargs_tup: + assert len(modules) == len(kwargs_tup) + else: + kwargs_tup = ({},) * len(modules) + if devices is not None: + assert len(modules) == len(devices) + else: + devices = [None] * len(modules) + + lock = threading.Lock() + results = {} + if torch_ver != "0.3": + grad_enabled = torch.is_grad_enabled() + + def _worker(i, module, input, target, kwargs, device=None): + if torch_ver != "0.3": + torch.set_grad_enabled(grad_enabled) + if device is None: + device = get_a_var(input).get_device() + try: + if not isinstance(input, tuple): + input = (input,) + with torch.cuda.device(device): + output = module(*(input + target), **kwargs) + with lock: + results[i] = output + except Exception as e: + with lock: + results[i] = e + + if len(modules) > 1: + threads = [threading.Thread(target=_worker, + args=(i, module, input, target, + kwargs, device),) + for i, (module, input, target, kwargs, device) in + enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + else: + _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) + + outputs = [] + for i in range(len(inputs)): + output = results[i] + if isinstance(output, Exception): + raise output + outputs.append(output) + return outputs + + +########################################################################### +# Adapted from Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created + by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead + of calling the callback of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate \ No newline at end of file diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..6113a37 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,94 @@ +import os +import sys +import logging + +# from . import pyt_utils +# from utils.pyt_utils import ensure_dir + +_default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO') +_default_level = logging.getLevelName(_default_level_name.upper()) + + +class LogFormatter(logging.Formatter): + log_fout = None + date_full = '[%(asctime)s %(lineno)d@%(filename)s:%(name)s] ' + date = '%(asctime)s ' + msg = '%(message)s' + + def format(self, record): + if record.levelno == logging.DEBUG: + mcl, mtxt = self._color_dbg, 'DBG' + elif record.levelno == logging.WARNING: + mcl, mtxt = self._color_warn, 'WRN' + elif record.levelno == logging.ERROR: + mcl, mtxt = self._color_err, 'ERR' + else: + mcl, mtxt = self._color_normal, '' + + if mtxt: + mtxt += ' ' + + if self.log_fout: + self.__set_fmt(self.date_full + mtxt + self.msg) + formatted = super(LogFormatter, self).format(record) + # self.log_fout.write(formatted) + # self.log_fout.write('\n') + # self.log_fout.flush() + return formatted + + self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) + formatted = super(LogFormatter, self).format(record) + + return formatted + + if sys.version_info.major < 3: + def __set_fmt(self, fmt): + self._fmt = fmt + else: + def __set_fmt(self, fmt): + self._style._fmt = fmt + + @staticmethod + def _color_dbg(msg): + return '\x1b[36m{}\x1b[0m'.format(msg) + + @staticmethod + def _color_warn(msg): + return '\x1b[1;31m{}\x1b[0m'.format(msg) + + @staticmethod + def _color_err(msg): + return '\x1b[1;4;31m{}\x1b[0m'.format(msg) + + @staticmethod + def _color_omitted(msg): + return '\x1b[35m{}\x1b[0m'.format(msg) + + @staticmethod + def _color_normal(msg): + return msg + + @staticmethod + def _color_date(msg): + return '\x1b[32m{}\x1b[0m'.format(msg) + + +def get_logger(log_dir=None, log_file=None, formatter=LogFormatter): + logger = logging.getLogger() + logger.setLevel(_default_level) + del logger.handlers[:] + + if log_dir and log_file: + if not os.path.isdir(log_dir): + os.makedirs(log_dir) + LogFormatter.log_fout = True + file_handler = logging.FileHandler(log_file, mode='a') + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter(datefmt='%d %H:%M:%S')) + stream_handler.setLevel(0) + logger.addHandler(stream_handler) + return logger diff --git a/utils/loss.py b/utils/loss.py new file mode 100644 index 0000000..09456e0 --- /dev/null +++ b/utils/loss.py @@ -0,0 +1,93 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.autograd import Variable +import numpy as np +import scipy.ndimage as nd + + +class OhemCrossEntropy2d(nn.Module): + + def __init__(self, ignore_label=255, thresh=0.7, min_kept=100000, factor=8): + super(OhemCrossEntropy2d, self).__init__() + self.ignore_label = ignore_label + self.thresh = float(thresh) + # self.min_kept_ratio = float(min_kept_ratio) + self.min_kept = int(min_kept) + self.factor = factor + self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label) + + def find_threshold(self, np_predict, np_target): + # downsample 1/8 + factor = self.factor + predict = nd.zoom(np_predict, (1.0, 1.0, 1.0/factor, 1.0/factor), order=1) + target = nd.zoom(np_target, (1.0, 1.0/factor, 1.0/factor), order=0) + + n, c, h, w = predict.shape + min_kept = self.min_kept // (factor*factor) #int(self.min_kept_ratio * n * h * w) + + input_label = target.ravel().astype(np.int32) + input_prob = np.rollaxis(predict, 1).reshape((c, -1)) + + valid_flag = input_label != self.ignore_label + valid_inds = np.where(valid_flag)[0] + label = input_label[valid_flag] + num_valid = valid_flag.sum() + if min_kept >= num_valid: + threshold = 1.0 + elif num_valid > 0: + prob = input_prob[:,valid_flag] + pred = prob[label, np.arange(len(label), dtype=np.int32)] + threshold = self.thresh + if min_kept > 0: + k_th = min(len(pred), min_kept)-1 + new_array = np.partition(pred, k_th) + new_threshold = new_array[k_th] + if new_threshold > self.thresh: + threshold = new_threshold + return threshold + + + def generate_new_target(self, predict, target): + np_predict = predict.data.cpu().numpy() + np_target = target.data.cpu().numpy() + n, c, h, w = np_predict.shape + + threshold = self.find_threshold(np_predict, np_target) + + input_label = np_target.ravel().astype(np.int32) + input_prob = np.rollaxis(np_predict, 1).reshape((c, -1)) + + valid_flag = input_label != self.ignore_label + valid_inds = np.where(valid_flag)[0] + label = input_label[valid_flag] + num_valid = valid_flag.sum() + + if num_valid > 0: + prob = input_prob[:,valid_flag] + pred = prob[label, np.arange(len(label), dtype=np.int32)] + kept_flag = pred <= threshold + valid_inds = valid_inds[kept_flag] + print('Labels: {} {}'.format(len(valid_inds), threshold)) + + label = input_label[valid_inds].copy() + input_label.fill(self.ignore_label) + input_label[valid_inds] = label + new_target = torch.from_numpy(input_label.reshape(target.size())).long().cuda(target.get_device()) + + return new_target + + + def forward(self, predict, target, weight=None): + """ + Args: + predict:(n, c, h, w) + target:(n, h, w) + weight (Tensor, optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of size "nclasses" + """ + assert not target.requires_grad + + input_prob = F.softmax(predict, 1) + target = self.generate_new_target(input_prob, target) + return self.criterion(predict, target) diff --git a/utils/lovasz_losses.py b/utils/lovasz_losses.py new file mode 100644 index 0000000..a3f23a5 --- /dev/null +++ b/utils/lovasz_losses.py @@ -0,0 +1,250 @@ +""" +Lovasz-Softmax and Jaccard hinge loss in PyTorch +Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) +""" + +from __future__ import print_function, division + +import torch +from torch.autograd import Variable +import torch.nn.functional as F +import numpy as np +try: + from itertools import ifilterfalse +except ImportError: # py3k + from itertools import filterfalse as ifilterfalse + + +def lovasz_grad(gt_sorted): + """ + Computes gradient of the Lovasz extension w.r.t sorted errors + See Alg. 1 in paper + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): + """ + IoU for foreground class + binary: 1 foreground, 0 background + """ + if not per_image: + preds, labels = (preds,), (labels,) + ious = [] + for pred, label in zip(preds, labels): + intersection = ((label == 1) & (pred == 1)).sum() + union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() + if not union: + iou = EMPTY + else: + iou = float(intersection) / float(union) + ious.append(iou) + iou = mean(ious) # mean accross images if per_image + return 100 * iou + + +def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): + """ + Array of IoU for each (non ignored) class + """ + if not per_image: + preds, labels = (preds,), (labels,) + ious = [] + for pred, label in zip(preds, labels): + iou = [] + for i in range(C): + if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) + intersection = ((label == i) & (pred == i)).sum() + union = ((label == i) | ((pred == i) & (label != ignore))).sum() + if not union: + iou.append(EMPTY) + else: + iou.append(float(intersection) / float(union)) + ious.append(iou) + ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image + return 100 * np.array(ious) + + +# --------------------------- BINARY LOSSES --------------------------- + + +def lovasz_hinge(logits, labels, per_image=True, ignore=None): + """ + Binary Lovasz hinge loss + logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + per_image: compute the loss per image instead of per batch + ignore: void class id + """ + if per_image: + loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) + for log, lab in zip(logits, labels)) + else: + loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) + return loss + + +def lovasz_hinge_flat(logits, labels): + """ + Binary Lovasz hinge loss + logits: [P] Variable, logits at each prediction (between -\infty and +\infty) + labels: [P] Tensor, binary ground truth labels (0 or 1) + ignore: label to ignore + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * Variable(signs)) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), Variable(grad)) + return loss + + +def flatten_binary_scores(scores, labels, ignore=None): + """ + Flattens predictions in the batch (binary case) + Remove labels equal to 'ignore' + """ + scores = scores.view(-1) + labels = labels.view(-1) + if ignore is None: + return scores, labels + valid = (labels != ignore) + vscores = scores[valid] + vlabels = labels[valid] + return vscores, vlabels + + +class StableBCELoss(torch.nn.modules.Module): + def __init__(self): + super(StableBCELoss, self).__init__() + def forward(self, input, target): + neg_abs = - input.abs() + loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() + return loss.mean() + + +def binary_xloss(logits, labels, ignore=None): + """ + Binary Cross entropy loss + logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + ignore: void class id + """ + logits, labels = flatten_binary_scores(logits, labels, ignore) + loss = StableBCELoss()(logits, Variable(labels.float())) + return loss + + +# --------------------------- MULTICLASS LOSSES --------------------------- + + +def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): + """ + Multi-class Lovasz-Softmax loss + probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). + Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. + labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) + classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + per_image: compute the loss per image instead of per batch + ignore: void class labels + """ + if per_image: + loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) + for prob, lab in zip(probas, labels)) + else: + loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) + return loss + + +def lovasz_softmax_flat(probas, labels, classes='present'): + """ + Multi-class Lovasz-Softmax loss + probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) + labels: [P] Tensor, ground truth labels (between 0 and C - 1) + classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + """ + if probas.numel() == 0: + # only void pixels, the gradients should be 0 + return probas * 0. + C = probas.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + fg = (labels == c).float() # foreground for class c + if (classes is 'present' and fg.sum() == 0): + continue + if C == 1: + if len(classes) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probas[:, 0] + else: + class_pred = probas[:, c] + errors = (Variable(fg) - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) + return mean(losses) + + +def flatten_probas(probas, labels, ignore=None): + """ + Flattens predictions in the batch + """ + if probas.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probas.size() + probas = probas.view(B, 1, H, W) + B, C, H, W = probas.size() + probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C + labels = labels.view(-1) + if ignore is None: + return probas, labels + valid = (labels != ignore) + vprobas = probas[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobas, vlabels + +def xloss(logits, labels, ignore=None): + """ + Cross entropy loss + """ + return F.cross_entropy(logits, Variable(labels), ignore_index=255) + + +# --------------------------- HELPER FUNCTIONS --------------------------- +def isnan(x): + return x != x + + +def mean(l, ignore_nan=False, empty=0): + """ + nanmean compatible with generators. + """ + l = iter(l) + if ignore_nan: + l = ifilterfalse(isnan, l) + try: + n = 1 + acc = next(l) + except StopIteration: + if empty == 'raise': + raise ValueError('Empty mean') + return empty + for n, v in enumerate(l, 2): + acc += v + if n == 1: + return acc + return acc / n diff --git a/utils/miou.py b/utils/miou.py new file mode 100644 index 0000000..286839c --- /dev/null +++ b/utils/miou.py @@ -0,0 +1,204 @@ +import numpy as np +import cv2 +import os +import json +from collections import OrderedDict +import argparse +from PIL import Image as PILImage +from utils.transforms import transform_parsing + +LABELS = ['background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes', 'facemask', 'coat', 'socks', 'pants', 'torso-skin', 'scarf', 'skirt', 'face', 'left-arm', 'right-arm', 'left-leg', 'right-leg', 'left-shoe', 'right-shoe', 'bag', 'others'] + +def get_palette(num_cls): + """ Returns the color map for visualizing the segmentation mask. + Args: + num_cls: Number of classes + Returns: + The color map + """ + + n = num_cls + palette = [0] * (n * 3) + for j in range(0, n): + lab = j + palette[j * 3 + 0] = 0 + palette[j * 3 + 1] = 0 + palette[j * 3 + 2] = 0 + i = 0 + while lab: + palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) + palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) + palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) + i += 1 + lab >>= 3 + return palette + +def get_confusion_matrix(gt_label, pred_label, num_classes): + """ + Calcute the confusion matrix by given label and pred + :param gt_label: the ground truth label + :param pred_label: the pred label + :param num_classes: the nunber of class + :return: the confusion matrix + """ + index = (gt_label * num_classes + pred_label).astype('int32') + label_count = np.bincount(index) + confusion_matrix = np.zeros((num_classes, num_classes)) + + for i_label in range(num_classes): + for i_pred_label in range(num_classes): + cur_index = i_label * num_classes + i_pred_label + if cur_index < len(label_count): + confusion_matrix[i_label, i_pred_label] = label_count[cur_index] + + return confusion_matrix + + +def compute_mean_ioU(preds, scales, centers, num_classes, datadir, input_size=[473, 473], dataset='val'): + list_path = os.path.join(datadir, dataset + '_id.txt') + val_id = [i_id.strip() for i_id in open(list_path)] + + confusion_matrix = np.zeros((num_classes, num_classes)) + + for i, im_name in enumerate(val_id): + gt_path = os.path.join(datadir, dataset + '_segmentations', im_name + '.png') + gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE) + h, w = gt.shape + pred_out = preds[i] + s = scales[i] + c = centers[i] + pred = transform_parsing(pred_out, c, s, w, h, input_size) + + gt = np.asarray(gt, dtype=np.int32) + pred = np.asarray(pred, dtype=np.int32) + + ignore_index = gt != 255 + + gt = gt[ignore_index] + pred = pred[ignore_index] + + confusion_matrix += get_confusion_matrix(gt, pred, num_classes) + + pos = confusion_matrix.sum(1) + res = confusion_matrix.sum(0) + tp = np.diag(confusion_matrix) + + pixel_accuracy = (tp.sum() / pos.sum()) * 100 + mean_accuracy = ((tp / np.maximum(1.0, pos)).mean()) * 100 + IoU_array = (tp / np.maximum(1.0, pos + res - tp)) + IoU_array = IoU_array * 100 + mean_IoU = IoU_array.mean() + print('Pixel accuracy: %f \n' % pixel_accuracy) + print('Mean accuracy: %f \n' % mean_accuracy) + print('Mean IoU: %f \n' % mean_IoU) + name_value = [] + + for i, (label, iou) in enumerate(zip(LABELS, IoU_array)): + name_value.append((label, iou)) + + name_value.append(('Pixel accuracy', pixel_accuracy)) + name_value.append(('Mean accuracy', mean_accuracy)) + name_value.append(('Mean IU', mean_IoU)) + name_value = OrderedDict(name_value) + return name_value, pixel_accuracy, mean_accuracy + +def compute_mean_ioU_file(preds_dir, num_classes, datadir, dataset='val'): + list_path = os.path.join(datadir, dataset + '_id.txt') + val_id = [i_id.strip() for i_id in open(list_path)] + + confusion_matrix = np.zeros((num_classes, num_classes)) + + for i, im_name in enumerate(val_id): + gt_path = os.path.join(datadir, dataset + '_segmentations', im_name + '.png') + gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE) + + pred_path = os.path.join(preds_dir, im_name + '.png') + pred = np.asarray(PILImage.open(pred_path)) + + gt = np.asarray(gt, dtype=np.int32) + pred = np.asarray(pred, dtype=np.int32) + + ignore_index = gt != 255 + + gt = gt[ignore_index] + pred = pred[ignore_index] + + confusion_matrix += get_confusion_matrix(gt, pred, num_classes) + + pos = confusion_matrix.sum(1) + res = confusion_matrix.sum(0) + tp = np.diag(confusion_matrix) + + pixel_accuracy = (tp.sum() / pos.sum())*100 + mean_accuracy = ((tp / np.maximum(1.0, pos)).mean())*100 + IoU_array = (tp / np.maximum(1.0, pos + res - tp)) + IoU_array = IoU_array*100 + mean_IoU = IoU_array.mean() + print('Pixel accuracy: %f \n' % pixel_accuracy) + print('Mean accuracy: %f \n' % mean_accuracy) + print('Mean IU: %f \n' % mean_IoU) + name_value = [] + + for i, (label, iou) in enumerate(zip(LABELS, IoU_array)): + name_value.append((label, iou)) + + name_value.append(('Pixel accuracy', pixel_accuracy)) + name_value.append(('Mean accuracy', mean_accuracy)) + name_value.append(('Mean IU', mean_IoU)) + name_value = OrderedDict(name_value) + return name_value + +def write_results(preds, scales, centers, datadir, dataset, result_dir, input_size=[473, 473]): + palette = get_palette(20) + if not os.path.exists(result_dir): + os.makedirs(result_dir) + + json_file = os.path.join(datadir, 'annotations', dataset + '.json') + with open(json_file) as data_file: + data_list = json.load(data_file) + data_list = data_list['root'] + for item, pred_out, s, c in zip(data_list, preds, scales, centers): + im_name = item['im_name'] + w = item['img_width'] + h = item['img_height'] + pred = transform_parsing(pred_out, c, s, w, h, input_size) + #pred = pred_out + save_path = os.path.join(result_dir, im_name[:-4]+'.png') + + output_im = PILImage.fromarray(np.asarray(pred, dtype=np.uint8)) + output_im.putpalette(palette) + output_im.save(save_path) + +def get_arguments(): + """Parse all the arguments provided from the CLI. + + Returns: + A list of parsed arguments. + """ + parser = argparse.ArgumentParser(description="DeepLabLFOV NetworkEv") + parser.add_argument("--pred-path", type=str, default='', + help="Path to predicted segmentation.") + parser.add_argument("--gt-path", type=str, default='', + help="Path to the groundtruth dir.") + + return parser.parse_args() + + +if __name__ == "__main__": + args = get_arguments() + palette = get_palette(20) + # im_path = '/ssd1/liuting14/Dataset/LIP/val_segmentations/100034_483681.png' + # #compute_mean_ioU_file(args.pred_path, 20, args.gt_path, 'val') + # im = cv2.imread(im_path, cv2.IMREAD_GRAYSCALE) + # print(im.shape) + # test = np.asarray( PILImage.open(im_path)) + # print(test.shape) + # if im.all()!=test.all(): + # print('different') + # output_im = PILImage.fromarray(np.zeros((100,100), dtype=np.uint8)) + # output_im.putpalette(palette) + # output_im.save('test.png') + pred_dir = '/ssd1/liuting14/exps/lip/snapshots/results/epoch4/' + num_classes = 20 + datadir = '/ssd1/liuting14/Dataset/LIP/' + compute_mean_ioU_file(pred_dir, num_classes, datadir, dataset='val') \ No newline at end of file diff --git a/utils/pyt_utils.py b/utils/pyt_utils.py new file mode 100644 index 0000000..ecb239a --- /dev/null +++ b/utils/pyt_utils.py @@ -0,0 +1,217 @@ +# encoding: utf-8 +import os +import sys +import time +import argparse +from collections import OrderedDict, defaultdict + +import torch +import torch.utils.model_zoo as model_zoo +import torch.distributed as dist + +from .logger import get_logger + +logger = get_logger() + +# colour map +label_colours = [(0,0,0) + # 0=background + ,(128,0,0),(0,128,0),(128,128,0),(0,0,128),(128,0,128) + # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle + ,(0,128,128),(128,128,128),(64,0,0),(192,0,0),(64,128,0) + # 6=bus, 7=car, 8=cat, 9=chair, 10=cow + ,(192,128,0),(64,0,128),(192,0,128),(64,128,128),(192,128,128) + # 11=diningtable, 12=dog, 13=horse, 14=motorbike, 15=person + ,(0,64,0),(128,64,0),(0,192,0),(128,192,0),(0,64,128)] + # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor + + +def reduce_tensor(tensor, dst=0, op=dist.ReduceOp.SUM, world_size=1): + tensor = tensor.clone() + dist.reduce(tensor, dst, op) + if dist.get_rank() == dst: + tensor.div_(world_size) + + return tensor + + +def all_reduce_tensor(tensor, op=dist.ReduceOp.SUM, world_size=1, norm=True): + tensor = tensor.clone() + dist.all_reduce(tensor, op) + if norm: + tensor.div_(world_size) + + return tensor + + +def load_model(model, model_file, is_restore=False): + t_start = time.time() + if isinstance(model_file, str): + device = torch.device('cpu') + state_dict = torch.load(model_file, map_location=device) + if 'model' in state_dict.keys(): + state_dict = state_dict['model'] + else: + state_dict = model_file + t_ioend = time.time() + + if is_restore: + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = 'module.' + k + new_state_dict[name] = v + state_dict = new_state_dict + + model.load_state_dict(state_dict, strict=False) + ckpt_keys = set(state_dict.keys()) + own_keys = set(model.state_dict().keys()) + missing_keys = own_keys - ckpt_keys + unexpected_keys = ckpt_keys - own_keys + + if len(missing_keys) > 0: + logger.warning('Missing key(s) in state_dict: {}'.format( + ', '.join('{}'.format(k) for k in missing_keys))) + + if len(unexpected_keys) > 0: + logger.warning('Unexpected key(s) in state_dict: {}'.format( + ', '.join('{}'.format(k) for k in unexpected_keys))) + + del state_dict + t_end = time.time() + logger.info( + "Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format( + t_ioend - t_start, t_end - t_ioend)) + + return model + + +def parse_devices(input_devices): + if input_devices.endswith('*'): + devices = list(range(torch.cuda.device_count())) + return devices + + devices = [] + for d in input_devices.split(','): + if '-' in d: + start_device, end_device = d.split('-')[0], d.split('-')[1] + assert start_device != '' + assert end_device != '' + start_device, end_device = int(start_device), int(end_device) + assert start_device < end_device + assert end_device < torch.cuda.device_count() + for sd in range(start_device, end_device + 1): + devices.append(sd) + else: + device = int(d) + assert device < torch.cuda.device_count() + devices.append(device) + + logger.info('using devices {}'.format( + ', '.join([str(d) for d in devices]))) + + return devices + + +def extant_file(x): + """ + 'Type' for argparse - checks that file exists but does not open. + """ + if not os.path.exists(x): + # Argparse uses the ArgumentTypeError to give a rejection message like: + # error: argument input: x does not exist + raise argparse.ArgumentTypeError("{0} does not exist".format(x)) + return x + + +def link_file(src, target): + if os.path.isdir(target) or os.path.isfile(target): + os.remove(target) + os.system('ln -s {} {}'.format(src, target)) + + +def ensure_dir(path): + if not os.path.isdir(path): + os.makedirs(path) + + +def _dbg_interactive(var, value): + from IPython import embed + embed() + +def decode_labels(mask, num_images=1, num_classes=21): + """Decode batch of segmentation masks. + + Args: + mask: result of inference after taking argmax. + num_images: number of images to decode from the batch. + num_classes: number of classes to predict (including background). + + Returns: + A batch with num_images RGB images of the same size as the input. + """ + mask = mask.data.cpu().numpy() + n, h, w = mask.shape + assert(n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images) + outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8) + for i in range(num_images): + img = Image.new('RGB', (len(mask[i, 0]), len(mask[i]))) + pixels = img.load() + for j_, j in enumerate(mask[i, :, :]): + for k_, k in enumerate(j): + if k < num_classes: + pixels[k_,j_] = label_colours[k] + outputs[i] = np.array(img) + return outputs + +def decode_predictions(preds, num_images=1, num_classes=21): + """Decode batch of segmentation masks. + + Args: + mask: result of inference after taking argmax. + num_images: number of images to decode from the batch. + num_classes: number of classes to predict (including background). + + Returns: + A batch with num_images RGB images of the same size as the input. + """ + if isinstance(preds, list): + preds_list = [] + for pred in preds: + preds_list.append(pred[-1].data.cpu().numpy()) + preds = np.concatenate(preds_list, axis=0) + else: + preds = preds.data.cpu().numpy() + + preds = np.argmax(preds, axis=1) + n, h, w = preds.shape + assert(n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images) + outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8) + for i in range(num_images): + img = Image.new('RGB', (len(preds[i, 0]), len(preds[i]))) + pixels = img.load() + for j_, j in enumerate(preds[i, :, :]): + for k_, k in enumerate(j): + if k < num_classes: + pixels[k_,j_] = label_colours[k] + outputs[i] = np.array(img) + return outputs + +def inv_preprocess(imgs, num_images, img_mean): + """Inverse preprocessing of the batch of images. + Add the mean vector and convert from BGR to RGB. + + Args: + imgs: batch of input images. + num_images: number of images to apply the inverse transformations on. + img_mean: vector of mean colour values. + + Returns: + The batch of the size num_images with the same spatial dimensions as the input. + """ + imgs = imgs.data.cpu().numpy() + n, c, h, w = imgs.shape + assert(n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images) + outputs = np.zeros((num_images, h, w, c), dtype=np.uint8) + for i in range(num_images): + outputs[i] = (np.transpose(imgs[i], (1,2,0)) + img_mean).astype(np.uint8) + return outputs diff --git a/utils/transforms.py b/utils/transforms.py new file mode 100644 index 0000000..da53eee --- /dev/null +++ b/utils/transforms.py @@ -0,0 +1,113 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import cv2 + + +def flip_back(output_flipped, matched_parts): + ''' + ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width) + ''' + assert output_flipped.ndim == 4,\ + 'output_flipped should be [batch_size, num_joints, height, width]' + + output_flipped = output_flipped[:, :, :, ::-1] + + for pair in matched_parts: + tmp = output_flipped[:, pair[0], :, :].copy() + output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :] + output_flipped[:, pair[1], :, :] = tmp + + return output_flipped + + +def transform_parsing(pred, center, scale, width, height, input_size): + + trans = get_affine_transform(center, scale, 0, input_size, inv=1) + target_pred = cv2.warpAffine( + pred, + trans, + (int(width), int(height)), #(int(width), int(height)), + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(0)) + + return target_pred + + +def get_affine_transform(center, + scale, + rot, + output_size, + shift=np.array([0, 0], dtype=np.float32), + inv=0): + if not isinstance(scale, np.ndarray) and not isinstance(scale, list): + print(scale) + scale = np.array([scale, scale]) + + scale_tmp = scale + + src_w = scale_tmp[0] + dst_w = output_size[1] + dst_h = output_size[0] + + rot_rad = np.pi * rot / 180 + src_dir = get_dir([0, src_w * -0.5], rot_rad) + dst_dir = np.array([0, dst_w * -0.5], np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + dst = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + + src[2:, :] = get_3rd_point(src[0, :], src[1, :]) + dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +def affine_transform(pt, t): + new_pt = np.array([pt[0], pt[1], 1.]).T + new_pt = np.dot(t, new_pt) + return new_pt[:2] + + +def get_3rd_point(a, b): + direct = a - b + return b + np.array([-direct[1], direct[0]], dtype=np.float32) + + +def get_dir(src_point, rot_rad): + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + + src_result = [0, 0] + src_result[0] = src_point[0] * cs - src_point[1] * sn + src_result[1] = src_point[0] * sn + src_point[1] * cs + + return src_result + + +def crop(img, center, scale, output_size, rot=0): + trans = get_affine_transform(center, scale, rot, output_size) + + dst_img = cv2.warpAffine(img, + trans, + (int(output_size[1]), int(output_size[0])), + flags=cv2.INTER_LINEAR) + + return dst_img diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..0a709ea --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,83 @@ +from PIL import Image +import numpy as np +import torchvision +import torch + +# colour map +COLORS = [[120, 120, 120], [127, 0, 0], [254, 0, 0], [0, 84, 0], [169, 0, 50], [254, 84, 0], [255, 0, 84], [0, 118, 220], [84, 84, 0], [0, 84, 84], [84, 50, 0], [51, 85, 127], [0, 127, 0], [0, 0, 254], [50, 169, 220], [0, 254, 254], [84, 254, 169], [169, 254, 84], [254, 254, 0], [254, 169, 0], [102, 254, 0], [182, 255, 0]] + # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor + + +def decode_parsing(labels, num_images=1, num_classes=22, is_pred=False): + """Decode batch of segmentation masks. + + Args: + mask: result of inference after taking argmax. + num_images: number of images to decode from the batch. + num_classes: number of classes to predict (including background). + + Returns: + A batch with num_images RGB images of the same size as the input. + """ + pred_labels = labels[:num_images].clone().cpu().data + if is_pred: + pred_labels = torch.argmax(pred_labels, dim=1) + n, h, w = pred_labels.size() + + labels_color = torch.zeros([n, 3, h, w], dtype=torch.uint8) + for i, c in enumerate(COLORS): + c0 = labels_color[:, 0, :, :] + c1 = labels_color[:, 1, :, :] + c2 = labels_color[:, 2, :, :] + + c0[pred_labels == i] = c[0] + c1[pred_labels == i] = c[1] + c2[pred_labels == i] = c[2] + + return labels_color + +def inv_preprocess(imgs, num_images): + """Inverse preprocessing of the batch of images. + Add the mean vector and convert from BGR to RGB. + + Args: + imgs: batch of input images. + num_images: number of images to apply the inverse transformations on. + img_mean: vector of mean colour values. + + Returns: + The batch of the size num_images with the same spatial dimensions as the input. + """ + rev_imgs = imgs[:num_images].clone().cpu().data + rev_normalize = NormalizeInverse(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + for i in range(num_images): + rev_imgs[i] = rev_normalize(rev_imgs[i]) + + return rev_imgs + +class NormalizeInverse(torchvision.transforms.Normalize): + """ + Undoes the normalization and returns the reconstructed images in the input domain. + """ + + def __init__(self, mean, std): + mean = torch.as_tensor(mean) + std = torch.as_tensor(std) + std_inv = 1 / (std + 1e-7) + mean_inv = -mean * std_inv + super().__init__(mean=mean_inv, std=std_inv) + + +class AverageMeter: + def __init__(self, name=None): + self.name = name + self.reset() + + def reset(self): + self.sum = self.count = self.avg = 0 + + def update(self, val, n=1): + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/utils/writejson.py b/utils/writejson.py new file mode 100644 index 0000000..69aca6b --- /dev/null +++ b/utils/writejson.py @@ -0,0 +1,21 @@ +import json +import os +import cv2 + +json_file = os.path.join('/ssd1/liuting14/Dataset/LIP', 'annotations', 'test.json') + +with open(json_file) as data_file: + data_json = json.load(data_file) + data_list = data_json['root'] + +for item in data_list: + name = item['im_name'] + im_path = os.path.join('/ssd1/liuting14/Dataset/LIP', 'test_images', name) + im = cv2.imread(im_path, cv2.IMREAD_COLOR) + h, w, c = im.shape + item['img_height'] = h + item['img_width'] = w + item['center'] = [h/2, w/2] + +with open(json_file, "w") as f: + json.dump(data_json, f, indent=2)