diff --git a/repri_classifier.py b/repri_classifier.py index 09d5cf6..77d30f4 100644 --- a/repri_classifier.py +++ b/repri_classifier.py @@ -125,21 +125,29 @@ def init_prototypes(self, features_s: torch.tensor, features_q: torch.tensor, te ds_gt_s = F.interpolate(gt_s.float(), size=features_s.shape[-2:], mode='nearest') ds_gt_s = ds_gt_s.long().unsqueeze(2) # [n_task, shot, 1, h, w] - # Computing prototypes NOTE: this is the weights of the classifier + # Computing prototypes fg_mask = (ds_gt_s == 1) # [n_tasks, shot, 1, 240, 240] - # per-shot masked average pooling or per-shot prototype - fg_prototype = (features_s * fg_mask).sum(dim=(3, 4)) - fg_prototype /= (fg_mask.sum(dim=(3, 4)) + 1e-10) # [n_tasks, shot, c] + # ============= NOTE: Proposed Solution with text embedding only Begin ============= + fg_prototype = text_s[:, -1, :] + # ============= NOTE: Proposed Solution with text embedding only End =============== - # average the prototype across all shots including the label embeddings - fg_prototype = torch.concat([fg_prototype, text_s[:, -1, :].unsqueeze(1)], dim=1) # [n_tasks, shot + 1, c] - fg_prototype = torch.mean(fg_prototype, dim=1) + # ============= NOTE: Proposed Solution with text embedding + few-shot support images Begin ============= + # fg_prototype = text_s[:, -1, :] + # ============= NOTE: Proposed Solution with text embedding + few-shot support images End =============== - # PREVIOUS CODE - # fg_prototype = (features_s * fg_mask).sum(dim=(1, 3, 4)) - # fg_prototype /= (fg_mask.sum(dim=(1, 3, 4)) + 1e-10) # [n_task, c] - # fg_prototype = (fg_prototype + text_s[:, -1, :]) / 2 + + # ============= NOTE: Alternative Solution 1 Start ============= + + # # per-shot masked average pooling or per-shot prototype + # fg_prototype = (features_s * fg_mask).sum(dim=(3, 4)) + # fg_prototype /= (fg_mask.sum(dim=(3, 4)) + 1e-10) # [n_tasks, shot, c] + + # # average the prototype across all shots including the label embeddings + # fg_prototype = torch.concat([fg_prototype, text_s[:, -1, :].unsqueeze(1)], dim=1) # [n_tasks, shot + 1, c] + # fg_prototype = torch.mean(fg_prototype, dim=1) + + # ============= NOTE: Alternative Solution 1 End ============= # store the initial weight / support prototype before finding the logits self.prototype = fg_prototype