From 8300bc8f89af27c6c782d36ecb72fa419d4e7875 Mon Sep 17 00:00:00 2001 From: dan Date: Sat, 1 Feb 2025 16:13:49 -0800 Subject: [PATCH 01/29] restore custom matmul kernel --- sharktank/sharktank/ops/qlinear_impls.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index f88684273..f4f7ac0ca 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -50,10 +50,10 @@ def qlinear_tensor_scaled( # Handle only integer and fp8 quantizations. if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point: - if x_layout.qs.dtype == torch.float8_e4m3fnuz: - # assume quark - return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True) - else: + if ( + x_layout.qs.dtype != torch.float8_e4m3fnuz + or weight_layout.qs.dtype != torch.float8_e4m3fnuz + ): return NotImplemented # Bias. From a98a332658700f1e4a5c576d96c1c98088bab8b8 Mon Sep 17 00:00:00 2001 From: dan Date: Sat, 1 Feb 2025 18:33:28 -0800 Subject: [PATCH 02/29] not mergeable as-is --- sharktank/sharktank/kernels/batch_matmul_transpose_b.py | 9 ++++----- sharktank/sharktank/layers/linear.py | 4 ++-- sharktank/sharktank/ops/qlinear_impls.py | 6 +++--- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py index 21f9e9ed4..a55d6654b 100644 --- a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py +++ b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py @@ -8,7 +8,7 @@ import torch -from iree.compiler.ir import IntegerType +from iree.compiler.ir import IntegerType, FloatType __all__ = [ "batch_matmul_transpose_b", @@ -59,9 +59,7 @@ def select(self, ksel: KernelSelection): lambda: f"batch_matmul_transpose_b: Batch dims must match ({lhs_desc.t.shape} vs {rhs_desc.t.shape})", ) # Shape batch, m, n - c_desc = ksel.return_new_tensor( - [lhs_batch, lhs_m, rhs_n], dtype=lhs_desc.t.dtype - ) + c_desc = ksel.return_new_tensor([lhs_batch, lhs_m, rhs_n], dtype=torch.float32) specialize_all_known_dims(lhs_desc) specialize_all_known_dims(rhs_desc) specialize_all_known_dims(c_desc) @@ -77,8 +75,9 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): result_desc = ksel.result_descs[0] # Generate specialization signature and types. - a_asm_type, a_ident, accum_type = unpack_tensor_type(lhs.type) + a_asm_type, a_ident, _ = unpack_tensor_type(lhs.type) b_asm_type, b_ident, _ = unpack_tensor_type(rhs.type) + accum_type = FloatType.parse("f32") spec_sig = f"L{a_ident}_R{b_ident}" template_file = "batch_matmul_transpose_b.mlir" target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}" diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index a1f1366ab..dae126767 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -85,8 +85,8 @@ def forward(self, x): # We can truncate to fp16 in iree, so we do a cast here # to account for this in the IR. This is may not be the right # level to do this, but for now its here. - if not isinstance(y, QuantizedTensor): - if y.dtype == torch.float8_e4m3fnuz: + if not isinstance(y, QuantizedTensor) and isinstance(x, QuantizedTensor): + if x.unpack().qs.dtype == torch.float8_e4m3fnuz: y = ops.to(y, torch.bfloat16) return y if qdq_output is not None: diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index f4f7ac0ca..df6d74b15 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -93,6 +93,7 @@ def qlinear_tensor_scaled( # Fall back to automatic fusion based on integer, high precision matmul. y_qs = _invoke_mmt_kernel(x_qs, weight_qs, accum_dtype=accum_dtype) + return y_qs # Offset correction. By applying the offset correction in post, it is # set up to fuse with its consumer, which is already doing additional @@ -187,9 +188,8 @@ def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype): rhs_size = [lhs.shape[0]] + list(rhs.shape) rhs = rhs.unsqueeze(0).expand(rhs_size) rhs_rank = len(rhs.shape) - y_qs = kernels.batch_matmul_transpose_b( - lhs.to(accum_dtype), rhs.to(accum_dtype) - ) + y_qs = kernels.batch_matmul_transpose_b(lhs, rhs) + return y_qs # Squeeze the batch dimension to maintain shape parity with other # layers. if len(y_qs.shape) > 2: From 80fee981b8d6fc096390530640e84f0da5e68f4d Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Tue, 4 Feb 2025 02:05:36 +0000 Subject: [PATCH 03/29] Make batch_matmul_transpose_b accept accumulation dtype This argument can be skipped. Then the result dtype is inferred from the arguments. This requires https://github.com/iree-org/iree-turbine/pull/451 --- .../kernels/batch_matmul_transpose_b.py | 39 ++++++++++++--- sharktank/sharktank/ops/qlinear_impls.py | 4 +- .../kernels/batch_matmul_transpose_b_test.py | 50 ++++++++++++++++++- 3 files changed, 82 insertions(+), 11 deletions(-) diff --git a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py index a55d6654b..179cb4679 100644 --- a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py +++ b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py @@ -7,16 +7,36 @@ from sharktank.kernels.base import * import torch +from typing import cast, Optional -from iree.compiler.ir import IntegerType, FloatType +from iree.compiler.ir import IntegerType, Type +from iree.turbine.support.conversions import ( + TORCH_DTYPE_TO_IREE_TYPE_ASM, + IREE_TYPE_ASM_TO_TORCH_DTYPE, +) +from iree.turbine.runtime.op_reg import AttrArg __all__ = [ "batch_matmul_transpose_b", ] +def batch_matmul_transpose_b( + lhs: torch.Tensor, + rhs: torch.Tensor, + /, + *, + accum_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if accum_dtype is None: + accum_dtype = lhs.dtype + return _batch_matmul_transpose_b( + lhs, rhs, accum_dtype=TORCH_DTYPE_TO_IREE_TYPE_ASM[accum_dtype] + ) + + @CustomOp.register(library=LIBRARY) -class batch_matmul_transpose_b(CustomOp): +class _batch_matmul_transpose_b(CustomOp): """Generic block scaled matmul with transposed RHS. The LHS is expected to be a 3d tensor of shape [B, M, K]. RHS must be @@ -25,11 +45,14 @@ class batch_matmul_transpose_b(CustomOp): The kernel will be specialized for all values of N, K and LHS dtype. """ - signature = "batch_matmul_transpose_b(Tensor lhs, Tensor rhs) -> (Tensor)" + signature = ( + "batch_matmul_transpose_b(Tensor lhs, Tensor rhs, str accum_dtype) -> (Tensor)" + ) def select(self, ksel: KernelSelection): lhs_desc = ksel.arg_tensor(0) # Shape [B, M, K] rhs_desc = ksel.arg_tensor(1) # Shape [B, N, K] + accum_type_attr = ksel.attr_str(2) # Rank check. torch._check( @@ -59,7 +82,10 @@ def select(self, ksel: KernelSelection): lambda: f"batch_matmul_transpose_b: Batch dims must match ({lhs_desc.t.shape} vs {rhs_desc.t.shape})", ) # Shape batch, m, n - c_desc = ksel.return_new_tensor([lhs_batch, lhs_m, rhs_n], dtype=torch.float32) + c_desc = ksel.return_new_tensor( + [lhs_batch, lhs_m, rhs_n], + dtype=IREE_TYPE_ASM_TO_TORCH_DTYPE[accum_type_attr.v], + ) specialize_all_known_dims(lhs_desc) specialize_all_known_dims(rhs_desc) specialize_all_known_dims(c_desc) @@ -72,13 +98,14 @@ def select(self, ksel: KernelSelection): def generate(self, ksel: KernelSelection, kb: KernelBuilder): lhs = kb.arg_value(0) rhs = kb.arg_value(1) + accum_type_str = cast(AttrArg, ksel.arg_descs[2]).v result_desc = ksel.result_descs[0] # Generate specialization signature and types. a_asm_type, a_ident, _ = unpack_tensor_type(lhs.type) b_asm_type, b_ident, _ = unpack_tensor_type(rhs.type) - accum_type = FloatType.parse("f32") - spec_sig = f"L{a_ident}_R{b_ident}" + accum_type = Type.parse(accum_type_str) + spec_sig = f"L{a_ident}_R{b_ident}_{accum_type_str}" template_file = "batch_matmul_transpose_b.mlir" target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}" cst_zero = "0" if IntegerType.isinstance(accum_type) else "0." diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index df6d74b15..1b8c12d57 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -93,7 +93,6 @@ def qlinear_tensor_scaled( # Fall back to automatic fusion based on integer, high precision matmul. y_qs = _invoke_mmt_kernel(x_qs, weight_qs, accum_dtype=accum_dtype) - return y_qs # Offset correction. By applying the offset correction in post, it is # set up to fuse with its consumer, which is already doing additional @@ -188,8 +187,7 @@ def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype): rhs_size = [lhs.shape[0]] + list(rhs.shape) rhs = rhs.unsqueeze(0).expand(rhs_size) rhs_rank = len(rhs.shape) - y_qs = kernels.batch_matmul_transpose_b(lhs, rhs) - return y_qs + y_qs = kernels.batch_matmul_transpose_b(lhs, rhs, accum_dtype=accum_dtype) # Squeeze the batch dimension to maintain shape parity with other # layers. if len(y_qs.shape) > 2: diff --git a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py index 208d54782..b2c8d45f2 100644 --- a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py +++ b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py @@ -10,11 +10,12 @@ import unittest from parameterized import parameterized - +import pytest import torch from iree.turbine import aot from sharktank import kernels +from sharktank.utils.testing import skip class batch_matmul_transpose_b_test(unittest.TestCase): @@ -40,6 +41,25 @@ def testBS32(self, atol, rtol): ref = torch.matmul(a, bT) torch.testing.assert_close(result, ref, atol=atol, rtol=rtol) + @pytest.mark.xfail( + reason="""Does not compile for llvm-cpu with + :0: error: 'llvm.fpext' op operand #0 must be floating point LLVM type or LLVM dialect-compatible vector of floating point LLVM type, but got 'vector<4xi8>' + :0: note: see current operation: %120 = "llvm.fpext"(%109) : (vector<4xi8>) -> vector<4xf32> + """ + ) + def testArgF8AccumF32(self): + arg_dtype = torch.float8_e4m3fnuz + a = torch.rand([3, 4, 6]).to(arg_dtype) + b = torch.rand([3, 5, 6]).to(arg_dtype) + accum_dtype = torch.float32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + # Dequantize and test with normal matmul. + # Tolerances are empirical and results are not expected to match exactly. + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) + torch.testing.assert_close(result, ref, atol=1e-3, rtol=0) + def testExportStaticDims(self): class MyModule(torch.nn.Module): def forward(self, a, b): @@ -57,7 +77,33 @@ def forward(self, a, b): output = aot.export(ep) output.verify() asm = str(output.mlir_module) - self.assertIn("@sharktank_batch_matmul_transpose_b_L4x16x2xi32_R4x8x2xi32", asm) + self.assertIn( + "@sharktank_batch_matmul_transpose_b_L4x16x2xi32_R4x8x2xi32_i32", asm + ) + + def testExportArgF8AccumF32(self): + accum_dtype = torch.float32 + arg_type = torch.float8_e4m3fnuz + + class MyModule(torch.nn.Module): + def forward(self, a, b): + return kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + mod = MyModule() + ep = torch.export.export( + mod, + args=( + (torch.rand([4, 16, 2])).to(arg_type), + (torch.rand([4, 8, 2])).to(arg_type), + ), + ) + output = aot.export(ep) + output.verify() + asm = str(output.mlir_module) + self.assertIn( + "@sharktank_batch_matmul_transpose_b_L4x16x2xf8E4M3FNUZ_R4x8x2xf8E4M3FNUZ_f32", + asm, + ) if __name__ == "__main__": From 9ff020e1adb56b6c56acb0a44e9011367ff3aab3 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Tue, 4 Feb 2025 15:44:56 +0000 Subject: [PATCH 04/29] Merge batch_matmul_transpose_b export tests into 1 --- .../kernels/batch_matmul_transpose_b_test.py | 48 ++++++++----------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py index b2c8d45f2..a9a490a8e 100644 --- a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py +++ b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py @@ -14,6 +14,7 @@ import torch from iree.turbine import aot +from iree.turbine.support.conversions import TORCH_DTYPE_TO_IREE_TYPE_ASM from sharktank import kernels from sharktank.utils.testing import skip @@ -60,31 +61,15 @@ def testArgF8AccumF32(self): ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) torch.testing.assert_close(result, ref, atol=1e-3, rtol=0) - def testExportStaticDims(self): - class MyModule(torch.nn.Module): - def forward(self, a, b): - return kernels.batch_matmul_transpose_b(a, b) - - mod = MyModule() - dtype = torch.int32 - ep = torch.export.export( - mod, - args=( - (torch.rand([4, 16, 2]) * 64).to(dtype), - (torch.rand([4, 8, 2]) * 64).to(dtype), - ), - ) - output = aot.export(ep) - output.verify() - asm = str(output.mlir_module) - self.assertIn( - "@sharktank_batch_matmul_transpose_b_L4x16x2xi32_R4x8x2xi32_i32", asm - ) - - def testExportArgF8AccumF32(self): - accum_dtype = torch.float32 - arg_type = torch.float8_e4m3fnuz - + @parameterized.expand( + [ + (torch.int32, None), + (torch.float8_e4m3fnuz, torch.float32), + ] + ) + def testExportStaticDims( + self, arg_dtype: torch.dtype, accum_dtype: torch.dtype | None + ): class MyModule(torch.nn.Module): def forward(self, a, b): return kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) @@ -93,15 +78,22 @@ def forward(self, a, b): ep = torch.export.export( mod, args=( - (torch.rand([4, 16, 2])).to(arg_type), - (torch.rand([4, 8, 2])).to(arg_type), + (torch.rand([4, 16, 2]) * 64).to(arg_dtype), + (torch.rand([4, 8, 2]) * 64).to(arg_dtype), ), ) output = aot.export(ep) output.verify() asm = str(output.mlir_module) + arg_dtype_asm = TORCH_DTYPE_TO_IREE_TYPE_ASM[arg_dtype] + accum_dtype_asm = arg_dtype_asm + if accum_dtype is not None: + accum_dtype_asm = TORCH_DTYPE_TO_IREE_TYPE_ASM[accum_dtype] self.assertIn( - "@sharktank_batch_matmul_transpose_b_L4x16x2xf8E4M3FNUZ_R4x8x2xf8E4M3FNUZ_f32", + ( + "@sharktank_batch_matmul_transpose_b_" + f"L4x16x2x{arg_dtype_asm}_R4x8x2x{arg_dtype_asm}_{accum_dtype_asm}" + ), asm, ) From 781c8e8bd86182379ef8fa18bb4c6aa7480779bc Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Sat, 8 Feb 2025 16:08:09 +0000 Subject: [PATCH 05/29] Add exception to qlinear to not use the kernel when unsigned ints --- sharktank/sharktank/ops/qlinear_impls.py | 24 +++++++- .../kernels/batch_matmul_transpose_b_test.py | 57 +++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index 1b8c12d57..5875ed30f 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -170,7 +170,17 @@ def linear_quantized_weight( linear.override(Tensor, QuantizedTensor, AnyTensor)(linear_quantized_weight) -def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype): +def _is_dtype_unsigned_integer(dtype: torch.dtype): + return ( + not dtype.is_complex() + and not dtype.is_floating_point() + and not dtype.is_signed() + ) + + +def _invoke_mmt_kernel( + lhs: torch.Tensor, rhs: torch.Tensor, *, accum_dtype: torch.dtype +): if debugging.flags.use_custom_iree_kernels: # The custom kernel requires that the lhs and rhs be the same # rank. Broadcast the rhs to match. @@ -187,7 +197,17 @@ def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype): rhs_size = [lhs.shape[0]] + list(rhs.shape) rhs = rhs.unsqueeze(0).expand(rhs_size) rhs_rank = len(rhs.shape) - y_qs = kernels.batch_matmul_transpose_b(lhs, rhs, accum_dtype=accum_dtype) + if ( + _is_dtype_unsigned_integer(lhs.dtype) + or _is_dtype_unsigned_integer(rhs.dtype) + or _is_dtype_unsigned_integer(accum_dtype) + ): + # TODO: make the kernel work with unsigned types. + y_qs = kernels.batch_matmul_transpose_b( + lhs.to(dtype=accum_dtype), rhs.to(dtype=accum_dtype) + ) + else: + y_qs = kernels.batch_matmul_transpose_b(lhs, rhs, accum_dtype=accum_dtype) # Squeeze the batch dimension to maintain shape parity with other # layers. if len(y_qs.shape) > 2: diff --git a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py index a9a490a8e..81e32db36 100644 --- a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py +++ b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py @@ -61,6 +61,63 @@ def testArgF8AccumF32(self): ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) torch.testing.assert_close(result, ref, atol=1e-3, rtol=0) + @pytest.mark.xfail( + reason="Does not work with unsigned types. The kernel needs to be adapted." + ) + def testArgUi8AccumI32(self): + arg_dtype = torch.uint8 + a = ((torch.rand([2, 3, 5]) * 255) + 0.5).to(dtype=arg_dtype) + b = ((torch.rand([2, 4, 5]) * 255) + 0.5).to(dtype=arg_dtype) + accum_dtype = torch.int32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) + torch.testing.assert_close(result, ref, atol=0, rtol=0) + + @pytest.mark.xfail( + reason="Does not work with unsigned types. The kernel needs to be adapted." + ) + def testArgLhsI8RhsUi8AccumI32(self): + a = ((torch.rand([2, 3, 5]) - 0.5) * 255).to(dtype=torch.int8) + b = ((torch.rand([2, 4, 5]) * 255) + 0.5).to(dtype=torch.uint8) + accum_dtype = torch.int32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) + torch.testing.assert_close(result, ref, atol=0, rtol=0) + + def testArgI8AccumI32(self): + arg_dtype = torch.int8 + a = ((torch.rand([2, 3, 5]) - 0.5) * 255).to(dtype=arg_dtype) + b = ((torch.rand([2, 3, 5]) - 0.5) * 255).to(dtype=arg_dtype) + accum_dtype = torch.int32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) + torch.testing.assert_close(result, ref, atol=0, rtol=0) + + @pytest.mark.xfail( + reason="""No uint32 dtype conversions in IREE Turbine. + Does not work with unsigned types. The kernel needs to be adapted. + The problem is that we reinterpret cast to signless integer types. + Maybe linalg.batch_matmul_transpose_b when promoting from i8 to i32 assumes a + signed type even though i8 is signless.""" + ) + def testArgUi8AccumUi32(self): + arg_dtype = torch.uint8 + a = ((torch.rand([2, 3, 5]) * 255) + 0.5).to(dtype=arg_dtype) + b = ((torch.rand([2, 4, 5]) * 255) + 0.5).to(dtype=arg_dtype) + accum_dtype = torch.uint32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=torch.int32), bT.to(dtype=torch.int32)) + ref = ref.to(dtype=accum_dtype) + torch.testing.assert_close(result, ref, atol=0, rtol=0) + @parameterized.expand( [ (torch.int32, None), From 9f1c3d40f2aa7a40abd4ab6f562648f7b35e53dc Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Sat, 8 Feb 2025 16:17:22 +0000 Subject: [PATCH 06/29] Small fix --- sharktank/sharktank/ops/qlinear_impls.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index 5875ed30f..dd658b5ce 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -171,11 +171,7 @@ def linear_quantized_weight( def _is_dtype_unsigned_integer(dtype: torch.dtype): - return ( - not dtype.is_complex() - and not dtype.is_floating_point() - and not dtype.is_signed() - ) + return not dtype.is_complex and not dtype.is_floating_point and not dtype.is_signed def _invoke_mmt_kernel( From 82b032aead6df9e4200a8fa2ce472d98e20301a7 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Mon, 10 Feb 2025 23:19:15 +0000 Subject: [PATCH 07/29] Add eager execution to circamvent failure to compile for llvm-cpu --- .../kernels/batch_matmul_transpose_b.py | 4 ++++ .../kernels/batch_matmul_transpose_b_test.py | 18 ++++++------------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py index 179cb4679..e9879bfab 100644 --- a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py +++ b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py @@ -49,6 +49,10 @@ class _batch_matmul_transpose_b(CustomOp): "batch_matmul_transpose_b(Tensor lhs, Tensor rhs, str accum_dtype) -> (Tensor)" ) + def eager_execute(self, lhs: torch.Tensor, rhs: torch.Tensor, accum_dtype: str): + dtype = IREE_TYPE_ASM_TO_TORCH_DTYPE[accum_dtype] + return torch.matmul(lhs.to(dtype=dtype), rhs.transpose(-1, -2).to(dtype=dtype)) + def select(self, ksel: KernelSelection): lhs_desc = ksel.arg_tensor(0) # Shape [B, M, K] rhs_desc = ksel.arg_tensor(1) # Shape [B, N, K] diff --git a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py index 81e32db36..3ac260265 100644 --- a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py +++ b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py @@ -42,13 +42,11 @@ def testBS32(self, atol, rtol): ref = torch.matmul(a, bT) torch.testing.assert_close(result, ref, atol=atol, rtol=rtol) - @pytest.mark.xfail( - reason="""Does not compile for llvm-cpu with - :0: error: 'llvm.fpext' op operand #0 must be floating point LLVM type or LLVM dialect-compatible vector of floating point LLVM type, but got 'vector<4xi8>' - :0: note: see current operation: %120 = "llvm.fpext"(%109) : (vector<4xi8>) -> vector<4xf32> - """ - ) def testArgF8AccumF32(self): + # TODO: make this test not use eager but actually execute with IREE. + # Does not compile for llvm-cpu with + # :0: error: 'llvm.fpext' op operand #0 must be floating point LLVM type or LLVM dialect-compatible vector of floating point LLVM type, but got 'vector<4xi8>' + # :0: note: see current operation: %120 = "llvm.fpext"(%109) : (vector<4xi8>) -> vector<4xf32> arg_dtype = torch.float8_e4m3fnuz a = torch.rand([3, 4, 6]).to(arg_dtype) b = torch.rand([3, 5, 6]).to(arg_dtype) @@ -61,10 +59,9 @@ def testArgF8AccumF32(self): ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) torch.testing.assert_close(result, ref, atol=1e-3, rtol=0) - @pytest.mark.xfail( - reason="Does not work with unsigned types. The kernel needs to be adapted." - ) def testArgUi8AccumI32(self): + # TODO: make this test not use eager but actually execute with IREE. + # Does not work with unsigned types. The kernel needs to be adapted. arg_dtype = torch.uint8 a = ((torch.rand([2, 3, 5]) * 255) + 0.5).to(dtype=arg_dtype) b = ((torch.rand([2, 4, 5]) * 255) + 0.5).to(dtype=arg_dtype) @@ -75,9 +72,6 @@ def testArgUi8AccumI32(self): ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) torch.testing.assert_close(result, ref, atol=0, rtol=0) - @pytest.mark.xfail( - reason="Does not work with unsigned types. The kernel needs to be adapted." - ) def testArgLhsI8RhsUi8AccumI32(self): a = ((torch.rand([2, 3, 5]) - 0.5) * 255).to(dtype=torch.int8) b = ((torch.rand([2, 4, 5]) * 255) + 0.5).to(dtype=torch.uint8) From de7009425862ca2afa38d215ad12a5749f7b93b0 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Tue, 11 Feb 2025 21:24:52 +0000 Subject: [PATCH 08/29] Convert dtype when writing into the cache --- sharktank/sharktank/layers/kv_cache.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index be8c66fb4..e5e5dbcd8 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -275,7 +275,9 @@ def write_timestep( partitions = partitions.repeat(bs, 1) indices = (page_id, transformer_block, partitions, page_offset) - page_table.index_put_(indices=indices, values=cache_partition) + page_table.index_put_( + indices=indices, values=ops.to(cache_partition, dtype=page_table.dtype) + ) return @@ -320,4 +322,6 @@ def write( (base_subblock_ids + index) if index > 0 else base_subblock_ids ).flatten(0, 1) - subblock_table.index_copy_(0, subblock_ids, part_block_view) + subblock_table.index_copy_( + 0, subblock_ids, ops.to(part_block_view, dtype=subblock_table.dtype) + ) From ae89b55586c5038bacbd1a82d662ddfae451fad5 Mon Sep 17 00:00:00 2001 From: aviator19941 Date: Wed, 12 Feb 2025 18:23:51 -0600 Subject: [PATCH 09/29] Fix attention_dtype flag for paged_llm_v1 Signed-off-by: aviator19941 --- sharktank/sharktank/examples/paged_llm_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 5d338bd74..11c834b21 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -288,7 +288,7 @@ def main(): block_seq_stride=args.block_seq_stride, device=device, activation_dtype=args.activation_dtype, - attention_dtype=args.activation_dtype, + attention_dtype=args.attention_dtype, attention_kernel=args.attention_kernel, use_hf=args.use_hf, tensor_parallelism_size=args.tensor_parallelism_size, From 5bf4636e1fe4363c5b103d85152bebe340051cb7 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 14 Feb 2025 00:40:10 +0000 Subject: [PATCH 10/29] KV cache workaround for Torch not supporting torch.Tensor.index_copy_ for f8 --- sharktank/sharktank/layers/kv_cache.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index e5e5dbcd8..8ca3cfd5d 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -322,6 +322,9 @@ def write( (base_subblock_ids + index) if index > 0 else base_subblock_ids ).flatten(0, 1) + if subblock_table.dtype == torch.float8_e4m3fnuz: + # Workaround for Torch not supporting torch.Tensor.index_copy_ for f8. + subblock_table = subblock_table.view(dtype=torch.int8) subblock_table.index_copy_( 0, subblock_ids, ops.to(part_block_view, dtype=subblock_table.dtype) ) From fe5c881759d2ad52812324922e5a1f4367755f9d Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Thu, 13 Feb 2025 18:10:38 -0800 Subject: [PATCH 11/29] Fix kv_cache index_put_ issue --- sharktank/sharktank/layers/kv_cache.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index 8ca3cfd5d..38e1e1a95 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -275,6 +275,8 @@ def write_timestep( partitions = partitions.repeat(bs, 1) indices = (page_id, transformer_block, partitions, page_offset) + # Workaround for Torch not supporting torch.Tensor.index_put_ for f8. + page_table = page_table.view(dtype=torch.int8) page_table.index_put_( indices=indices, values=ops.to(cache_partition, dtype=page_table.dtype) ) From b4be2a889c65fa4dfe9ef5e20657c89cbed73ad0 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Thu, 13 Feb 2025 20:48:08 -0800 Subject: [PATCH 12/29] Revert "Fix kv_cache index_put_ issue" This reverts commit fe5c881759d2ad52812324922e5a1f4367755f9d. --- sharktank/sharktank/layers/kv_cache.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index 38e1e1a95..8ca3cfd5d 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -275,8 +275,6 @@ def write_timestep( partitions = partitions.repeat(bs, 1) indices = (page_id, transformer_block, partitions, page_offset) - # Workaround for Torch not supporting torch.Tensor.index_put_ for f8. - page_table = page_table.view(dtype=torch.int8) page_table.index_put_( indices=indices, values=ops.to(cache_partition, dtype=page_table.dtype) ) From 462ddc43753d5e83899e9c32c50ce85ce078c799 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 14 Feb 2025 05:06:11 +0000 Subject: [PATCH 13/29] Fix KV cache index_copy_ f8 workaround --- sharktank/sharktank/layers/kv_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index 8ca3cfd5d..dcef55728 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -322,9 +322,9 @@ def write( (base_subblock_ids + index) if index > 0 else base_subblock_ids ).flatten(0, 1) + part_block = ops.to(part_block_view, dtype=subblock_table.dtype) if subblock_table.dtype == torch.float8_e4m3fnuz: # Workaround for Torch not supporting torch.Tensor.index_copy_ for f8. subblock_table = subblock_table.view(dtype=torch.int8) - subblock_table.index_copy_( - 0, subblock_ids, ops.to(part_block_view, dtype=subblock_table.dtype) - ) + part_block = part_block.view(dtype=torch.int8) + subblock_table.index_copy_(0, subblock_ids, part_block) From 4dc2ac269453deeb4cdd2aab8076d33148e0b0b8 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 14 Feb 2025 05:18:25 +0000 Subject: [PATCH 14/29] In linear for (Tensor, QuantizedTensor) raise if accum_dtype is given --- sharktank/sharktank/ops/qlinear_impls.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index dd658b5ce..b88703aeb 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -160,6 +160,8 @@ def linear_quantized_weight( *, accum_dtype: Optional[torch.dtype], ) -> AnyTensor: + if accum_dtype is not None: + raise NotImplementedError("TODO: implement when is passed accum_dtype") res = matmul(x, weight, transpose_rhs=True) if bias is not None: res = res + bias From 13bfc689744f93b21e1a5119c0b68d0fc7f5b5e0 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 20 Feb 2025 16:49:30 +0000 Subject: [PATCH 15/29] Fix KV cache f8 --- sharktank/sharktank/layers/kv_cache.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index dcef55728..a961a9dd7 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -275,9 +275,14 @@ def write_timestep( partitions = partitions.repeat(bs, 1) indices = (page_id, transformer_block, partitions, page_offset) - page_table.index_put_( - indices=indices, values=ops.to(cache_partition, dtype=page_table.dtype) - ) + values = ops.to(cache_partition, dtype=page_table.dtype) + if page_table.dtype == torch.float8_e4m3fnuz: + # Workaround for Torch not supporting torch.Tensor.index_copy_ for f8. + page_table_as_int8 = page_table.view(dtype=torch.int8) + values_int8 = values.view(dtype=torch.int8) + page_table_as_int8.index_put_(indices=indices, values=values_int8) + else: + page_table.index_put_(indices=indices, values=values) return @@ -325,6 +330,8 @@ def write( part_block = ops.to(part_block_view, dtype=subblock_table.dtype) if subblock_table.dtype == torch.float8_e4m3fnuz: # Workaround for Torch not supporting torch.Tensor.index_copy_ for f8. - subblock_table = subblock_table.view(dtype=torch.int8) - part_block = part_block.view(dtype=torch.int8) - subblock_table.index_copy_(0, subblock_ids, part_block) + subblock_table_as_int8 = subblock_table.view(dtype=torch.int8) + part_block_as_int8 = part_block.view(dtype=torch.int8) + subblock_table_as_int8.index_copy_(0, subblock_ids, part_block_as_int8) + else: + subblock_table.index_copy_(0, subblock_ids, part_block) From 8b2044558b6f87b268a831f058476d8a1dcaba81 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 20 Feb 2025 16:50:39 +0000 Subject: [PATCH 16/29] Remove unused HF dataset --- sharktank/sharktank/utils/hf_datasets.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/sharktank/sharktank/utils/hf_datasets.py b/sharktank/sharktank/utils/hf_datasets.py index c898113c5..301612c26 100644 --- a/sharktank/sharktank/utils/hf_datasets.py +++ b/sharktank/sharktank/utils/hf_datasets.py @@ -93,23 +93,6 @@ def alias_dataset(from_name: str, to_name: str): # Dataset definitions ################################################################################ -Dataset( - "SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF", - ( - RemoteFile( - "gguf", - "SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF", - "meta-llama-3.1-8b-instruct.f16.gguf", - ), - RemoteFile( - "tokenizer_config.json", - "NousResearch/Meta-Llama-3-8B-Instruct", - "tokenizer_config.json", - extra_filenames=["tokenizer.json"], - ), - ), -).alias_to("llama3_8B_fp16") - Dataset( "QuantFactory/Llama-3-8B_q4_1_gguf", ( From 9f160f1d88a7d8e16c82b387e8639a134977d532 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 20 Feb 2025 17:49:05 +0000 Subject: [PATCH 17/29] Add KV cache dtype different from attention dtype --- sharktank/sharktank/examples/export_paged_llm_v1.py | 1 + sharktank/sharktank/examples/paged_llm_v1.py | 1 + sharktank/sharktank/layers/configs/llm_configs.py | 3 +++ sharktank/sharktank/layers/paged_llama_attention_block.py | 8 +++++--- sharktank/sharktank/models/llama/llama.py | 3 +++ sharktank/sharktank/utils/cli.py | 7 ++++++- sharktank/sharktank/utils/create_cache.py | 3 ++- 7 files changed, 21 insertions(+), 5 deletions(-) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 5a5f6d40c..f4b34aea3 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -83,6 +83,7 @@ def main(): block_seq_stride=args.block_seq_stride, activation_dtype=args.activation_dtype, attention_dtype=args.attention_dtype, + kv_cache_dtype=args.kv_cache_dtype, ) llama_config.fake_quant = args.fake_quant diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 11c834b21..b8b6026b6 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -290,6 +290,7 @@ def main(): activation_dtype=args.activation_dtype, attention_dtype=args.attention_dtype, attention_kernel=args.attention_kernel, + kv_cache_dtype=args.kv_cache_dtype, use_hf=args.use_hf, tensor_parallelism_size=args.tensor_parallelism_size, fake_quant=args.fake_quant, diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 162599c7a..baa738e82 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -158,6 +158,9 @@ class LlamaModelConfig: # Either "paged" or "direct". kv_cache_type: str = "paged" + # If None will use attention_dtype. + kv_cache_dtype: Optional[torch.dtype] = None + # The device on which to place intermediate state. device: Optional[torch.device] = None diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 69b011cc4..99216b3d9 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -37,6 +37,7 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, + attention_dtype: Optional[torch.dtype] = None, attention_kernel: str = "decomposed", attention_scale: Optional[float] = None, softcap: Optional[float] = None, @@ -49,6 +50,7 @@ def __init__( self.head_count = head_count self.head_dim = head_dim self.head_count_kv = head_count_kv + self.attention_dtype = attention_dtype self.attention_kernel = attention_kernel self.attention_scale = attention_scale self.softcap = softcap @@ -158,13 +160,13 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: # Fake quant is already dequantized when stored in the cache. if self.cache_quantizer and not self.fake_quant: xk = self.cache_quantizer.dequantize_raw_tensor( - xk, torch.bfloat16, name="xk_deq" + xk, self.attention_dtype, name="xk_deq" ) xv = self.cache_quantizer.dequantize_raw_tensor( - xv, torch.bfloat16, name="xv_deq" + xv, self.attention_dtype, name="xv_deq" ) if attention_mask is not None: - attention_mask = attention_mask.to(torch.bfloat16) + attention_mask = attention_mask.to(self.attention_dtype) # Transpose into [bs, heads, sl, dim] xq = xq.transpose(1, 2) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index a31913833..bc18cb492 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -117,6 +117,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): head_dim=hp.attn_head_dim, head_count_kv=hp.attention_head_count_kv, rms_epsilon=hp.attention_layer_norm_rms_epsilon, + attention_dtype=config.attention_dtype, attention_kernel=self.attention_kernel, fake_quant=self.fake_quant, ) @@ -241,6 +242,7 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, + attention_dtype: Optional[torch.dtype] = None, attention_kernel: str = "decomposed", fake_quant: bool = True, ): @@ -255,6 +257,7 @@ def __init__( head_dim=head_dim, head_count_kv=head_count_kv, rms_epsilon=rms_epsilon, + attention_dtype=attention_dtype, attention_kernel=attention_kernel, fake_quant=fake_quant, ), diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index e3dba31fa..282283ea7 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -32,7 +32,7 @@ def parse(parser: argparse.ArgumentParser, *, args: Sequence[str] | None = None) """Parses arguments and does any prescribed global process setup.""" parsed_args = parser.parse_args(args) # Set torch dtypes - for attr in ["activation_dtype", "attention_dtype"]: + for attr in ["activation_dtype", "attention_dtype", "kv_cache_dtype"]: if hasattr(parsed_args, attr): dtype = getattr(torch, getattr(parsed_args, attr)) assert isinstance(dtype, torch.dtype) @@ -100,6 +100,11 @@ def add_model_options(parser: argparse.ArgumentParser): help="DType to use for activations in the model", default="float16", ) + parser.add_argument( + "--kv-cache-dtype", + help="DType to use for the KV cache. If not given will be attention dtype", + default=None, + ) parser.add_argument("--device", help="Torch device (or default)") parser.add_argument( diff --git a/sharktank/sharktank/utils/create_cache.py b/sharktank/sharktank/utils/create_cache.py index f462d9c00..eb26f5a14 100644 --- a/sharktank/sharktank/utils/create_cache.py +++ b/sharktank/sharktank/utils/create_cache.py @@ -12,6 +12,7 @@ def create_paged_kv_cache(config: LlamaModelConfig) -> PagedKVCache: raise ValueError("Model does not use paged kv cache, cannot create kv cache") hp = config.hp + dtype = config.kv_cache_dtype or config.attention_dtype return PagedKVCache( transformer_block_count=hp.block_count, attn_head_count=hp.attention_head_count_kv, @@ -19,6 +20,6 @@ def create_paged_kv_cache(config: LlamaModelConfig) -> PagedKVCache: cache_partition_count=2, # One for each of K/V. block_seq_stride=config.block_seq_stride, device=config.device, - dtype=config.attention_dtype, + dtype=dtype, shard_count=config.tensor_parallelism_size, ) From 77a84436cbe99d8b620bc34b50518bf381195b35 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 20 Feb 2025 17:50:09 +0000 Subject: [PATCH 18/29] Add more KV cache tests for various dtypes --- sharktank/tests/layers/kv_cache_test.py | 57 +++++++++++++++++-------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py index d59d0a85b..ff60ef8b6 100644 --- a/sharktank/tests/layers/kv_cache_test.py +++ b/sharktank/tests/layers/kv_cache_test.py @@ -4,8 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import unittest - +import pytest import torch from sharktank.ops import replicate, reshard_split, unshard @@ -13,7 +12,16 @@ from sharktank.types import * -def test_paged(): +@pytest.mark.parametrize( + "dtype", + [ + torch.float8_e4m3fnuz, + torch.bfloat16, + torch.float16, + torch.float32, + ], +) +def test_paged(dtype: torch.dtype): bs = 4 seq_length = 24 attn_head_count = 4 @@ -25,7 +33,7 @@ def test_paged(): transformer_block_count=transformer_block_count, attn_head_count=attn_head_count, attn_head_dim=attn_head_dim, - dtype=torch.float32, + dtype=dtype, device=None, ) @@ -36,15 +44,18 @@ def test_paged(): write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] allocation = cache.allocate(page_count=page_count) - allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + for t in allocation: + t[...] = torch.full(t.shape, 0.0).to(dtype=dtype) # Write a prefill in: write_ones = torch.full( - (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 - ) + (bs, write_seq_length, attn_head_count, attn_head_dim), + 1.0, + ).to(dtype=dtype) write_twos = torch.full( - (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 - ) + (bs, write_seq_length, attn_head_count, attn_head_dim), + 2.0, + ).to(dtype=dtype) cache.write( allocation, @@ -72,15 +83,19 @@ def test_paged(): seq_len=write_seq_length, page_ids=write_page_ids, ) - torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) - torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) + torch.testing.assert_close( + read_ones[0], torch.full(read_ones[0].shape, 0.0).to(dtype=dtype) + ) + torch.testing.assert_close( + read_ones[1], torch.full(read_ones[0].shape, 0.0).to(dtype=dtype) + ) # Write timestep - write_threes = torch.full( - (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 + write_threes = torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0).to( + dtype=dtype ) - write_fours = torch.full( - (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 + write_fours = torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0).to( + dtype=dtype ) write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) cache.write_timestep( @@ -98,8 +113,16 @@ def test_paged(): page_ids=page_ids, ) - check_concat_0 = torch.concat([write_ones, write_threes], dim=1) - check_concat_1 = torch.concat([write_twos, write_fours], dim=1) + if dtype == torch.float8_e4m3fnuz: + check_concat_0 = torch.concat( + [write_ones.view(torch.int8), write_threes.view(torch.int8)], dim=1 + ).view(torch.float8_e4m3fnuz) + check_concat_1 = torch.concat( + [write_twos.view(torch.int8), write_fours.view(torch.int8)], dim=1 + ).view(torch.float8_e4m3fnuz) + else: + check_concat_0 = torch.concat([write_ones, write_threes], dim=1) + check_concat_1 = torch.concat([write_twos, write_fours], dim=1) torch.testing.assert_close(check_concat_0, read_back[0]) torch.testing.assert_close(check_concat_1, read_back[1]) From 1ea608a115f0764569bdc1d008f6b8c46b119c93 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 20 Feb 2025 17:53:17 +0000 Subject: [PATCH 19/29] Remove some unwanted corner casehandlings in linear layer --- sharktank/sharktank/layers/linear.py | 4 ---- sharktank/sharktank/ops/qlinear_impls.py | 6 +----- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index dae126767..0381abfe9 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -85,10 +85,6 @@ def forward(self, x): # We can truncate to fp16 in iree, so we do a cast here # to account for this in the IR. This is may not be the right # level to do this, but for now its here. - if not isinstance(y, QuantizedTensor) and isinstance(x, QuantizedTensor): - if x.unpack().qs.dtype == torch.float8_e4m3fnuz: - y = ops.to(y, torch.bfloat16) - return y if qdq_output is not None: y = qdq_output.quantize(y).unpack().dequant() return y diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index b88703aeb..7015c4162 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -44,7 +44,7 @@ def qlinear_tensor_scaled( # Now we know that both the x/weight are TensorScaledLayout. There are still # degrees of freedom: # * Either/both can be per-tensor or per-axis scaled (d is 0D or d is nd>0). - # * Either/both can have offsets of not (m is not None). + # * Either/both can have offsets or not (m is not None). x_layout: TensorScaledLayout = x.unpack() weight_layout: TensorScaledLayout = weight.unpack() @@ -206,10 +206,6 @@ def _invoke_mmt_kernel( ) else: y_qs = kernels.batch_matmul_transpose_b(lhs, rhs, accum_dtype=accum_dtype) - # Squeeze the batch dimension to maintain shape parity with other - # layers. - if len(y_qs.shape) > 2: - y_qs = y_qs.squeeze(0) else: # FP emulation. y_qs = torch.matmul( From 6f0c98bfebcc4af6df661e0b4b5b87ce3b0678e2 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 20 Feb 2025 17:53:47 +0000 Subject: [PATCH 20/29] Add more linear layer tests --- sharktank/tests/layers/linear_test.py | 89 ++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/sharktank/tests/layers/linear_test.py b/sharktank/tests/layers/linear_test.py index ad657889d..562782899 100644 --- a/sharktank/tests/layers/linear_test.py +++ b/sharktank/tests/layers/linear_test.py @@ -5,11 +5,12 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import unittest - import torch +from parameterized import parameterized from sharktank.layers import * from sharktank.types import * +from sharktank.utils.testing import make_rand_torch def _randomize_per_axis(t: torch.Tensor, axis: int, offset_range: float = 0.0): @@ -91,6 +92,92 @@ def testNativeQuant_SymPerTensor_AsymPerAxis0_Dynamic(self): print(torch.abs(output - output_ref)) torch.testing.assert_close(output, output_ref, atol=1e-1, rtol=1e-1) + @parameterized.expand( + [ + (torch.float32, torch.float32, torch.float8_e4m3fnuz, False, False, 1e-6), + (torch.float32, torch.float32, torch.float8_e4m3fnuz, False, True, 1e-6), + (torch.float32, torch.float32, torch.float16, True, False, 1e-6), + (torch.float32, torch.float32, torch.float16, False, False, 1e-6), + (torch.float32, torch.float32, torch.float16, False, True, 1e-6), + (torch.float32, torch.float32, torch.float32, True, False, 1e-6), + ], + ) + def testPerTensorScale( + self, + dequantized_dtype: torch.dtype, + quantized_scale_dtype: torch.dtype, + quantized_dtype: torch.dtype, + with_bias: bool, + fake_quant: bool, + atol: float, + ): + """Test a linear layer where each tensor being quantized with a single + different scale.""" + x = make_rand_torch([10, 8, 2], dtype=dequantized_dtype) + input_scale = torch.tensor(0.5, dtype=quantized_scale_dtype) + input_quantizer = StaticScaledQuantizer( + name="q_input", scale=input_scale, dtype=quantized_dtype + ) + # We roundtrip through quantization to know that any discrepancies in the + # results come from the quantized linear operation itself. Not form the + # inaccuracies of the initial quantization. + x_dequantized = input_quantizer.quantize(x).unpack().dequant() + torch.testing.assert_close( + input_quantizer.quantize(x_dequantized).unpack().dequant(), + x_dequantized, + atol=0, + rtol=0, + ) + + weight = make_rand_torch([12, x.shape[2]], dtype=dequantized_dtype) + weight_scale = torch.tensor(0.66, dtype=quantized_scale_dtype) + weight_quantizer = StaticScaledQuantizer( + scale=weight_scale, dtype=quantized_dtype + ) + weight_dequantized = weight_quantizer.quantize(weight).unpack().dequant() + weight_quantized = weight_quantizer.quantize(weight_dequantized, name="weight") + torch.testing.assert_close( + weight_quantizer.quantize(weight_dequantized).unpack().dequant(), + weight_dequantized, + atol=0, + rtol=0, + ) + + if with_bias: + bias = make_rand_torch( + [x.shape[1], weight.shape[0]], dtype=dequantized_dtype + ) + bias_scale = torch.tensor(1.25, dtype=quantized_scale_dtype) + bias_quantizer = StaticScaledQuantizer( + scale=bias_scale, dtype=quantized_dtype + ) + bias_dequantized = bias_quantizer.quantize(bias).unpack().dequant() + bias_quantized = bias_quantizer.quantize(bias_dequantized, name="bias") + torch.testing.assert_close( + bias_quantizer.quantize(bias_dequantized).unpack().dequant(), + bias_dequantized, + atol=0, + rtol=0, + ) + + expected = torch.matmul(x_dequantized, weight_dequantized.T) + if with_bias: + expected += bias_dequantized + + theta_tensors = [ + input_quantizer, + weight_quantized, + ] + if with_bias: + theta_tensors += [bias_quantized] + theta = Theta(theta_tensors) + linear = LinearLayer(theta, fake_quant=fake_quant) + actual = linear(x_dequantized) + + torch.testing.assert_close( + actual.to(dtype=expected.dtype), expected, atol=atol, rtol=0 + ) + if __name__ == "__main__": unittest.main() From 664a847d69a1a688403a3e7fe1f1078bacb3d196 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 20 Feb 2025 17:55:41 +0000 Subject: [PATCH 21/29] Refactor quark parity test to use tmp dir --- sharktank/tests/models/llama/quark_parity_test.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sharktank/tests/models/llama/quark_parity_test.py b/sharktank/tests/models/llama/quark_parity_test.py index b45696fe4..2dfd73fe7 100644 --- a/sharktank/tests/models/llama/quark_parity_test.py +++ b/sharktank/tests/models/llama/quark_parity_test.py @@ -11,11 +11,12 @@ import pytest from pathlib import Path import subprocess +from sharktank.utils.testing import TempDirTestBase with_quark_data = pytest.mark.skipif("not config.getoption('with_quark_data')") -class QuarkParityTest(unittest.TestCase): +class QuarkParityTest(TempDirTestBase): def setUp(self): super().setUp() self.path_prefix = Path("/shark-dev/quark_test") @@ -25,7 +26,7 @@ def test_compare_against_quark(self): sharktank_dir = str( Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent.parent ) - our_path = self.path_prefix / "ours_prefill.safetensors" + our_path = self._temp_dir / "ours_prefill.safetensors" if os.path.exists(our_path): os.remove(our_path) @@ -58,16 +59,14 @@ def test_compare_against_quark(self): "--fake-quant", "--attention-kernel=torch", "--activation-dtype=bfloat16", - f"--save_intermediates_path={self.path_prefix}/ours", + f"--save_intermediates_path={self._temp_dir / 'ours'}", "--use-hf", "--attention-dtype=bfloat16", "--skip-decode", "--block-seq-stride=16", ] command = subprocess.list2cmdline(command) - proc = subprocess.run( - command, shell=True, capture_output=True, cwd=sharktank_dir - ) + subprocess.check_call(command, shell=True, cwd=sharktank_dir) ours = dict() with safe_open(our_path, "pytorch") as st: From 9816c35ba87a48dd146a5516abfbd0f511b0f7a6 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 20 Feb 2025 19:12:41 +0000 Subject: [PATCH 22/29] Fix KV cache dtype CLI arg parsing --- sharktank/sharktank/utils/cli.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index 282283ea7..36ae89cc6 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -34,8 +34,10 @@ def parse(parser: argparse.ArgumentParser, *, args: Sequence[str] | None = None) # Set torch dtypes for attr in ["activation_dtype", "attention_dtype", "kv_cache_dtype"]: if hasattr(parsed_args, attr): - dtype = getattr(torch, getattr(parsed_args, attr)) - assert isinstance(dtype, torch.dtype) + dtype = getattr(parsed_args, attr) + if dtype is not None: + dtype = getattr(torch, dtype) + assert isinstance(dtype, torch.dtype) setattr(parsed_args, attr, dtype) return parsed_args From c17629e23e415f51645fee121f3335525b117a26 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 20 Feb 2025 21:55:07 +0000 Subject: [PATCH 23/29] Change doc example to not use the removed Llama dataset --- docs/shortfin/llm/user/llama_serving.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/shortfin/llm/user/llama_serving.md b/docs/shortfin/llm/user/llama_serving.md index 474e843b6..ce3058e70 100644 --- a/docs/shortfin/llm/user/llama_serving.md +++ b/docs/shortfin/llm/user/llama_serving.md @@ -94,13 +94,13 @@ mkdir -p $EXPORT_DIR ## 2. Download and compile the model -### Download `llama3_8b_fp16.gguf` +### Download `llama3_8B_f16.gguf` We will use the `hf_datasets` module in `sharktank` to download a LLama3.1 8b f16 model. ```bash -python -m sharktank.utils.hf_datasets llama3_8B_fp16 --local-dir $EXPORT_DIR +python -m sharktank.utils.hf_datasets llama3_8B_f16 --local-dir $EXPORT_DIR ``` > [!NOTE] From 9b7dfdf5d307806347b41f89e7478d76a9a71526 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 20 Feb 2025 23:04:37 +0000 Subject: [PATCH 24/29] Add KV cache dtype to benchmark --- sharktank/sharktank/utils/export_artifacts.py | 4 ++++ sharktank/tests/models/llama/benchmark_amdgpu_test.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index 75071e286..608f65c48 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -95,6 +95,7 @@ def __init__( use_attention_mask: bool = False, activation_dtype: str = "float16", attention_dtype: str = "float16", + kv_cache_dtype: Optional[str] = None, ): self.sharktank_dir = str( Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent @@ -109,6 +110,7 @@ def __init__( self.use_attention_mask = use_attention_mask self.activation_dtype = activation_dtype self.attention_dtype = attention_dtype + self.kv_cache_dtype = kv_cache_dtype def timeit(func): def wrapper(*args, **kwargs): @@ -189,6 +191,8 @@ def export_to_mlir( f"--attention-dtype={self.attention_dtype}", f"--activation-dtype={self.activation_dtype}", ] + if self.kv_cache_dtype is not None: + export_args.append(f"--kv-cache-dtype={self.kv_cache_dtype}") if skip_decode: export_args.append("--skip-decode") if self.attention_kernel in ["decomposed", "torch"]: diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index 6056f9c6c..7cc51a099 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -94,7 +94,8 @@ def setUp(self): tensor_parallelism_size=self.tensor_parallelism_size, block_seq_stride=32, activation_dtype="bfloat16", - attention_dtype="float8_e4m3fnuz", + attention_dtype="bfloat16", + kv_cache_dtype="float8_e4m3fnuz", ) self.prefill_args_bs4_128_stride_32_f16 = ( self.artifacts_dir / "prefill_args_bs4_128_stride_32_tp1" From 55c87017c8d81d511242c50a81a00eeeda928da7 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 21 Feb 2025 00:30:06 +0000 Subject: [PATCH 25/29] Change testBenchmark8B_fp8_Non_Decomposed xfail reason to compilation error --- sharktank/tests/models/llama/benchmark_amdgpu_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index 7cc51a099..39d6f4ce8 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -243,9 +243,9 @@ def testBenchmark8B_f16_Non_Decomposed_Input_Len_2048(self): ) @pytest.mark.xfail( - reason="Benchmark inputs not configured yet.", + reason="Fails due to https://github.com/iree-org/iree/issues/20002.", strict=True, - raises=IreeBenchmarkException, + raises=IreeCompileException, ) def testBenchmark8B_fp8_Non_Decomposed(self): output_file_name = self.dir_path_8b / "fp8_torch" From 740bb807f0200cc99bd464b6244092dadf99b159 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 21 Feb 2025 00:57:50 +0000 Subject: [PATCH 26/29] Put back in the llama3_8B_fp16 HF dataset --- docs/shortfin/llm/user/llama_serving.md | 4 ++-- sharktank/sharktank/utils/hf_datasets.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/docs/shortfin/llm/user/llama_serving.md b/docs/shortfin/llm/user/llama_serving.md index ce3058e70..474e843b6 100644 --- a/docs/shortfin/llm/user/llama_serving.md +++ b/docs/shortfin/llm/user/llama_serving.md @@ -94,13 +94,13 @@ mkdir -p $EXPORT_DIR ## 2. Download and compile the model -### Download `llama3_8B_f16.gguf` +### Download `llama3_8b_fp16.gguf` We will use the `hf_datasets` module in `sharktank` to download a LLama3.1 8b f16 model. ```bash -python -m sharktank.utils.hf_datasets llama3_8B_f16 --local-dir $EXPORT_DIR +python -m sharktank.utils.hf_datasets llama3_8B_fp16 --local-dir $EXPORT_DIR ``` > [!NOTE] diff --git a/sharktank/sharktank/utils/hf_datasets.py b/sharktank/sharktank/utils/hf_datasets.py index 301612c26..c898113c5 100644 --- a/sharktank/sharktank/utils/hf_datasets.py +++ b/sharktank/sharktank/utils/hf_datasets.py @@ -93,6 +93,23 @@ def alias_dataset(from_name: str, to_name: str): # Dataset definitions ################################################################################ +Dataset( + "SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF", + ( + RemoteFile( + "gguf", + "SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF", + "meta-llama-3.1-8b-instruct.f16.gguf", + ), + RemoteFile( + "tokenizer_config.json", + "NousResearch/Meta-Llama-3-8B-Instruct", + "tokenizer_config.json", + extra_filenames=["tokenizer.json"], + ), + ), +).alias_to("llama3_8B_fp16") + Dataset( "QuantFactory/Llama-3-8B_q4_1_gguf", ( From 02215d5b9fd55b7d9205eabadf6e1d4bc9e128e3 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 21 Feb 2025 14:19:35 +0000 Subject: [PATCH 27/29] Remove left behind comment --- sharktank/sharktank/layers/linear.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index 0381abfe9..79ede9f5c 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -81,10 +81,6 @@ def forward(self, x): # Unconditionally dequantize. if isinstance(y, QuantizedTensor): y = y.unpack().dequant() - # Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32. - # We can truncate to fp16 in iree, so we do a cast here - # to account for this in the IR. This is may not be the right - # level to do this, but for now its here. if qdq_output is not None: y = qdq_output.quantize(y).unpack().dequant() return y From 40f993a5b16ba93afe1f1f49f01ad3313d582e06 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 21 Feb 2025 15:41:35 +0000 Subject: [PATCH 28/29] Make quark parity test use f8 KV cache --- sharktank/tests/models/llama/quark_parity_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sharktank/tests/models/llama/quark_parity_test.py b/sharktank/tests/models/llama/quark_parity_test.py index 2dfd73fe7..b8a30c543 100644 --- a/sharktank/tests/models/llama/quark_parity_test.py +++ b/sharktank/tests/models/llama/quark_parity_test.py @@ -62,6 +62,7 @@ def test_compare_against_quark(self): f"--save_intermediates_path={self._temp_dir / 'ours'}", "--use-hf", "--attention-dtype=bfloat16", + "--kv-cache-dtype=float8_e4m3fnuz", "--skip-decode", "--block-seq-stride=16", ] From 85053eff1c6c214ebe8465ef6f981ae0055cf62c Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 21 Feb 2025 16:38:56 +0000 Subject: [PATCH 29/29] Add more bf16 qlinear tests and make ref dtype be f64 --- sharktank/sharktank/ops/qlinear_impls.py | 2 +- sharktank/tests/layers/linear_test.py | 23 ++++++++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index 7015c4162..104349266 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -161,7 +161,7 @@ def linear_quantized_weight( accum_dtype: Optional[torch.dtype], ) -> AnyTensor: if accum_dtype is not None: - raise NotImplementedError("TODO: implement when is passed accum_dtype") + raise NotImplementedError("TODO: implement when accum_dtype is passed") res = matmul(x, weight, transpose_rhs=True) if bias is not None: res = res + bias diff --git a/sharktank/tests/layers/linear_test.py b/sharktank/tests/layers/linear_test.py index 562782899..08164176f 100644 --- a/sharktank/tests/layers/linear_test.py +++ b/sharktank/tests/layers/linear_test.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import logging import unittest import torch from parameterized import parameterized @@ -12,6 +13,8 @@ from sharktank.types import * from sharktank.utils.testing import make_rand_torch +logger = logging.getLogger(__name__) + def _randomize_per_axis(t: torch.Tensor, axis: int, offset_range: float = 0.0): # Applies a randomized per-axis scale and offset to a tensor. @@ -94,6 +97,8 @@ def testNativeQuant_SymPerTensor_AsymPerAxis0_Dynamic(self): @parameterized.expand( [ + (torch.bfloat16, torch.float32, torch.float8_e4m3fnuz, False, False, 1e-2), + (torch.bfloat16, torch.float32, torch.float8_e4m3fnuz, False, True, 1e-2), (torch.float32, torch.float32, torch.float8_e4m3fnuz, False, False, 1e-6), (torch.float32, torch.float32, torch.float8_e4m3fnuz, False, True, 1e-6), (torch.float32, torch.float32, torch.float16, True, False, 1e-6), @@ -113,7 +118,9 @@ def testPerTensorScale( ): """Test a linear layer where each tensor being quantized with a single different scale.""" - x = make_rand_torch([10, 8, 2], dtype=dequantized_dtype) + ref_dtype = torch.float64 + + x = make_rand_torch([10, 8, 8], dtype=dequantized_dtype) input_scale = torch.tensor(0.5, dtype=quantized_scale_dtype) input_quantizer = StaticScaledQuantizer( name="q_input", scale=input_scale, dtype=quantized_dtype @@ -160,9 +167,11 @@ def testPerTensorScale( rtol=0, ) - expected = torch.matmul(x_dequantized, weight_dequantized.T) + expected = torch.matmul( + x_dequantized.to(ref_dtype), weight_dequantized.T.to(ref_dtype) + ) if with_bias: - expected += bias_dequantized + expected += bias_dequantized.to(ref_dtype) theta_tensors = [ input_quantizer, @@ -173,11 +182,15 @@ def testPerTensorScale( theta = Theta(theta_tensors) linear = LinearLayer(theta, fake_quant=fake_quant) actual = linear(x_dequantized) + actual = actual.to(dtype=expected.dtype) - torch.testing.assert_close( - actual.to(dtype=expected.dtype), expected, atol=atol, rtol=0 + abs_diff = (expected - actual).abs() + logger.info( + f"abs diff from expected (std, mean, median) = {[float(abs_diff.std()), float(abs_diff.mean()), float(abs_diff.median())]}" ) + torch.testing.assert_close(actual, expected, atol=atol, rtol=0) + if __name__ == "__main__": unittest.main()