Skip to content
This repository has been archived by the owner on Dec 18, 2024. It is now read-only.

Commit

Permalink
solution
Browse files Browse the repository at this point in the history
  • Loading branch information
Max You authored and Max You committed Dec 17, 2023
1 parent cf9793c commit 42b3b69
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions repri_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,29 @@ def init_prototypes(self, features_s: torch.tensor, features_q: torch.tensor, te
ds_gt_s = F.interpolate(gt_s.float(), size=features_s.shape[-2:], mode='nearest')
ds_gt_s = ds_gt_s.long().unsqueeze(2) # [n_task, shot, 1, h, w]

# Computing prototypes NOTE: this is the weights of the classifier
# Computing prototypes
fg_mask = (ds_gt_s == 1) # [n_tasks, shot, 1, 240, 240]

# per-shot masked average pooling or per-shot prototype
fg_prototype = (features_s * fg_mask).sum(dim=(3, 4))
fg_prototype /= (fg_mask.sum(dim=(3, 4)) + 1e-10) # [n_tasks, shot, c]
# ============= NOTE: Proposed Solution with text embedding only Begin =============
fg_prototype = text_s[:, -1, :]
# ============= NOTE: Proposed Solution with text embedding only End ===============

# average the prototype across all shots including the label embeddings
fg_prototype = torch.concat([fg_prototype, text_s[:, -1, :].unsqueeze(1)], dim=1) # [n_tasks, shot + 1, c]
fg_prototype = torch.mean(fg_prototype, dim=1)
# ============= NOTE: Proposed Solution with text embedding + few-shot support images Begin =============
# fg_prototype = text_s[:, -1, :]
# ============= NOTE: Proposed Solution with text embedding + few-shot support images End ===============

# PREVIOUS CODE
# fg_prototype = (features_s * fg_mask).sum(dim=(1, 3, 4))
# fg_prototype /= (fg_mask.sum(dim=(1, 3, 4)) + 1e-10) # [n_task, c]
# fg_prototype = (fg_prototype + text_s[:, -1, :]) / 2

# ============= NOTE: Alternative Solution 1 Start =============

# # per-shot masked average pooling or per-shot prototype
# fg_prototype = (features_s * fg_mask).sum(dim=(3, 4))
# fg_prototype /= (fg_mask.sum(dim=(3, 4)) + 1e-10) # [n_tasks, shot, c]

# # average the prototype across all shots including the label embeddings
# fg_prototype = torch.concat([fg_prototype, text_s[:, -1, :].unsqueeze(1)], dim=1) # [n_tasks, shot + 1, c]
# fg_prototype = torch.mean(fg_prototype, dim=1)

# ============= NOTE: Alternative Solution 1 End =============

# store the initial weight / support prototype before finding the logits
self.prototype = fg_prototype
Expand Down

0 comments on commit 42b3b69

Please sign in to comment.