-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
129 lines (98 loc) · 4.45 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
from pytorch_metric_learning import miners, losses
import sys
def binarize(T, nb_classes):
T = T.cpu().numpy()
import sklearn.preprocessing
T = sklearn.preprocessing.label_binarize(T, classes = range(0, nb_classes))
T = torch.FloatTensor(T).cuda()
return T
def l2_norm(input):
input_size = input.size()
buffer = torch.pow(input, 2)
normp = torch.sum(buffer, 1).add_(1e-12)
norm = torch.sqrt(normp)
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
output = _output.view(input_size)
return output
class Proxy_Anchor(torch.nn.Module):
def __init__(self, nb_classes, sz_embed, mrg=0.1, alpha=32):
torch.nn.Module.__init__(self)
self.proxies = torch.nn.Parameter(torch.randn(nb_classes, sz_embed).cuda())
nn.init.kaiming_normal_(self.proxies, mode='fan_out')
self.nb_classes = nb_classes
self.sz_embed = sz_embed
self.mrg = mrg
self.alpha = alpha
def forward(self, X, T):
P = self.proxies
cos = F.linear(l2_norm(X), l2_norm(P)) # Calcluate cosine similarity
P_one_hot = binarize(T = T, nb_classes = self.nb_classes)
N_one_hot = 1 - P_one_hot
pos_exp = torch.exp(-self.alpha * (cos - self.mrg))
neg_exp = torch.exp(self.alpha * (cos + self.mrg))
with_pos_proxies = torch.nonzero(P_one_hot.sum(dim = 0) != 0).squeeze(dim = 1) # The set of positive proxies of data in the batch
num_valid_proxies = len(with_pos_proxies) # The number of positive proxies
P_sim_sum = torch.where(P_one_hot == 1, pos_exp, torch.zeros_like(pos_exp)).sum(dim=0)
N_sim_sum = torch.where(N_one_hot == 1, neg_exp, torch.zeros_like(neg_exp)).sum(dim=0)
if num_valid_proxies == 0:
num_valid_proxies = 1
pos_term = torch.log(1 + P_sim_sum).sum() / num_valid_proxies
neg_term = torch.log(1 + N_sim_sum).sum() / self.nb_classes
loss = pos_term + neg_term
return loss
# We use PyTorch Metric Learning library for the following codes.
# Please refer to "https://github.com/KevinMusgrave/pytorch-metric-learning" for details.
class Proxy_NCA(torch.nn.Module):
def __init__(self, nb_classes, sz_embed, scale=32):
super(Proxy_NCA, self).__init__()
self.nb_classes = nb_classes
self.sz_embed = sz_embed
self.scale = scale
self.loss_func = losses.ProxyNCALoss(num_classes = self.nb_classes, embedding_size = self.sz_embed, softmax_scale = self.scale).cuda()
def forward(self, embeddings, labels):
loss = self.loss_func(embeddings, labels)
return loss
class MultiSimilarityLoss(torch.nn.Module):
def __init__(self, ):
super(MultiSimilarityLoss, self).__init__()
self.thresh = 0.5
self.epsilon = 0.1
self.scale_pos = 2
self.scale_neg = 50
self.miner = miners.MultiSimilarityMiner(epsilon=self.epsilon)
self.loss_func = losses.MultiSimilarityLoss(self.scale_pos, self.scale_neg, self.thresh)
def forward(self, embeddings, labels):
hard_pairs = self.miner(embeddings, labels)
loss = self.loss_func(embeddings, labels, hard_pairs)
return loss
class ContrastiveLoss(nn.Module):
def __init__(self, margin=0.5, **kwargs):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.loss_func = losses.ContrastiveLoss(neg_margin=self.margin)
def forward(self, embeddings, labels):
loss = self.loss_func(embeddings, labels)
return loss
class TripletLoss(nn.Module):
def __init__(self, margin=0.1, **kwargs):
super(TripletLoss, self).__init__()
self.margin = margin
self.miner = miners.TripletMarginMiner(margin, type_of_triplets = 'semihard')
self.loss_func = losses.TripletMarginLoss(margin = self.margin)
def forward(self, embeddings, labels):
hard_pairs = self.miner(embeddings, labels)
loss = self.loss_func(embeddings, labels, hard_pairs)
return loss
class NPairLoss(nn.Module):
def __init__(self, l2_reg=0):
super(NPairLoss, self).__init__()
self.l2_reg = l2_reg
self.loss_func = losses.NPairsLoss(l2_reg_weight=self.l2_reg, normalize_embeddings = False)
def forward(self, embeddings, labels):
loss = self.loss_func(embeddings, labels)
return loss