-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSelfAttention.py
33 lines (28 loc) · 1.36 KB
/
SelfAttention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
from torch.nn import functional as F
class SelfAttention(nn.Module):
def __init__(self, n_embd, head_size, mask_size=None):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.dropout = nn.Dropout(0.2)
# register buffer saves tril on object state, avoiding the overhead of running tril again
self.register_buffer('tril', torch.tril(torch.ones(mask_size, mask_size)))
def forward(self, x):
# input of size (batch, time-step, channels)
# output of size (batch, time-step, head size)
B,T,C = x.shape
k = self.key(x) # (B,T,hs)
q = self.query(x) # (B,T,hs)
v = self.value(x) # (B,T,hs)
# compute attention scores ("affinities")
# (q * k^T) / sqrt(len(k))
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # avoids exploiting the future in the sequence, (B, T, T)
wei = F.softmax(wei, dim=-1) # (B, T, T)
wei = self.dropout(wei)
# perform the weighted aggregation of the values
out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
return out