Skip to content
This repository has been archived by the owner on Dec 18, 2024. It is now read-only.

Repri #47

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions fewshot_data/common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def initialize(cls):
def classify_prediction(cls, pred_mask, gt_mask, query_ignore_idx=None):
# gt_mask = batch.get('query_mask')

# # Apply ignore_index in PASCAL-5i masks (following evaluation scheme in PFE-Net (TPAMI 2020))
# query_ignore_idx = batch.get('query_ignore_idx')
# Apply ignore_index in PASCAL-5i masks (following evaluation scheme in PFE-Net (TPAMI 2020))
# query_ignore_idx = batch.get('query_ignore_idx')
if query_ignore_idx is not None:
assert torch.logical_and(query_ignore_idx, gt_mask).sum() == 0
# assert torch.logical_and(query_ignore_idx, gt_mask).sum() == 0
query_ignore_idx *= cls.ignore_index
gt_mask = gt_mask + query_ignore_idx
pred_mask[gt_mask == cls.ignore_index] = cls.ignore_index
Expand All @@ -27,10 +27,14 @@ def classify_prediction(cls, pred_mask, gt_mask, query_ignore_idx=None):
if _inter.size(0) == 0: # as torch.histc returns error if it gets empty tensor (pytorch 1.5.1)
_area_inter = torch.tensor([0, 0], device=_pred_mask.device)
else:
_area_inter = torch.histc(_inter, bins=2, min=0, max=1)
# _area_inter = torch.histc(_inter, bins=2, min=0, max=1)
_area_inter = torch.histc(_inter.to(torch.float32), bins=2, min=0, max=1)

area_inter.append(_area_inter)
area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1))
area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1))
# area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1))
# area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1))
area_pred.append(torch.histc(_pred_mask.to(torch.float32), bins=2, min=0, max=1))
area_gt.append(torch.histc(_gt_mask.to(torch.float32), bins=2, min=0, max=1))
area_inter = torch.stack(area_inter).t()
area_pred = torch.stack(area_pred).t()
area_gt = torch.stack(area_gt).t()
Expand Down
13 changes: 10 additions & 3 deletions fewshot_data/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ class AverageMeter:
def __init__(self, dataset):
self.benchmark = dataset.benchmark
self.class_ids_interest = dataset.class_ids
self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda()

if self.benchmark == 'pascal':
self.nclass = 20
Expand All @@ -21,8 +20,16 @@ def __init__(self, dataset):
elif self.benchmark == 'fss':
self.nclass = 1000

self.intersection_buf = torch.zeros([2, self.nclass]).float().cuda()
self.union_buf = torch.zeros([2, self.nclass]).float().cuda()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda()
self.intersection_buf = torch.zeros([2, self.nclass]).float().cuda()
self.union_buf = torch.zeros([2, self.nclass]).float().cuda()
else:
self.class_ids_interest = torch.tensor(self.class_ids_interest).cpu()
self.intersection_buf = torch.zeros([2, self.nclass]).float().cpu()
self.union_buf = torch.zeros([2, self.nclass]).float().cpu()

self.ones = torch.ones_like(self.union_buf)
self.loss_buf = []

Expand Down
8 changes: 6 additions & 2 deletions fewshot_data/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,9 @@ def to_cuda(batch):
return batch


def to_cpu(tensor):
return tensor.detach().clone().cpu()
def to_cpu(batch):
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch[key] = value.detach().clone().cpu()
# return tensor.detach().clone().cpu()
return batch
7 changes: 3 additions & 4 deletions fewshot_data/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ def initialize(cls, img_size, datapath, use_original_imgsize, imagenet_norm=Fals
if imagenet_norm:
cls.img_mean = [0.485, 0.456, 0.406]
cls.img_std = [0.229, 0.224, 0.225]
print('use norm: {}, {}'.format(cls.img_mean, cls.img_std))
else:
cls.img_mean = [0.5] * 3
cls.img_std = [0.5] * 3
print('use norm: {}, {}'.format(cls.img_mean, cls.img_std))

print('use norm: {}, {}'.format(cls.img_mean, cls.img_std))
cls.datapath = datapath
cls.use_original_imgsize = use_original_imgsize

Expand All @@ -39,4 +38,4 @@ def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1):
dataset = cls.datasets[benchmark](cls.datapath, fold=fold, transform=cls.transform, split=split, shot=shot, use_original_imgsize=cls.use_original_imgsize)
dataloader = DataLoader(dataset, batch_size=bsz, shuffle=shuffle, num_workers=nworker)

return dataloader
return dataloader, dataset
46 changes: 33 additions & 13 deletions fewshot_data/data/pascal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,20 @@ def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize)
self.img_metadata_classwise = self.build_img_metadata_classwise()

