-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattentions.py
304 lines (206 loc) · 11.6 KB
/
attentions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import torch
import torch.nn as nn
import torch.nn.functional as F
from positional_encoders import AbsolutePositionalEncoder, RelativePositionalEncoder, T5RelativePositionalEncoder
# Scaled Dot Product Attention using Absolute Positional Encoding
class ScaledDotProductAttention(nn.Module):
def __init__(self, emb_dim):
super(ScaledDotProductAttention, self).__init__()
# scaling factor 1 / sqrt(dimension of queries and keys)
self.scaling_factor = torch.sqrt(torch.tensor(emb_dim))
def forward(self, query, key, value, mask = None):
# Scaled score of the Matrix multiplication of query and key (e)
attn_score = torch.bmm(query, key.transpose(1, 2)) / self.scaling_factor
# Masking (Optional)
# shape of mask: (batch size, input length of query, input length of key)
if mask is not None:
attn_score.masked_fill_(mask, -1e18)
# Softmax of the scaled score (alpha)
attn_score = F.softmax(attn_score, -1)
# Matrix multiplication of the scaled score and value (z)
output = torch.bmm(attn_score, value)
return output, attn_score
# Scaled Dot Product Attention using Relative Positional Encoding
class RelativeScaledDotProductAttention(nn.Module):
def __init__(self, emb_dim):
super(RelativeScaledDotProductAttention, self).__init__()
# scaling factor 1 / sqrt(dimension of queries and keys)
self.scaling_factor = torch.sqrt(torch.tensor(emb_dim))
def forward(self, query, key, value, a_key, a_value, mask = None):
# Scaled score of the Matrix multiplication of query and key (e)
qk_attn = torch.bmm(query, key.transpose(1, 2))
relative_qk_attn = torch.bmm(query.permute(1, 0, 2).contiguous(), a_key.transpose(1, 2)).transpose(0, 1)
attn_score = (qk_attn + relative_qk_attn) / self.scaling_factor
# Masking (Optional)
# shape of mask: (batch size, input length of query, input length of key)
if mask is not None:
attn_score.masked_fill_(mask, -1e18)
# Softmax of the scaled score (alpha)
attn_score = F.softmax(attn_score, -1)
# Matrix multiplication of the scaled score and value (z)
qkv_attn = torch.bmm(attn_score, value)
relative_qkv_attn = torch.bmm(attn_score.permute(1, 0, 2).contiguous(), a_value).transpose(0, 1)
output = qkv_attn + relative_qkv_attn
return output, attn_score
# Scaled Dot Product Attention using T5 Relative Positional Encoding
class T5ScaledDotProductAttention(nn.Module):
def __init__(self, emb_dim):
super(T5ScaledDotProductAttention, self).__init__()
# scaling factor 1 / sqrt(dimension of queries and keys)
self.scaling_factor = torch.sqrt(torch.tensor(emb_dim))
def forward(self, query, key, value, relative_bias, mask = None):
# Scaled score of the Matrix multiplication of query and key (e)
attn_score = torch.bmm(query, key.transpose(1, 2)) / self.scaling_factor + relative_bias.permute(2,0,1)
# Masking (Optional)
# shape of mask: (batch size, input length of query, input length of key)
if mask is not None:
attn_score.masked_fill_(mask, -1e18)
# Softmax of the scaled score (alpha)
attn_score = F.softmax(attn_score, -1)
output = torch.bmm(attn_score, value)
return output, attn_score
# Multi-Head Attention using Relation Positional Encoding
class MultiHeadAttention(nn.Module):
def __init__(self, emb_dim, num_heads, positional_encoding="abs", dropout_rate=0.1):
super(MultiHeadAttention, self).__init__()
self.head_dim = int(emb_dim / num_heads)
self.num_heads = num_heads
self.positional_encoding = positional_encoding
self.dropout = nn.Dropout(p=dropout_rate)
# initialize one feed-forward layer (head dimension x number of heads) of each q, k and v
# instead of initializing number of heads of feed-forward layers (head dimension / number of heads)
self.query_proj = nn.Linear(emb_dim, self.head_dim * num_heads)
self.key_proj = nn.Linear(emb_dim, self.head_dim * num_heads)
self.value_proj = nn.Linear(emb_dim, self.head_dim * num_heads)
self.out_proj = nn.Linear(emb_dim, self.head_dim * num_heads)
if positional_encoding == "abs":
self.scaled_dot_attn = ScaledDotProductAttention(self.head_dim)
elif positional_encoding == "rel":
self.relative_scaled_dot_attn = RelativeScaledDotProductAttention(self.head_dim)
self.relative_position_k = RelativePositionalEncoder(self.head_dim)
self.relative_position_v = RelativePositionalEncoder(self.head_dim)
elif positional_encoding == "t5":
self.t5_scaled_dot_attn = T5ScaledDotProductAttention(self.head_dim)
def reshape_from_feed_forward(self, batch_size, _tensor):
return _tensor.view(batch_size, -1, self.num_heads, self.head_dim)
def reshape_to_ScaledDotProductAttention(self, batch_size, _tensor):
# before shape: (batch size, input length, number of heads, head dimension)
# after shape: (batch size, number of heads, input length, head dimension)
_tensor = _tensor.permute(0, 2, 1, 3)
# reshape to feed the tensor to ScaledDotProductAttention
return _tensor.contiguous().view(batch_size * self.num_heads, -1, self.head_dim)
def reshape_to_concat(self, batch_size, _tensor):
# before shape: (batch size, number of heads, input length, head dimension)
# after shape: (batch size, input length, number of heads, head dimension)
_tensor = _tensor.permute(0, 2, 1, 3)
return _tensor.contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
def forward(self, query, key, value, mask = None, relative_bias=None, is_dropout=True):
# shape of input of q, k and v: (batch size, input length, embedding dimension)
batch_size = query.size()[0]
# feed-forward network
query = self.query_proj(query)
key = self.key_proj(key)
value = self.value_proj(value)
# reshape the result of the feed-forward network
# shape after the feed-forward network of q, k and v: (batch, input length, number of heads, head dimension)
query = self.reshape_from_feed_forward(batch_size, query)
key = self.reshape_from_feed_forward(batch_size, key)
value = self.reshape_from_feed_forward(batch_size, value)
# reshape the result of the feed-forward network to feed it to ScaledDotProductAttention
# shape: (number of heads * batch, input length, head dimension)
query = self.reshape_to_ScaledDotProductAttention(batch_size, query)
key = self.reshape_to_ScaledDotProductAttention(batch_size, key)
value = self.reshape_to_ScaledDotProductAttention(batch_size, value)
# shape of mask: (batch size, number of heads, input length of query, input length of key)
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
if self.positional_encoding == "abs":
output, attn_score = self.scaled_dot_attn(query, key, value, mask)
elif self.positional_encoding == "rel":
seq_len_query = query.size()[1]
seq_len_key = key.size()[1]
seq_len_value = value.size()[1]
a_key = self.relative_position_k(seq_len_query, seq_len_key)
a_value = self.relative_position_v(seq_len_query, seq_len_value)
output, attn_score = self.relative_scaled_dot_attn(query, key, value, a_key, a_value, mask)
elif self.positional_encoding == "t5":
seq_len_query = query.size()[1]
seq_len_key = key.size()[1]
output, attn_score = self.t5_scaled_dot_attn(query, key, value, relative_bias, mask)
# reshape the result of the ScaledDotProductAttention
# shape: (number of heads, batch size, input length, head dimension)
output = output.view(self.num_heads, batch_size, -1, self.head_dim)
# reshape to concat
# shape: (number of heads, batch size, input length, head dimension)
output = self.reshape_to_concat(batch_size, output)
# final feed-forward network
output = self.out_proj(output)
if is_dropout:
output = self.dropout(output)
return output, attn_score
return output, attn_score
def get_candidate_heads(emb_dim, _num_heads):
divisor_list = []
for i in range(1, emb_dim):
if emb_dim % i == 0:
divisor_list.append(i)
return divisor_list[len(divisor_list)//2]
def get_attn_output(input_embedding, selected_attn, selected_pe, _num_heads):
emb_dim = input_embedding.size()[-1]
# input embedding + positional encoding
positional_encoder = AbsolutePositionalEncoder(emb_dim)
input_embedding = input_embedding + positional_encoder(input_embedding)
query = key = value = input_embedding
seq_len_query = query.size()[1]
seq_len_key = key.size()[1]
seq_len_value = value.size()[1]
# Absolute Positional Encoding
if selected_pe == "abs":
if selected_attn == "scaleddotproduct":
model = ScaledDotProductAttention(emb_dim)
output, attn_score = model(query, key, value)
return output, attn_score
elif selected_attn == "multihead":
if emb_dim % _num_heads != 0:
num_heads = get_candidate_heads(emb_dim, _num_heads)
else:
num_heads = _num_heads
model = MultiHeadAttention(emb_dim, num_heads)
output, attn_score = model(query, key, value)
return output, attn_score
# Relative Positional Encoding
elif selected_pe == "rel":
if selected_attn == "scaleddotproduct":
relative_position_k = RelativePositionalEncoder(emb_dim)
relative_position_v = RelativePositionalEncoder(emb_dim)
a_key = relative_position_k(seq_len_query, seq_len_key)
a_value = relative_position_v(seq_len_query, seq_len_value)
model = RelativeScaledDotProductAttention(emb_dim)
output, attn_score = model(query, key, value, a_key, a_value)
return output, attn_score
elif selected_attn == "multihead":
if emb_dim % _num_heads != 0:
num_heads = get_candidate_heads(emb_dim, _num_heads)
else:
num_heads = _num_heads
model = MultiHeadAttention(emb_dim, num_heads)
output, attn_score = model(query, key, value)
return output, attn_score
# T5 Relative Positional Encoding
elif selected_pe == "t5":
if selected_attn == "scaleddotproduct":
relative_position_bias = T5RelativePositionalEncoder(1)
relative_bias = relative_position_bias(seq_len_query, seq_len_key)
model = T5ScaledDotProductAttention(emb_dim)
output, attn_score = model(query, key, value, relative_bias)
return output, attn_score
elif selected_attn == "multihead":
if emb_dim % _num_heads != 0:
num_heads = get_candidate_heads(emb_dim, _num_heads)
else:
num_heads = _num_heads
relative_position_bias = T5RelativePositionalEncoder(_num_heads)
relative_bias = relative_position_bias(seq_len_query, seq_len_key)
model = MultiHeadAttention(emb_dim, num_heads)
output, attn_score = model(query, key, value, relative_bias=relative_bias)
return output, attn_score