-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathexport_fp8_attention.py
65 lines (50 loc) · 1.69 KB
/
export_fp8_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch.nn
import sharktank.ops as ops
from shark_turbine import aot
from sharktank.types import PlanarQuantizedTensor
from sharktank.types import QuantizedTensor
from sharktank.types.layouts import TensorScaledLayout
def make_q_tensor(tensor, scale):
return PlanarQuantizedTensor(
name="qq", shape=tensor.shape, layout=TensorScaledLayout(
shape = tensor.shape,
qs=tensor,
d=torch.scalar_tensor(scale, dtype=torch.float16),
dtype=torch.float16))
class AttentionModel(torch.nn.Module):
def __init__(self):
super().__init__()
pass
def forward(
self, s: torch.Tensor,
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
qs: torch.Tensor, ks: torch.Tensor, vs: torch.Tensor):
q = make_q_tensor(q, qs)
k = make_q_tensor(k, ks)
v = make_q_tensor(v, vs)
return ops.scaled_dot_product_attention(q, k, v, a=None)
q = torch.zeros((1, 1, 4096, 64), dtype=torch.float8_e4m3fnuz, device="cuda:0")
k = torch.zeros((1, 1, 4096, 64), dtype=torch.float8_e4m3fnuz, device="cuda:0")
v = torch.zeros((1, 1, 4096, 64), dtype=torch.float8_e4m3fnuz, device="cuda:0")
s = torch.zeros((), dtype=torch.float32, device="cuda:0")
qs = torch.zeros((), dtype=torch.float32, device="cuda:0")
ks = torch.zeros((), dtype=torch.float32, device="cuda:0")
vs = torch.zeros((), dtype=torch.float32, device="cuda:0")
inputs = {
"q" : q,
"k" : k,
"v" : v,
"s" : s,
"qs" : qs,
"ks" : ks,
"vs" : vs,
}
if __name__ == "__main__":
pass
mdl = AttentionModel()
# Temporary: Need a dedicated exporter.
output = aot.export(
mdl,
kwargs=inputs,
)
output.print_readable()