-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 94b0ade
Showing
97 changed files
with
4,691 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
sagemaker_job.py | ||
export_model.py | ||
secrets.env | ||
wandb/* | ||
__pycache__/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
Implementation | ||
|
||
Dataset | ||
|
||
- This model is trained on CCIHP dataset which contains 22 class labels. | ||
|
||
Please download imagenet pretrained resent-101 from [baidu drive](https://pan.baidu.com/s/1NoxI_JetjSVa7uqgVSKdPw) or [Google drive](https://drive.google.com/open?id=1rzLU-wK6rEorCNJfwrmIu5hY2wRMyKTK), and put it into dataset folder. | ||
|
||
#### Training | ||
|
||
- Set necessary arguments and run `train_simplified.py`. | ||
|
||
Citation: | ||
|
||
@InProceedings{Liu_2022_CVPR, | ||
author = {Liu, Kunliang and Choi, Ouk and Wang, Jianming and Hwang, Wonjun}, | ||
title = {CDGNet: Class Distribution Guided Network for Human Parsing}, | ||
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, | ||
month = {June}, | ||
year = {2022}, | ||
pages = {4473-4482} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
Requirements | ||
|
||
Pytorch 1.9.0 | ||
torchvision 0.11.0 | ||
scipy 1.5.2 | ||
cudatoolkit 11.3.1 | ||
tensorboardX 2.2 | ||
torchvision 0.11.0 | ||
Python 3.7 |
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
import os | ||
import numpy as np | ||
import random | ||
import torch | ||
import cv2 | ||
import json | ||
import sys | ||
sys.path.insert(0, '.') | ||
from torch.utils import data | ||
from torch.utils.data import DataLoader | ||
import matplotlib.pyplot as plt | ||
from dataset.target_generation import generate_edge, generate_hw_gt | ||
from utils.transforms import get_affine_transform | ||
from utils.ImgTransforms import AugmentationBlock, autoaug_imagenet_policies | ||
from utils.utils import decode_parsing | ||
|
||
|
||
|
||
# statisticSeg=[ 30462,7026,21054,2404,1660,23165,1201,8182,2178,16224, | ||
# 455,518,634,24418,18539,20033,4763,4832,8126,8166] | ||
class LIPDataSet(data.Dataset): | ||
def __init__(self, root, dataset, crop_size=[473, 473], scale_factor=0.25, | ||
rotation_factor=30, ignore_label=255, transform=None): | ||
""" | ||
:rtype: | ||
""" | ||
self.root = root | ||
self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0] | ||
self.crop_size = np.asarray(crop_size) | ||
self.ignore_label = ignore_label | ||
self.scale_factor = scale_factor | ||
self.rotation_factor = rotation_factor | ||
self.flip_prob = 0.5 | ||
self.flip_pairs = [[0, 5], [1, 4], [2, 3], [11, 14], [12, 13], [10, 15]] | ||
self.transform = transform | ||
self.dataset = dataset | ||
# self.statSeg = np.array( statisticSeg, dtype ='float') | ||
# self.statSeg = self.statSeg/30462 | ||
|
||
list_path = os.path.join(self.root, self.dataset + '_id.txt') | ||
|
||
self.im_list = [i_id.strip() for i_id in open(list_path)] | ||
# if dataset != 'val': | ||
# im_list_2 = [] | ||
# for i in range(len(self.im_list)): | ||
# if i % 5 ==0: | ||
# im_list_2.append(self.im_list[i]) | ||
# self.im_list = im_list_2 | ||
self.number_samples = len(self.im_list) | ||
#================================================================================ | ||
self.augBlock = AugmentationBlock( autoaug_imagenet_policies ) | ||
#================================================================================ | ||
def __len__(self): | ||
return self.number_samples | ||
|
||
def _box2cs(self, box): | ||
x, y, w, h = box[:4] | ||
return self._xywh2cs(x, y, w, h) | ||
|
||
def _xywh2cs(self, x, y, w, h): | ||
center = np.zeros((2), dtype=np.float32) | ||
center[0] = x + w * 0.5 | ||
center[1] = y + h * 0.5 | ||
if w > self.aspect_ratio * h: | ||
h = w * 1.0 / self.aspect_ratio | ||
elif w < self.aspect_ratio * h: | ||
w = h * self.aspect_ratio | ||
scale = np.array([w * 1.0, h * 1.0], dtype=np.float32) | ||
|
||
return center, scale | ||
|
||
def __getitem__(self, index): | ||
# Load training image | ||
im_name = self.im_list[index] | ||
|
||
im_path = os.path.join(self.root, self.dataset + '_images', im_name + '.jpg') | ||
#print(im_path) | ||
parsing_anno_path = os.path.join(self.root, self.dataset + '_segmentations', im_name + '.png') | ||
|
||
im = cv2.imread(im_path, cv2.IMREAD_COLOR) | ||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) | ||
#================================================= | ||
if self.dataset != 'val': | ||
im = self.augBlock( im ) | ||
#================================================= | ||
h, w, _ = im.shape | ||
parsing_anno = np.zeros((h, w), dtype=np.long) | ||
|
||
# Get center and scale | ||
center, s = self._box2cs([0, 0, w - 1, h - 1]) | ||
r = 0 | ||
|
||
if self.dataset != 'test': | ||
parsing_anno = cv2.imread(parsing_anno_path, cv2.IMREAD_GRAYSCALE) | ||
|
||
if self.dataset == 'train' or self.dataset == 'trainval': | ||
|
||
sf = self.scale_factor | ||
rf = self.rotation_factor | ||
s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf) | ||
r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) \ | ||
if random.random() <= 0.6 else 0 | ||
|
||
if random.random() <= self.flip_prob: | ||
im = im[:, ::-1, :] | ||
parsing_anno = parsing_anno[:, ::-1] | ||
|
||
center[0] = im.shape[1] - center[0] - 1 | ||
right_idx = [15, 17, 19] | ||
left_idx = [14, 16, 18] | ||
for i in range(0, 3): | ||
right_pos = np.where(parsing_anno == right_idx[i]) | ||
left_pos = np.where(parsing_anno == left_idx[i]) | ||
parsing_anno[right_pos[0], right_pos[1]] = left_idx[i] | ||
parsing_anno[left_pos[0], left_pos[1]] = right_idx[i] | ||
|
||
trans = get_affine_transform(center, s, r, self.crop_size) | ||
input = cv2.warpAffine( | ||
im, | ||
trans, | ||
(int(self.crop_size[1]), int(self.crop_size[0])), | ||
flags=cv2.INTER_LINEAR, | ||
borderMode=cv2.BORDER_CONSTANT, | ||
borderValue=(0, 0, 0)) | ||
|
||
if self.transform: | ||
input = self.transform(input) | ||
|
||
meta = { | ||
'name': im_name, | ||
'center': center, | ||
'height': h, | ||
'width': w, | ||
'scale': s, | ||
'rotation': r | ||
} | ||
|
||
if self.dataset != 'train': | ||
return input, meta | ||
else: | ||
label_parsing = cv2.warpAffine( | ||
parsing_anno, | ||
trans, | ||
(int(self.crop_size[1]), int(self.crop_size[0])), | ||
flags=cv2.INTER_NEAREST, | ||
borderMode=cv2.BORDER_CONSTANT, | ||
borderValue=(255)) | ||
|
||
# label_edge = generate_edge(label_parsing) | ||
hgt, wgt, hwgt = generate_hw_gt(label_parsing) | ||
label_parsing = torch.from_numpy(label_parsing) | ||
# label_edge = torch.from_numpy(label_edge) | ||
|
||
return input, label_parsing, hgt,wgt,hwgt, meta | ||
|
||
class LIPDataValSet(data.Dataset): | ||
def __init__(self, root, dataset='val', crop_size=[512, 512], transform=None, flip=False): | ||
self.root = root | ||
self.crop_size = crop_size | ||
self.transform = transform | ||
self.flip = flip | ||
self.dataset = dataset | ||
self.root = root | ||
self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0] | ||
self.crop_size = np.asarray(crop_size) | ||
|
||
list_path = os.path.join(self.root, self.dataset + '_id.txt') | ||
val_list = [i_id.strip() for i_id in open(list_path)] | ||
|
||
self.val_list = val_list | ||
self.number_samples = len(self.val_list) | ||
|
||
def __len__(self): | ||
return len(self.val_list) | ||
|
||
def _box2cs(self, box): | ||
x, y, w, h = box[:4] | ||
return self._xywh2cs(x, y, w, h) | ||
|
||
def _xywh2cs(self, x, y, w, h): | ||
center = np.zeros((2), dtype=np.float32) | ||
center[0] = x + w * 0.5 | ||
center[1] = y + h * 0.5 | ||
if w > self.aspect_ratio * h: | ||
h = w * 1.0 / self.aspect_ratio | ||
elif w < self.aspect_ratio * h: | ||
w = h * self.aspect_ratio | ||
scale = np.array([w * 1.0, h * 1.0], dtype=np.float32) | ||
|
||
return center, scale | ||
|
||
def __getitem__(self, index): | ||
val_item = self.val_list[index] | ||
# Load training image | ||
im_path = os.path.join(self.root, self.dataset + '_images', val_item + '.jpg') | ||
im = cv2.imread(im_path, cv2.IMREAD_COLOR) | ||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) | ||
h, w, _ = im.shape | ||
# Get person center and scale | ||
person_center, s = self._box2cs([0, 0, w - 1, h - 1]) | ||
r = 0 | ||
trans = get_affine_transform(person_center, s, r, self.crop_size) | ||
input = cv2.warpAffine( | ||
im, | ||
trans, | ||
(int(self.crop_size[1]), int(self.crop_size[0])), | ||
flags=cv2.INTER_LINEAR, | ||
borderMode=cv2.BORDER_CONSTANT, | ||
borderValue=(0, 0, 0)) | ||
input = self.transform(input) | ||
flip_input = input.flip(dims=[-1]) | ||
if self.flip: | ||
batch_input_im = torch.stack([input, flip_input]) | ||
else: | ||
batch_input_im = input | ||
|
||
meta = { | ||
'name': val_item, | ||
'center': person_center, | ||
'height': h, | ||
'width': w, | ||
'scale': s, | ||
'rotation': r | ||
} | ||
|
||
return batch_input_im, meta | ||
|
||
''' | ||
root = '/home/vrushank/Spyne/CCIHP' | ||
dataset = 'train' | ||
data1 = LIPDataValSet(root, dataset, crop_size=[512, 512]) | ||
loader = DataLoader(data1, batch_size = 1, shuffle = True) | ||
for idx, (input, label_parsing, hgt,wgt,hwgt, meta) in enumerate(loader): | ||
if idx == 0: | ||
print(input.shape) | ||
print(label_parsing.shape) | ||
ip = input.squeeze(0).cpu().numpy() | ||
label = decode_parsing(label_parsing, num_classes = 22) | ||
print(type(label)) | ||
label = label[0].data.cpu().numpy() | ||
label = label.transpose((1,2,0)) | ||
#label = cv2.cvtColor(label, cv2.COLOR_GRAY2BGR) | ||
print(ip.shape) | ||
print(label.shape) | ||
res = np.concatenate((ip, label), axis = 1) | ||
plt.imshow(res) | ||
plt.show() | ||
#print(f'{hgt}: {hgt.shape}') | ||
#print(f'{wgt}: {wgt.shape}') | ||
#print(f'{hwgt}: {hwgt.shape}') | ||
else: | ||
break | ||
''' |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import os | ||
import sys | ||
import numpy as np | ||
import random | ||
import cv2 | ||
import torch | ||
from torch.nn import functional as F | ||
|
||
def generate_hw_gt( target, class_num = 22 ): | ||
h,w = target.shape | ||
target = torch.from_numpy(target) | ||
target_c = target.clone() | ||
target_c[target_c==255]=0 | ||
target_c = target_c.long() | ||
target_c = target_c.view(h*w) | ||
target_c = target_c.unsqueeze(1) | ||
target_onehot = torch.zeros(h*w,class_num) | ||
target_onehot.scatter_( 1, target_c, 1 ) #h*w,class_num | ||
target_onehot = target_onehot.transpose(0,1) | ||
target_onehot = target_onehot.view(class_num,h,w) | ||
# h distribution ground truth | ||
hgt = torch.zeros((class_num,h)) | ||
hgt=( torch.sum( target_onehot, dim=2 ) ).float() | ||
hgt[0,:] = 0 | ||
max = torch.max(hgt,dim=1)[0] #c,1 | ||
min = torch.min(hgt,dim=1)[0] | ||
max = max.unsqueeze(1) | ||
min = min.unsqueeze(1) | ||
hgt = hgt / ( max + 1e-5 ) | ||
# w distribution gound truth | ||
wgt = torch.zeros((class_num,w)) | ||
wgt=( torch.sum(target_onehot, dim=1 ) ).float() | ||
wgt[0,:]=0 | ||
max = torch.max(wgt,dim=1)[0] #c,1 | ||
min = torch.min(wgt,dim=1)[0] | ||
max = max.unsqueeze(1) | ||
min = min.unsqueeze(1) | ||
wgt = wgt / ( max + 1e-5 ) | ||
#=========================================================== | ||
hwgt = torch.matmul( hgt.transpose(0,1), wgt ) | ||
max = torch.max( hwgt.view(-1), dim=0 )[0] | ||
# print(max) | ||
hwgt = hwgt / ( max + 1.0e-5 ) | ||
#==================================================================== | ||
return hgt, wgt, hwgt #,cch, ccw gt_hw | ||
|
||
def generate_edge(label, edge_width=3): | ||
label = label.type(torch.cuda.FloatTensor) | ||
if len(label.shape) == 2: | ||
label = label.unsqueeze(0) | ||
n, h, w = label.shape | ||
edge = torch.zeros(label.shape, dtype=torch.float).cuda() | ||
# right | ||
edge_right = edge[:, 1:h, :] | ||
edge_right[(label[:, 1:h, :] != label[:, :h - 1, :]) & (label[:, 1:h, :] != 255) | ||
& (label[:, :h - 1, :] != 255)] = 1 | ||
|
||
# up | ||
edge_up = edge[:, :, :w - 1] | ||
edge_up[(label[:, :, :w - 1] != label[:, :, 1:w]) | ||
& (label[:, :, :w - 1] != 255) | ||
& (label[:, :, 1:w] != 255)] = 1 | ||
|
||
# upright | ||
edge_upright = edge[:, :h - 1, :w - 1] | ||
edge_upright[(label[:, :h - 1, :w - 1] != label[:, 1:h, 1:w]) | ||
& (label[:, :h - 1, :w - 1] != 255) | ||
& (label[:, 1:h, 1:w] != 255)] = 1 | ||
|
||
# bottomright | ||
edge_bottomright = edge[:, :h - 1, 1:w] | ||
edge_bottomright[(label[:, :h - 1, 1:w] != label[:, 1:h, :w - 1]) | ||
& (label[:, :h - 1, 1:w] != 255) | ||
& (label[:, 1:h, :w - 1] != 255)] = 1 | ||
|
||
kernel = torch.ones((1, 1, edge_width, edge_width), dtype=torch.float).cuda() | ||
with torch.no_grad(): | ||
edge = edge.unsqueeze(1) | ||
edge = F.conv2d(edge, kernel, stride=1, padding=1) | ||
edge[edge!=0] = 1 | ||
edge = edge.squeeze() | ||
return edge |
Oops, something went wrong.