From 91d7c1201d99245a562be8908f06e5c02bbf41d9 Mon Sep 17 00:00:00 2001 From: Max You Date: Sun, 17 Dec 2023 22:11:18 -0500 Subject: [PATCH] fix the implementation --- lseg_repri.py | 68 ++++++++++++--------------- modules/models/lseg_net_zs.py | 51 ++------------------ repri_classifier.py | 88 +++++++++++++++++++---------------- 3 files changed, 83 insertions(+), 124 deletions(-) diff --git a/lseg_repri.py b/lseg_repri.py index 29c3a0d..1f782ae 100644 --- a/lseg_repri.py +++ b/lseg_repri.py @@ -202,7 +202,7 @@ def __init__(self): parser.add_argument( '--nshot', type=int, - default=2 + default=0 ) parser.add_argument( '--fold', @@ -323,44 +323,37 @@ def episodic_validate(args): # place it to the corresponding gpu/cpu/the ith gpu in the cluster q_label = q_label.to(device) - spprt_imgs = spprt_imgs.to(device) - s_label = s_label.to(device) + if spprt_imgs == []: + spprt_imgs = None # no support images + s_label = None + else: + spprt_imgs = spprt_imgs.to(device) + s_label = s_label.to(device) qry_img = qry_img.to(device) # get the final feature tensor of the support images and the query image - f_s, t_s, img_shape = model.extract_features(spprt_imgs.squeeze(0), subcls) + f_s, t_s, _ = model.extract_features(spprt_imgs.squeeze(0) if spprt_imgs else None, subcls) f_q, _, _ = model.extract_features(qry_img, subcls) t_s = t_s[-1] - # # normalize the support features - # f_s = [f_s[i].unsqueeze(0).permute(0,2,3,1).reshape(-1, c) for i in range(len(f_s))] - # f_s = [(image_feature / image_feature.norm(dim=-1, keepdim=True)).reshape(c, h, w) for image_feature in f_s] - # f_s = torch.stack(f_s, dim=0) - - # # normalize the query features - # f_q = f_q.permute(0,2,3,1).reshape(-1, c) - # f_q = (f_q / f_q.norm(dim=-1, keepdim=True)).reshape(c, h, w).unsqueeze(0) - - # # normalize the text features - # t_s = [(text_feature / text_feature.norm(dim=-1, keepdim=True)) for text_feature in t_s] - # t_s = torch.stack(t_s, dim=0) - - shot = f_s.size(0) - n_shots[i] = shot - features_s[i, :shot] = f_s.detach() # add the feature tensor of the shots to the container for each pair in the batch + if spprt_imgs: + shot = f_s.size(0) + n_shots[i] = shot + features_s[i, :shot] = f_s.detach() # add the feature tensor of the shots to the container for each pair in the batch + gt_s[i, :shot] = s_label + features_q[i] = f_q.detach() # same for the query but only one shot here - text_s[i] = t_s.detach() - - # store the corresponding labels - gt_s[i, :shot] = s_label + text_s[i] = t_s.detach() gt_q[i, 0] = q_label # add individual class label in a batch to the container, recall item() only work for tensor that contains one element only classes.append([class_.item() for class_ in subcls]) # =========== Normalize features along channel dimension =============== - # if args.norm_feat: - features_s = F.normalize(features_s, dim=2) + if args.shot == 0: + features_s = None + else: + features_s = F.normalize(features_s, dim=2) features_q = F.normalize(features_q, dim=2) text_s = F.normalize(text_s, dim=2) @@ -486,7 +479,7 @@ def test(args): 'module': module, 'image_size': image_size, 'test_num': len(dataset.img_metadata), # total number of test cases - 'batch_size_val': 10, # NOTE: this is different than the args.bsz + 'batch_size_val': 3, # NOTE: this is different than the args.bsz 'n_runs': args.n_run, # repeat the experiment 1 time 'shot': args.nshot, 'val_loader': iter(dataloader), @@ -582,13 +575,14 @@ def hyperparameter_tuning(): if __name__ == "__main__": - # args = Options().parse() - # torch.manual_seed(args.seed) - # args.temp = 20 - # args.adapt_iter = 50 - # args.fb_updates = [10, 30] - # args.fb_type = 'joe' - # args.cls_lr = 0.025 - # test(args) - - hyperparameter_tuning() \ No newline at end of file + args = Options().parse() + torch.manual_seed(args.seed) + args.temp = 20 + args.adapt_iter = 50 + args.fb_updates = [10, 30] + args.fb_type = 'joe' + args.cls_lr = 0.025 + args.n_run = 1 + test(args) + + # hyperparameter_tuning() \ No newline at end of file diff --git a/modules/models/lseg_net_zs.py b/modules/models/lseg_net_zs.py index 73676eb..374c7c6 100644 --- a/modules/models/lseg_net_zs.py +++ b/modules/models/lseg_net_zs.py @@ -300,15 +300,13 @@ def __init__( def extract_features(self, x, class_info): - # merge the batch and support dimensions to be 4-d tensor - # is_support = False - # if x.dim() == 5: - # B, S, C, H, W = x.shape - # x = x.view(-1, C, H, W) - # is_support = True - texts = [self.texts[class_i] for class_i in class_info] + # a list of batch-text encodings text_features[i] is the text embeddings for batch i + text_features = [self.clip_pretrained.encode_text(text.to(device = 'cuda' if torch.cuda.is_available() else 'cpu')) for text in texts] + if x is None: + return None, text_features, None + if self.channels_last == True: x.contiguous(memory_format=torch.channels_last) @@ -329,52 +327,13 @@ def extract_features(self, x, class_info): self.logit_scale = self.logit_scale.to(x.device) - # a list of batch-text encodings text_features[i] is the text embeddings for batch i - text_features = [self.clip_pretrained.encode_text(text.to(x.device)) for text in texts] image_features = self.scratch.head1(path_1) # [batch or n_task, c, h, w] - - # unsqueeze back the batch and shot dimensions - # _, C, H, W = image_features.shape - # if is_support: - # image_features = image_features.reshape((B, S, C, H, W)) imshape = image_features.shape return image_features, text_features, imshape def forward(self, x, class_info): - # texts = [self.texts[class_i] for class_i in class_info] - - # if self.channels_last == True: - # x.contiguous(memory_format=torch.channels_last) - - # layer_1 = self.pretrained.layer1(x) - # layer_2 = self.pretrained.layer2(layer_1) - # layer_3 = self.pretrained.layer3(layer_2) - # layer_4 = self.pretrained.layer4(layer_3) - - # layer_1_rn = self.scratch.layer1_rn(layer_1) - # layer_2_rn = self.scratch.layer2_rn(layer_2) - # layer_3_rn = self.scratch.layer3_rn(layer_3) - # layer_4_rn = self.scratch.layer4_rn(layer_4) - - # path_4 = self.scratch.refinenet4(layer_4_rn) - # path_3 = self.scratch.refinenet3(path_4, layer_3_rn) - # path_2 = self.scratch.refinenet2(path_3, layer_2_rn) - # path_1 = self.scratch.refinenet1(path_2, layer_1_rn) - - # self.logit_scale = self.logit_scale.to(x.device) - # text_features = [self.clip_pretrained.encode_text(text.to(x.device)) for text in texts] - - # image_features = self.scratch.head1(path_1) - - # imshape = image_features.shape - # image_features = [image_features[i].unsqueeze(0).permute(0,2,3,1).reshape(-1, self.out_c) for i in range(len(image_features))] - - # # normalized features - # image_features = [image_feature / image_feature.norm(dim=-1, keepdim=True) for image_feature in image_features] - # text_features = [text_feature / text_feature.norm(dim=-1, keepdim=True) for text_feature in text_features] - image_features, text_features, imshape = self.extract_features(x, class_info) # seperate the batch into a list of per-image features diff --git a/repri_classifier.py b/repri_classifier.py index 77d30f4..2518666 100644 --- a/repri_classifier.py +++ b/repri_classifier.py @@ -106,48 +106,36 @@ def __init__(self, args): def init_prototypes(self, features_s: torch.tensor, features_q: torch.tensor, text_s: torch.tensor, gt_s: torch.tensor, gt_q: torch.tensor, subcls: List[int], - callback) -> None: + callback, alpha=.5) -> None: """ inputs: features_s : shape [n_task, shot, c, h, w] features_q : shape [n_task, 1, c, h, w] - text_s: shape [n_task, texts = 1, c] + text_s: shape [n_task, n_label_text, c] gt_s : shape [n_task, shot, H, W] gt_q : shape [n_task, 1, H, W] - + alpha: interpolation factor returns : prototypes : shape [n_task, c] bias : shape [n_task] """ # DownSample support masks - n_task, shot, c, h, w = features_s.size() - ds_gt_s = F.interpolate(gt_s.float(), size=features_s.shape[-2:], mode='nearest') - ds_gt_s = ds_gt_s.long().unsqueeze(2) # [n_task, shot, 1, h, w] - - # Computing prototypes - fg_mask = (ds_gt_s == 1) # [n_tasks, shot, 1, 240, 240] - - # ============= NOTE: Proposed Solution with text embedding only Begin ============= - fg_prototype = text_s[:, -1, :] - # ============= NOTE: Proposed Solution with text embedding only End =============== - - # ============= NOTE: Proposed Solution with text embedding + few-shot support images Begin ============= - # fg_prototype = text_s[:, -1, :] - # ============= NOTE: Proposed Solution with text embedding + few-shot support images End =============== - - - # ============= NOTE: Alternative Solution 1 Start ============= - - # # per-shot masked average pooling or per-shot prototype - # fg_prototype = (features_s * fg_mask).sum(dim=(3, 4)) - # fg_prototype /= (fg_mask.sum(dim=(3, 4)) + 1e-10) # [n_tasks, shot, c] - - # # average the prototype across all shots including the label embeddings - # fg_prototype = torch.concat([fg_prototype, text_s[:, -1, :].unsqueeze(1)], dim=1) # [n_tasks, shot + 1, c] - # fg_prototype = torch.mean(fg_prototype, dim=1) - - # ============= NOTE: Alternative Solution 1 End ============= + n_task, _, c, _, _ = features_q.size() + if features_s is None: # zero-shot with text embedding only + fg_prototype = text_s[:, -1, :] + else: # text embedding + few-shot segmentation + ds_gt_s = F.interpolate(gt_s.float(), size=features_s.shape[-2:], mode='nearest') + ds_gt_s = ds_gt_s.long().unsqueeze(2) # [n_task, shot, 1, h, w] + + # Compute the mask region of each support for each task + fg_mask = (ds_gt_s == 1) # [n_tasks, shot, 1, h, w] + # find the prototype features from all the support images + fg_prototype = (features_s * fg_mask).sum(dim=(1, 3, 4)) + fg_prototype /= (fg_mask.sum(dim=(1, 3, 4)) + 1e-10) # [n_tasks, shot, c] + + # interpolation between the class embedding and the masked average support features + fg_prototype = alpha * text_s[:, -1, :] + (1-alpha) * fg_prototype # store the initial weight / support prototype before finding the logits self.prototype = fg_prototype @@ -162,6 +150,11 @@ def init_prototypes(self, features_s: torch.tensor, features_q: torch.tensor, te if callback is not None: self.update_callback(callback, 0, features_s, features_q, subcls, gt_s, gt_q) + # LEGACY CODE + # # average the prototype across all shots including the label embeddings + # fg_prototype = torch.concat([fg_prototype, text_s[:, -1, :].unsqueeze(1)], dim=1) # [n_tasks, shot + 1, c] + # fg_prototype = torch.mean(fg_prototype, dim=1) # average over the shots + def get_logits(self, features: torch.tensor) -> torch.tensor: """ @@ -308,7 +301,7 @@ def RePRI(self, Performs RePRI inference inputs: - features_s : shape [n_tasks, shot, c, h, w] + features_s : shape [n_tasks, shot, c, h, w] or None features_q : shape [n_tasks, shot, c, h, w] gt_s : shape [n_tasks, shot, h, w] gt_q : shape [n_tasks, shot, h, w] @@ -339,36 +332,49 @@ def RePRI(self, optimizer = torch.optim.SGD([self.prototype, self.bias], lr=self.lr) # downsample the groundth truth query and support - ds_gt_q = F.interpolate(gt_q.float(), size=features_s.size()[-2:], mode='nearest').long() - ds_gt_s = F.interpolate(gt_s.float(), size=features_s.size()[-2:], mode='nearest').long() - + ds_gt_q = F.interpolate(gt_q.float(), size=features_q.size()[-2:], mode='nearest').long() valid_pixels_q = (ds_gt_q != 255).float() # [n_tasks, shot, h, w] - valid_pixels_s = (ds_gt_s != 255).float() # [n_tasks, shot, h, w] - one_hot_gt_s = to_one_hot(ds_gt_s, self.num_classes) # [n_tasks, shot, num_classes, h, w] + ds_gt_s = None + valid_pixels_s = None + one_hot_gt_s = None + if features_s is not None: + ds_gt_s = F.interpolate(gt_s.float(), size=features_s.size()[-2:], mode='nearest').long() + valid_pixels_s = (ds_gt_s != 255).float() # [n_tasks, shot, h, w] + one_hot_gt_s = to_one_hot(ds_gt_s, self.num_classes) # [n_tasks, shot, num_classes, h, w] # self.adapt_iter is a hyperparameter to optimise for iteration in range(1, self.adapt_iter): + + proba_s = None + if features_s is not None: + logits_s = self.get_logits(features_s) # [n_tasks, shot, num_class, h, w] + proba_s = self.get_probas(logits_s) - logits_s = self.get_logits(features_s) # [n_tasks, shot, num_class, h, w] logits_q = self.get_logits(features_q) # [n_tasks, 1, num_class, h, w] proba_q = self.get_probas(logits_q) - proba_s = self.get_probas(logits_s) d_kl, cond_entropy, marginal = self.get_entropies(valid_pixels_q, proba_q, reduction='none') - ce = self.get_ce(proba_s, valid_pixels_s, one_hot_gt_s, reduction='none') - loss = l1 * ce + l2 * d_kl + l3 * cond_entropy + + if proba_s and one_hot_gt_s and valid_pixels_s: + ce = self.get_ce(proba_s, valid_pixels_s, one_hot_gt_s, reduction='none') + loss = l1 * ce + l2 * d_kl + l3 * cond_entropy + else: + loss = d_kl + cond_entropy optimizer.zero_grad() loss.sum(0).backward() optimizer.step() # Update FB_param - if (iteration + 1) in self.FB_param_update and ('oracle' not in self.FB_param_type) and (l2.sum().item() != 0): + # for few-shot + if (iteration + 1) in self.FB_param_update and ('oracle' not in self.FB_param_type) and features_s is not None and (l2.sum().item() != 0): deltas = self.compute_FB_param(features_q, gt_q).cpu() l2 += 1 + elif (iteration + 1) in self.FB_param_update and ('oracle' not in self.FB_param_type): # for 0 shot + deltas = self.compute_FB_param(features_q, gt_q).cpu() if callback is not None and (iteration + 1) % self.visdom_freq == 0: self.update_callback(callback, iteration, features_s, features_q, subcls, gt_s, gt_q)