def __len__(self):
return len(self.img_metadata) if self.split == 'trn' else 1000
return len(self.img_metadata) if self.split == 'trn' else min(1000, len(self.img_metadata))

def __getitem__(self, idx):
idx %= len(self.img_metadata) # for testing, as n_images < 1000
query_name, support_names, class_sample = self.sample_episode(idx)
query_img, query_cmask, support_imgs, support_cmasks, org_qry_imsize = self.load_frame(query_name, support_names)
query_img, query_cmask, support_imgs, support_cmasks, org_qry_imsize = self.load_frame(query_name, support_names) # cmask stands for class mask

query_img = self.transform(query_img)
if not self.use_original_imgsize:
query_cmask = F.interpolate(query_cmask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
query_mask, query_ignore_idx = self.extract_ignore_idx(query_cmask.float(), class_sample)

if self.shot:
# keep all the support images into one tensor
support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])

support_masks = []
Expand All @@ -49,28 +50,33 @@ def __getitem__(self, idx):
support_mask, support_ignore_idx = self.extract_ignore_idx(scmask, class_sample)
support_masks.append(support_mask)
support_ignore_idxs.append(support_ignore_idx)

# keep all the support masks as one tensors
support_masks = torch.stack(support_masks)
support_ignore_idxs = torch.stack(support_ignore_idxs)
else:
support_masks = []
support_ignore_idxs = []

# use this batch information for testing
batch = {'query_img': query_img,
'query_mask': query_mask,
'query_name': query_name,
'query_ignore_idx': query_ignore_idx,
'query_mask': query_mask,
'query_name': query_name,
'query_ignore_idx': query_ignore_idx,

'org_query_imsize': org_qry_imsize,
'org_query_imsize': org_qry_imsize,

'support_imgs': support_imgs,
'support_masks': support_masks,
'support_names': support_names,
'support_ignore_idxs': support_ignore_idxs,
'support_imgs': support_imgs,
'support_masks': support_masks,
'support_names': support_names,
'support_ignore_idxs': support_ignore_idxs,

'class_id': torch.tensor(class_sample)}
'class_id': torch.tensor(class_sample)}

return batch

def extract_ignore_idx(self, mask, class_id):
# only get the class of interest here
boundary = (mask / 255).floor()
mask[mask != class_id + 1] = 0
mask[mask == class_id + 1] = 1
Expand All @@ -97,22 +103,32 @@ def read_img(self, img_name):
return Image.open(os.path.join(self.img_path, img_name) + '.jpg')

def sample_episode(self, idx):
# return a triple of the query image name (1 image), support image names (length of self.shot), class (one label)

# recall img_metadata is a list of tuples
query_name, class_sample = self.img_metadata[idx]

support_names = []
if self.shot:
while True: # keep sampling support set if query == support

# sample the image that contains the corresponding query class with no replacement
support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
if query_name != support_name: support_names.append(support_name)
if len(support_names) == self.shot: break

return query_name, support_names, class_sample

def build_class_ids(self):
# each fold has different number of classes
nclass_trn = self.nclass // self.nfolds

# train fold i has the same set of classes as val fold i
# note that the class id starts from 0 => that is why you minus 1 in build_img_metadata function
class_ids_val = [self.fold * nclass_trn + i for i in range(nclass_trn)]
class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val]


# return the fold of interest only.
if self.split == 'trn':
return class_ids_trn
else:
Expand All @@ -124,13 +140,15 @@ def read_metadata(split, fold_id):
fold_n_metadata = os.path.join('fewshot_data/data/splits/pascal/%s/fold%d.txt' % (split, fold_id))
with open(fold_n_metadata, 'r') as f:
fold_n_metadata = f.read().split('\n')[:-1]

# note that data.split('__')[0] is the image name and int(data.split('__')[1]) is the class id but minus 1 here
fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata]
return fold_n_metadata

img_metadata = []
if self.split == 'trn': # For training, read image-metadata of "the other" folds
for fold_id in range(self.nfolds):
if fold_id == self.fold: # Skip validation fold
if fold_id == self.fold: # Skip validation fold, the rest of the three folds
continue
img_metadata += read_metadata(self.split, fold_id)
elif self.split == 'val': # For validation, read image-metadata of "current" fold
Expand All @@ -140,9 +158,11 @@ def read_metadata(split, fold_id):

print('Total (%s) images are : %d' % (self.split, len(img_metadata)))

# this return a list of tuples of (image id, class id)
return img_metadata

def build_img_metadata_classwise(self):
# collect all images of the same class into its dictionary key
img_metadata_classwise = {}
for class_id in range(self.nclass):
img_metadata_classwise[class_id] = []
Expand Down
Loading