-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresnet_moco.py
63 lines (54 loc) · 2.71 KB
/
resnet_moco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch.nn as nn
from functools import partial
import torchvision.models as models
# SplitBatchNorm: simulate multi-gpu behavior of BatchNorm in one gpu by splitting alone the batch dimension
# implementation adapted from https://github.com/davidcpage/cifar10-fast/blob/master/torch_backend.py
class SplitBatchNorm(nn.BatchNorm2d):
def __init__(self, num_features, num_splits, **kw):
super().__init__(num_features, **kw)
self.num_splits = num_splits
def forward(self, input):
N, C, H, W = input.shape
if self.training or not self.track_running_stats:
running_mean_split = self.running_mean.repeat(self.num_splits)
running_var_split = self.running_var.repeat(self.num_splits)
outcome = nn.functional.batch_norm(
input.view(-1, C * self.num_splits, H, W), running_mean_split, running_var_split,
self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
True, self.momentum, self.eps).view(N, C, H, W)
self.running_mean.data.copy_(running_mean_split.view(self.num_splits, C).mean(dim=0))
self.running_var.data.copy_(running_var_split.view(self.num_splits, C).mean(dim=0))
return outcome
else:
return nn.functional.batch_norm(
input, self.running_mean, self.running_var,
self.weight, self.bias, False, self.momentum, self.eps)
class ModelBase(nn.Module):
"""
Common CIFAR ResNet recipe.
Comparing with ImageNet ResNet recipe, it:
(i) replaces conv1 with kernel=3, str=1
(ii) removes pool1
"""
def __init__(self, feature_dim=128, arch=None, bn_splits=16):
super(ModelBase, self).__init__()
# use split batchnorm
norm_layer = partial(SplitBatchNorm, num_splits=bn_splits) if bn_splits > 1 else nn.BatchNorm2d
resnet_arch = models.__dict__[arch]
self.net = resnet_arch(pretrained=False, num_classes=feature_dim, norm_layer=norm_layer)
dim_mlp = self.net.fc.in_features #512 for ResNet18
self.net.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.net.fc) # add mlp projection head
# self.net = []
# for name, module in net.named_children():
# if name == 'conv1':
# module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
# if isinstance(module, nn.MaxPool2d):
# continue
# if isinstance(module, nn.Linear):
# self.net.append(nn.Flatten(1))
# self.net.append(module)
# self.net = nn.Sequential(*self.net)
def forward(self, x):
x = self.net(x)
# note: not normalized here
return x