Skip to content
This repository has been archived by the owner on Dec 18, 2024. It is now read-only.

Commit

Permalink
fix the implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Max You authored and Max You committed Dec 18, 2023
1 parent 42b3b69 commit 91d7c12
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 124 deletions.
68 changes: 31 additions & 37 deletions lseg_repri.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __init__(self):
parser.add_argument(
'--nshot',
type=int,
default=2
default=0
)
parser.add_argument(
'--fold',
Expand Down Expand Up @@ -323,44 +323,37 @@ def episodic_validate(args):

# place it to the corresponding gpu/cpu/the ith gpu in the cluster
q_label = q_label.to(device)
spprt_imgs = spprt_imgs.to(device)
s_label = s_label.to(device)
if spprt_imgs == []:
spprt_imgs = None # no support images
s_label = None
else:
spprt_imgs = spprt_imgs.to(device)
s_label = s_label.to(device)
qry_img = qry_img.to(device)

# get the final feature tensor of the support images and the query image
f_s, t_s, img_shape = model.extract_features(spprt_imgs.squeeze(0), subcls)
f_s, t_s, _ = model.extract_features(spprt_imgs.squeeze(0) if spprt_imgs else None, subcls)
f_q, _, _ = model.extract_features(qry_img, subcls)
t_s = t_s[-1]

# # normalize the support features
# f_s = [f_s[i].unsqueeze(0).permute(0,2,3,1).reshape(-1, c) for i in range(len(f_s))]
# f_s = [(image_feature / image_feature.norm(dim=-1, keepdim=True)).reshape(c, h, w) for image_feature in f_s]
# f_s = torch.stack(f_s, dim=0)

# # normalize the query features
# f_q = f_q.permute(0,2,3,1).reshape(-1, c)
# f_q = (f_q / f_q.norm(dim=-1, keepdim=True)).reshape(c, h, w).unsqueeze(0)

# # normalize the text features
# t_s = [(text_feature / text_feature.norm(dim=-1, keepdim=True)) for text_feature in t_s]
# t_s = torch.stack(t_s, dim=0)

shot = f_s.size(0)
n_shots[i] = shot
features_s[i, :shot] = f_s.detach() # add the feature tensor of the shots to the container for each pair in the batch
if spprt_imgs:
shot = f_s.size(0)
n_shots[i] = shot
features_s[i, :shot] = f_s.detach() # add the feature tensor of the shots to the container for each pair in the batch
gt_s[i, :shot] = s_label

features_q[i] = f_q.detach() # same for the query but only one shot here
text_s[i] = t_s.detach()

# store the corresponding labels
gt_s[i, :shot] = s_label
text_s[i] = t_s.detach()
gt_q[i, 0] = q_label

# add individual class label in a batch to the container, recall item() only work for tensor that contains one element only
classes.append([class_.item() for class_ in subcls])

# =========== Normalize features along channel dimension ===============
# if args.norm_feat:
features_s = F.normalize(features_s, dim=2)
if args.shot == 0:
features_s = None
else:
features_s = F.normalize(features_s, dim=2)
features_q = F.normalize(features_q, dim=2)
text_s = F.normalize(text_s, dim=2)

Expand Down Expand Up @@ -486,7 +479,7 @@ def test(args):
'module': module,
'image_size': image_size,
'test_num': len(dataset.img_metadata), # total number of test cases
'batch_size_val': 10, # NOTE: this is different than the args.bsz
'batch_size_val': 3, # NOTE: this is different than the args.bsz
'n_runs': args.n_run, # repeat the experiment 1 time
'shot': args.nshot,
'val_loader': iter(dataloader),
Expand Down Expand Up @@ -582,13 +575,14 @@ def hyperparameter_tuning():


if __name__ == "__main__":
# args = Options().parse()
# torch.manual_seed(args.seed)
# args.temp = 20
# args.adapt_iter = 50
# args.fb_updates = [10, 30]
# args.fb_type = 'joe'
# args.cls_lr = 0.025
# test(args)

hyperparameter_tuning()
args = Options().parse()
torch.manual_seed(args.seed)
args.temp = 20
args.adapt_iter = 50
args.fb_updates = [10, 30]
args.fb_type = 'joe'
args.cls_lr = 0.025
args.n_run = 1
test(args)

# hyperparameter_tuning()
51 changes: 5 additions & 46 deletions modules/models/lseg_net_zs.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,13 @@ def __init__(

def extract_features(self, x, class_info):

# merge the batch and support dimensions to be 4-d tensor
# is_support = False
# if x.dim() == 5:
# B, S, C, H, W = x.shape
# x = x.view(-1, C, H, W)
# is_support = True

texts = [self.texts[class_i] for class_i in class_info]
# a list of batch-text encodings text_features[i] is the text embeddings for batch i
text_features = [self.clip_pretrained.encode_text(text.to(device = 'cuda' if torch.cuda.is_available() else 'cpu')) for text in texts]

if x is None:
return None, text_features, None

if self.channels_last == True:
x.contiguous(memory_format=torch.channels_last)

Expand All @@ -329,52 +327,13 @@ def extract_features(self, x, class_info):

self.logit_scale = self.logit_scale.to(x.device)

# a list of batch-text encodings text_features[i] is the text embeddings for batch i
text_features = [self.clip_pretrained.encode_text(text.to(x.device)) for text in texts]
image_features = self.scratch.head1(path_1) # [batch or n_task, c, h, w]

# unsqueeze back the batch and shot dimensions
# _, C, H, W = image_features.shape
# if is_support:
# image_features = image_features.reshape((B, S, C, H, W))

imshape = image_features.shape

return image_features, text_features, imshape

def forward(self, x, class_info):
# texts = [self.texts[class_i] for class_i in class_info]

# if self.channels_last == True:
# x.contiguous(memory_format=torch.channels_last)

# layer_1 = self.pretrained.layer1(x)
# layer_2 = self.pretrained.layer2(layer_1)
# layer_3 = self.pretrained.layer3(layer_2)
# layer_4 = self.pretrained.layer4(layer_3)

# layer_1_rn = self.scratch.layer1_rn(layer_1)
# layer_2_rn = self.scratch.layer2_rn(layer_2)
# layer_3_rn = self.scratch.layer3_rn(layer_3)
# layer_4_rn = self.scratch.layer4_rn(layer_4)

# path_4 = self.scratch.refinenet4(layer_4_rn)
# path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
# path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
# path_1 = self.scratch.refinenet1(path_2, layer_1_rn)

# self.logit_scale = self.logit_scale.to(x.device)
# text_features = [self.clip_pretrained.encode_text(text.to(x.device)) for text in texts]

# image_features = self.scratch.head1(path_1)

# imshape = image_features.shape
# image_features = [image_features[i].unsqueeze(0).permute(0,2,3,1).reshape(-1, self.out_c) for i in range(len(image_features))]

# # normalized features
# image_features = [image_feature / image_feature.norm(dim=-1, keepdim=True) for image_feature in image_features]
# text_features = [text_feature / text_feature.norm(dim=-1, keepdim=True) for text_feature in text_features]

image_features, text_features, imshape = self.extract_features(x, class_info)

# seperate the batch into a list of per-image features
Expand Down
88 changes: 47 additions & 41 deletions repri_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,48 +106,36 @@ def __init__(self, args):

def init_prototypes(self, features_s: torch.tensor, features_q: torch.tensor, text_s: torch.tensor,
gt_s: torch.tensor, gt_q: torch.tensor, subcls: List[int],
callback) -> None:
callback, alpha=.5) -> None:
"""
inputs:
features_s : shape [n_task, shot, c, h, w]
features_q : shape [n_task, 1, c, h, w]
text_s: shape [n_task, texts = 1, c]
text_s: shape [n_task, n_label_text, c]
gt_s : shape [n_task, shot, H, W]
gt_q : shape [n_task, 1, H, W]
alpha: interpolation factor
returns :
prototypes : shape [n_task, c]
bias : shape [n_task]
"""

# DownSample support masks
n_task, shot, c, h, w = features_s.size()
ds_gt_s = F.interpolate(gt_s.float(), size=features_s.shape[-2:], mode='nearest')
ds_gt_s = ds_gt_s.long().unsqueeze(2) # [n_task, shot, 1, h, w]

# Computing prototypes
fg_mask = (ds_gt_s == 1) # [n_tasks, shot, 1, 240, 240]

# ============= NOTE: Proposed Solution with text embedding only Begin =============
fg_prototype = text_s[:, -1, :]
# ============= NOTE: Proposed Solution with text embedding only End ===============

# ============= NOTE: Proposed Solution with text embedding + few-shot support images Begin =============
# fg_prototype = text_s[:, -1, :]
# ============= NOTE: Proposed Solution with text embedding + few-shot support images End ===============


# ============= NOTE: Alternative Solution 1 Start =============

# # per-shot masked average pooling or per-shot prototype
# fg_prototype = (features_s * fg_mask).sum(dim=(3, 4))
# fg_prototype /= (fg_mask.sum(dim=(3, 4)) + 1e-10) # [n_tasks, shot, c]

# # average the prototype across all shots including the label embeddings
# fg_prototype = torch.concat([fg_prototype, text_s[:, -1, :].unsqueeze(1)], dim=1) # [n_tasks, shot + 1, c]
# fg_prototype = torch.mean(fg_prototype, dim=1)

# ============= NOTE: Alternative Solution 1 End =============
n_task, _, c, _, _ = features_q.size()
if features_s is None: # zero-shot with text embedding only
fg_prototype = text_s[:, -1, :]
else: # text embedding + few-shot segmentation
ds_gt_s = F.interpolate(gt_s.float(), size=features_s.shape[-2:], mode='nearest')
ds_gt_s = ds_gt_s.long().unsqueeze(2) # [n_task, shot, 1, h, w]

# Compute the mask region of each support for each task
fg_mask = (ds_gt_s == 1) # [n_tasks, shot, 1, h, w]
# find the prototype features from all the support images
fg_prototype = (features_s * fg_mask).sum(dim=(1, 3, 4))
fg_prototype /= (fg_mask.sum(dim=(1, 3, 4)) + 1e-10) # [n_tasks, shot, c]

# interpolation between the class embedding and the masked average support features
fg_prototype = alpha * text_s[:, -1, :] + (1-alpha) * fg_prototype

# store the initial weight / support prototype before finding the logits
self.prototype = fg_prototype
Expand All @@ -162,6 +150,11 @@ def init_prototypes(self, features_s: torch.tensor, features_q: torch.tensor, te
if callback is not None:
self.update_callback(callback, 0, features_s, features_q, subcls, gt_s, gt_q)

# LEGACY CODE
# # average the prototype across all shots including the label embeddings
# fg_prototype = torch.concat([fg_prototype, text_s[:, -1, :].unsqueeze(1)], dim=1) # [n_tasks, shot + 1, c]
# fg_prototype = torch.mean(fg_prototype, dim=1) # average over the shots

def get_logits(self, features: torch.tensor) -> torch.tensor:

"""
Expand Down Expand Up @@ -308,7 +301,7 @@ def RePRI(self,
Performs RePRI inference
inputs:
features_s : shape [n_tasks, shot, c, h, w]
features_s : shape [n_tasks, shot, c, h, w] or None
features_q : shape [n_tasks, shot, c, h, w]
gt_s : shape [n_tasks, shot, h, w]
gt_q : shape [n_tasks, shot, h, w]
Expand Down Expand Up @@ -339,36 +332,49 @@ def RePRI(self,
optimizer = torch.optim.SGD([self.prototype, self.bias], lr=self.lr)

# downsample the groundth truth query and support
ds_gt_q = F.interpolate(gt_q.float(), size=features_s.size()[-2:], mode='nearest').long()
ds_gt_s = F.interpolate(gt_s.float(), size=features_s.size()[-2:], mode='nearest').long()

ds_gt_q = F.interpolate(gt_q.float(), size=features_q.size()[-2:], mode='nearest').long()
valid_pixels_q = (ds_gt_q != 255).float() # [n_tasks, shot, h, w]
valid_pixels_s = (ds_gt_s != 255).float() # [n_tasks, shot, h, w]

one_hot_gt_s = to_one_hot(ds_gt_s, self.num_classes) # [n_tasks, shot, num_classes, h, w]
ds_gt_s = None
valid_pixels_s = None
one_hot_gt_s = None
if features_s is not None:
ds_gt_s = F.interpolate(gt_s.float(), size=features_s.size()[-2:], mode='nearest').long()
valid_pixels_s = (ds_gt_s != 255).float() # [n_tasks, shot, h, w]
one_hot_gt_s = to_one_hot(ds_gt_s, self.num_classes) # [n_tasks, shot, num_classes, h, w]

# self.adapt_iter is a hyperparameter to optimise
for iteration in range(1, self.adapt_iter):

proba_s = None
if features_s is not None:
logits_s = self.get_logits(features_s) # [n_tasks, shot, num_class, h, w]
proba_s = self.get_probas(logits_s)

logits_s = self.get_logits(features_s) # [n_tasks, shot, num_class, h, w]
logits_q = self.get_logits(features_q) # [n_tasks, 1, num_class, h, w]
proba_q = self.get_probas(logits_q)
proba_s = self.get_probas(logits_s)

d_kl, cond_entropy, marginal = self.get_entropies(valid_pixels_q,
proba_q,
reduction='none')
ce = self.get_ce(proba_s, valid_pixels_s, one_hot_gt_s, reduction='none')
loss = l1 * ce + l2 * d_kl + l3 * cond_entropy

if proba_s and one_hot_gt_s and valid_pixels_s:
ce = self.get_ce(proba_s, valid_pixels_s, one_hot_gt_s, reduction='none')
loss = l1 * ce + l2 * d_kl + l3 * cond_entropy
else:
loss = d_kl + cond_entropy

optimizer.zero_grad()
loss.sum(0).backward()
optimizer.step()

# Update FB_param
if (iteration + 1) in self.FB_param_update and ('oracle' not in self.FB_param_type) and (l2.sum().item() != 0):
# for few-shot
if (iteration + 1) in self.FB_param_update and ('oracle' not in self.FB_param_type) and features_s is not None and (l2.sum().item() != 0):
deltas = self.compute_FB_param(features_q, gt_q).cpu()
l2 += 1
elif (iteration + 1) in self.FB_param_update and ('oracle' not in self.FB_param_type): # for 0 shot
deltas = self.compute_FB_param(features_q, gt_q).cpu()

if callback is not None and (iteration + 1) % self.visdom_freq == 0:
self.update_callback(callback, iteration, features_s, features_q, subcls, gt_s, gt_q)
Expand Down

0 comments on commit 91d7c12

Please sign in to comment.