-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmoco_wrapper.py
129 lines (99 loc) · 4.2 KB
/
moco_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from resnet_moco import ModelBase
import torch.nn as nn
import torch
class ModelMoCo(nn.Module):
def __init__(self, dim=128, K=4096, m=0.99, T=0.1, arch='resnet18', bn_splits=8, symmetric=True):
super(ModelMoCo, self).__init__()
self.K = K
self.m = m
self.T = T
self.symmetric = symmetric
# create the encoders
self.encoder_q = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits)
self.encoder_k = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits)
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
# create the queue
self.register_buffer("queue", torch.randn(dim, K))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr:ptr + batch_size] = keys.t() # transpose
ptr = (ptr + batch_size) % self.K # move pointer
self.queue_ptr[0] = ptr
@torch.no_grad()
def _batch_shuffle_single_gpu(self, x):
"""
Batch shuffle, for making use of BatchNorm.
"""
# random shuffle index
idx_shuffle = torch.randperm(x.shape[0]).cuda()
# index for restoring
idx_unshuffle = torch.argsort(idx_shuffle)
return x[idx_shuffle], idx_unshuffle
@torch.no_grad()
def _batch_unshuffle_single_gpu(self, x, idx_unshuffle):
"""
Undo batch shuffle.
"""
return x[idx_unshuffle]
def contrastive_loss(self, im_q, im_k):
# compute query features
q = self.encoder_q(im_q) # queries: NxC
q = nn.functional.normalize(q, dim=1) # already normalized
# compute key features
with torch.no_grad(): # no gradient to keys
# shuffle for making use of BN
im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k)
k = self.encoder_k(im_k_) # keys: NxC
k = nn.functional.normalize(k, dim=1) # already normalized
# undo shuffle
k = self._batch_unshuffle_single_gpu(k, idx_unshuffle)
# compute logits
# Einstein sum is more intuitive
# positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
# apply temperature
logits /= self.T
# labels: positive key indicators
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
loss = nn.CrossEntropyLoss().cuda()(logits, labels)
return loss, q, k, logits, labels
def forward(self, im1, im2):
"""
Input:
im_q: a batch of query images
im_k: a batch of key images
Output:
loss, logits, labels
"""
# update the key encoder
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder()
# compute loss
if self.symmetric: # asymmetric loss -- does not take logits and labels: will throw error
loss_12, q1, k2 = self.contrastive_loss(im1, im2)
loss_21, q2, k1 = self.contrastive_loss(im2, im1)
loss = loss_12 + loss_21
k = torch.cat([k1, k2], dim=0)
else: # asymmetric loss
loss, q, k, logits, labels = self.contrastive_loss(im1, im2)
self._dequeue_and_enqueue(k)
return loss, logits, labels