diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 7c12f4e20..efbba089b 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -83,6 +83,7 @@ def main(): block_seq_stride=args.block_seq_stride, activation_dtype=args.activation_dtype, attention_dtype=args.attention_dtype, + kv_cache_dtype=args.kv_cache_dtype, ) llama_config.fake_quant = args.fake_quant diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 5d338bd74..b8b6026b6 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -288,8 +288,9 @@ def main(): block_seq_stride=args.block_seq_stride, device=device, activation_dtype=args.activation_dtype, - attention_dtype=args.activation_dtype, + attention_dtype=args.attention_dtype, attention_kernel=args.attention_kernel, + kv_cache_dtype=args.kv_cache_dtype, use_hf=args.use_hf, tensor_parallelism_size=args.tensor_parallelism_size, fake_quant=args.fake_quant, diff --git a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py index 21f9e9ed4..e9879bfab 100644 --- a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py +++ b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py @@ -7,16 +7,36 @@ from sharktank.kernels.base import * import torch +from typing import cast, Optional -from iree.compiler.ir import IntegerType +from iree.compiler.ir import IntegerType, Type +from iree.turbine.support.conversions import ( + TORCH_DTYPE_TO_IREE_TYPE_ASM, + IREE_TYPE_ASM_TO_TORCH_DTYPE, +) +from iree.turbine.runtime.op_reg import AttrArg __all__ = [ "batch_matmul_transpose_b", ] +def batch_matmul_transpose_b( + lhs: torch.Tensor, + rhs: torch.Tensor, + /, + *, + accum_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if accum_dtype is None: + accum_dtype = lhs.dtype + return _batch_matmul_transpose_b( + lhs, rhs, accum_dtype=TORCH_DTYPE_TO_IREE_TYPE_ASM[accum_dtype] + ) + + @CustomOp.register(library=LIBRARY) -class batch_matmul_transpose_b(CustomOp): +class _batch_matmul_transpose_b(CustomOp): """Generic block scaled matmul with transposed RHS. The LHS is expected to be a 3d tensor of shape [B, M, K]. RHS must be @@ -25,11 +45,18 @@ class batch_matmul_transpose_b(CustomOp): The kernel will be specialized for all values of N, K and LHS dtype. """ - signature = "batch_matmul_transpose_b(Tensor lhs, Tensor rhs) -> (Tensor)" + signature = ( + "batch_matmul_transpose_b(Tensor lhs, Tensor rhs, str accum_dtype) -> (Tensor)" + ) + + def eager_execute(self, lhs: torch.Tensor, rhs: torch.Tensor, accum_dtype: str): + dtype = IREE_TYPE_ASM_TO_TORCH_DTYPE[accum_dtype] + return torch.matmul(lhs.to(dtype=dtype), rhs.transpose(-1, -2).to(dtype=dtype)) def select(self, ksel: KernelSelection): lhs_desc = ksel.arg_tensor(0) # Shape [B, M, K] rhs_desc = ksel.arg_tensor(1) # Shape [B, N, K] + accum_type_attr = ksel.attr_str(2) # Rank check. torch._check( @@ -60,7 +87,8 @@ def select(self, ksel: KernelSelection): ) # Shape batch, m, n c_desc = ksel.return_new_tensor( - [lhs_batch, lhs_m, rhs_n], dtype=lhs_desc.t.dtype + [lhs_batch, lhs_m, rhs_n], + dtype=IREE_TYPE_ASM_TO_TORCH_DTYPE[accum_type_attr.v], ) specialize_all_known_dims(lhs_desc) specialize_all_known_dims(rhs_desc) @@ -74,12 +102,14 @@ def select(self, ksel: KernelSelection): def generate(self, ksel: KernelSelection, kb: KernelBuilder): lhs = kb.arg_value(0) rhs = kb.arg_value(1) + accum_type_str = cast(AttrArg, ksel.arg_descs[2]).v result_desc = ksel.result_descs[0] # Generate specialization signature and types. - a_asm_type, a_ident, accum_type = unpack_tensor_type(lhs.type) + a_asm_type, a_ident, _ = unpack_tensor_type(lhs.type) b_asm_type, b_ident, _ = unpack_tensor_type(rhs.type) - spec_sig = f"L{a_ident}_R{b_ident}" + accum_type = Type.parse(accum_type_str) + spec_sig = f"L{a_ident}_R{b_ident}_{accum_type_str}" template_file = "batch_matmul_transpose_b.mlir" target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}" cst_zero = "0" if IntegerType.isinstance(accum_type) else "0." diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 7d1e506a0..fafc2a98e 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -160,6 +160,9 @@ class LlamaModelConfig: # Either "paged" or "direct". kv_cache_type: str = "paged" + # If None will use attention_dtype. + kv_cache_dtype: Optional[torch.dtype] = None + # The device on which to place intermediate state. device: Optional[torch.device] = None diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index be8c66fb4..a961a9dd7 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -275,7 +275,14 @@ def write_timestep( partitions = partitions.repeat(bs, 1) indices = (page_id, transformer_block, partitions, page_offset) - page_table.index_put_(indices=indices, values=cache_partition) + values = ops.to(cache_partition, dtype=page_table.dtype) + if page_table.dtype == torch.float8_e4m3fnuz: + # Workaround for Torch not supporting torch.Tensor.index_copy_ for f8. + page_table_as_int8 = page_table.view(dtype=torch.int8) + values_int8 = values.view(dtype=torch.int8) + page_table_as_int8.index_put_(indices=indices, values=values_int8) + else: + page_table.index_put_(indices=indices, values=values) return @@ -320,4 +327,11 @@ def write( (base_subblock_ids + index) if index > 0 else base_subblock_ids ).flatten(0, 1) - subblock_table.index_copy_(0, subblock_ids, part_block_view) + part_block = ops.to(part_block_view, dtype=subblock_table.dtype) + if subblock_table.dtype == torch.float8_e4m3fnuz: + # Workaround for Torch not supporting torch.Tensor.index_copy_ for f8. + subblock_table_as_int8 = subblock_table.view(dtype=torch.int8) + part_block_as_int8 = part_block.view(dtype=torch.int8) + subblock_table_as_int8.index_copy_(0, subblock_ids, part_block_as_int8) + else: + subblock_table.index_copy_(0, subblock_ids, part_block) diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index a1f1366ab..79ede9f5c 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -81,14 +81,6 @@ def forward(self, x): # Unconditionally dequantize. if isinstance(y, QuantizedTensor): y = y.unpack().dequant() - # Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32. - # We can truncate to fp16 in iree, so we do a cast here - # to account for this in the IR. This is may not be the right - # level to do this, but for now its here. - if not isinstance(y, QuantizedTensor): - if y.dtype == torch.float8_e4m3fnuz: - y = ops.to(y, torch.bfloat16) - return y if qdq_output is not None: y = qdq_output.quantize(y).unpack().dequant() return y diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 71bcd1e8c..76dafb918 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -37,6 +37,7 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, + attention_dtype: Optional[torch.dtype] = None, attention_kernel: str = "decomposed", attention_scale: Optional[float] = None, softcap: Optional[float] = None, @@ -49,6 +50,7 @@ def __init__( self.head_count = head_count self.head_dim = head_dim self.head_count_kv = head_count_kv + self.attention_dtype = attention_dtype self.attention_kernel = attention_kernel self.attention_scale = attention_scale self.softcap = softcap @@ -161,13 +163,13 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: # Fake quant is already dequantized when stored in the cache. if self.cache_quantizer and not self.fake_quant: xk = self.cache_quantizer.dequantize_raw_tensor( - xk, torch.bfloat16, name="xk_deq" + xk, self.attention_dtype, name="xk_deq" ) xv = self.cache_quantizer.dequantize_raw_tensor( - xv, torch.bfloat16, name="xv_deq" + xv, self.attention_dtype, name="xv_deq" ) if attention_mask is not None: - attention_mask = attention_mask.to(torch.bfloat16) + attention_mask = attention_mask.to(self.attention_dtype) # Transpose into [bs, heads, sl, dim] xq = xq.transpose(1, 2) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index a31913833..bc18cb492 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -117,6 +117,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): head_dim=hp.attn_head_dim, head_count_kv=hp.attention_head_count_kv, rms_epsilon=hp.attention_layer_norm_rms_epsilon, + attention_dtype=config.attention_dtype, attention_kernel=self.attention_kernel, fake_quant=self.fake_quant, ) @@ -241,6 +242,7 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, + attention_dtype: Optional[torch.dtype] = None, attention_kernel: str = "decomposed", fake_quant: bool = True, ): @@ -255,6 +257,7 @@ def __init__( head_dim=head_dim, head_count_kv=head_count_kv, rms_epsilon=rms_epsilon, + attention_dtype=attention_dtype, attention_kernel=attention_kernel, fake_quant=fake_quant, ), diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index f88684273..104349266 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -44,16 +44,16 @@ def qlinear_tensor_scaled( # Now we know that both the x/weight are TensorScaledLayout. There are still # degrees of freedom: # * Either/both can be per-tensor or per-axis scaled (d is 0D or d is nd>0). - # * Either/both can have offsets of not (m is not None). + # * Either/both can have offsets or not (m is not None). x_layout: TensorScaledLayout = x.unpack() weight_layout: TensorScaledLayout = weight.unpack() # Handle only integer and fp8 quantizations. if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point: - if x_layout.qs.dtype == torch.float8_e4m3fnuz: - # assume quark - return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True) - else: + if ( + x_layout.qs.dtype != torch.float8_e4m3fnuz + or weight_layout.qs.dtype != torch.float8_e4m3fnuz + ): return NotImplemented # Bias. @@ -160,6 +160,8 @@ def linear_quantized_weight( *, accum_dtype: Optional[torch.dtype], ) -> AnyTensor: + if accum_dtype is not None: + raise NotImplementedError("TODO: implement when accum_dtype is passed") res = matmul(x, weight, transpose_rhs=True) if bias is not None: res = res + bias @@ -170,7 +172,13 @@ def linear_quantized_weight( linear.override(Tensor, QuantizedTensor, AnyTensor)(linear_quantized_weight) -def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype): +def _is_dtype_unsigned_integer(dtype: torch.dtype): + return not dtype.is_complex and not dtype.is_floating_point and not dtype.is_signed + + +def _invoke_mmt_kernel( + lhs: torch.Tensor, rhs: torch.Tensor, *, accum_dtype: torch.dtype +): if debugging.flags.use_custom_iree_kernels: # The custom kernel requires that the lhs and rhs be the same # rank. Broadcast the rhs to match. @@ -187,13 +195,17 @@ def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype): rhs_size = [lhs.shape[0]] + list(rhs.shape) rhs = rhs.unsqueeze(0).expand(rhs_size) rhs_rank = len(rhs.shape) - y_qs = kernels.batch_matmul_transpose_b( - lhs.to(accum_dtype), rhs.to(accum_dtype) - ) - # Squeeze the batch dimension to maintain shape parity with other - # layers. - if len(y_qs.shape) > 2: - y_qs = y_qs.squeeze(0) + if ( + _is_dtype_unsigned_integer(lhs.dtype) + or _is_dtype_unsigned_integer(rhs.dtype) + or _is_dtype_unsigned_integer(accum_dtype) + ): + # TODO: make the kernel work with unsigned types. + y_qs = kernels.batch_matmul_transpose_b( + lhs.to(dtype=accum_dtype), rhs.to(dtype=accum_dtype) + ) + else: + y_qs = kernels.batch_matmul_transpose_b(lhs, rhs, accum_dtype=accum_dtype) else: # FP emulation. y_qs = torch.matmul( diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index e3dba31fa..36ae89cc6 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -32,10 +32,12 @@ def parse(parser: argparse.ArgumentParser, *, args: Sequence[str] | None = None) """Parses arguments and does any prescribed global process setup.""" parsed_args = parser.parse_args(args) # Set torch dtypes - for attr in ["activation_dtype", "attention_dtype"]: + for attr in ["activation_dtype", "attention_dtype", "kv_cache_dtype"]: if hasattr(parsed_args, attr): - dtype = getattr(torch, getattr(parsed_args, attr)) - assert isinstance(dtype, torch.dtype) + dtype = getattr(parsed_args, attr) + if dtype is not None: + dtype = getattr(torch, dtype) + assert isinstance(dtype, torch.dtype) setattr(parsed_args, attr, dtype) return parsed_args @@ -100,6 +102,11 @@ def add_model_options(parser: argparse.ArgumentParser): help="DType to use for activations in the model", default="float16", ) + parser.add_argument( + "--kv-cache-dtype", + help="DType to use for the KV cache. If not given will be attention dtype", + default=None, + ) parser.add_argument("--device", help="Torch device (or default)") parser.add_argument( diff --git a/sharktank/sharktank/utils/create_cache.py b/sharktank/sharktank/utils/create_cache.py index f462d9c00..eb26f5a14 100644 --- a/sharktank/sharktank/utils/create_cache.py +++ b/sharktank/sharktank/utils/create_cache.py @@ -12,6 +12,7 @@ def create_paged_kv_cache(config: LlamaModelConfig) -> PagedKVCache: raise ValueError("Model does not use paged kv cache, cannot create kv cache") hp = config.hp + dtype = config.kv_cache_dtype or config.attention_dtype return PagedKVCache( transformer_block_count=hp.block_count, attn_head_count=hp.attention_head_count_kv, @@ -19,6 +20,6 @@ def create_paged_kv_cache(config: LlamaModelConfig) -> PagedKVCache: cache_partition_count=2, # One for each of K/V. block_seq_stride=config.block_seq_stride, device=config.device, - dtype=config.attention_dtype, + dtype=dtype, shard_count=config.tensor_parallelism_size, ) diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index 75071e286..608f65c48 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -95,6 +95,7 @@ def __init__( use_attention_mask: bool = False, activation_dtype: str = "float16", attention_dtype: str = "float16", + kv_cache_dtype: Optional[str] = None, ): self.sharktank_dir = str( Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent @@ -109,6 +110,7 @@ def __init__( self.use_attention_mask = use_attention_mask self.activation_dtype = activation_dtype self.attention_dtype = attention_dtype + self.kv_cache_dtype = kv_cache_dtype def timeit(func): def wrapper(*args, **kwargs): @@ -189,6 +191,8 @@ def export_to_mlir( f"--attention-dtype={self.attention_dtype}", f"--activation-dtype={self.activation_dtype}", ] + if self.kv_cache_dtype is not None: + export_args.append(f"--kv-cache-dtype={self.kv_cache_dtype}") if skip_decode: export_args.append("--skip-decode") if self.attention_kernel in ["decomposed", "torch"]: diff --git a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py index 208d54782..3ac260265 100644 --- a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py +++ b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py @@ -10,11 +10,13 @@ import unittest from parameterized import parameterized - +import pytest import torch from iree.turbine import aot +from iree.turbine.support.conversions import TORCH_DTYPE_TO_IREE_TYPE_ASM from sharktank import kernels +from sharktank.utils.testing import skip class batch_matmul_transpose_b_test(unittest.TestCase): @@ -40,24 +42,111 @@ def testBS32(self, atol, rtol): ref = torch.matmul(a, bT) torch.testing.assert_close(result, ref, atol=atol, rtol=rtol) - def testExportStaticDims(self): + def testArgF8AccumF32(self): + # TODO: make this test not use eager but actually execute with IREE. + # Does not compile for llvm-cpu with + # :0: error: 'llvm.fpext' op operand #0 must be floating point LLVM type or LLVM dialect-compatible vector of floating point LLVM type, but got 'vector<4xi8>' + # :0: note: see current operation: %120 = "llvm.fpext"(%109) : (vector<4xi8>) -> vector<4xf32> + arg_dtype = torch.float8_e4m3fnuz + a = torch.rand([3, 4, 6]).to(arg_dtype) + b = torch.rand([3, 5, 6]).to(arg_dtype) + accum_dtype = torch.float32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + # Dequantize and test with normal matmul. + # Tolerances are empirical and results are not expected to match exactly. + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) + torch.testing.assert_close(result, ref, atol=1e-3, rtol=0) + + def testArgUi8AccumI32(self): + # TODO: make this test not use eager but actually execute with IREE. + # Does not work with unsigned types. The kernel needs to be adapted. + arg_dtype = torch.uint8 + a = ((torch.rand([2, 3, 5]) * 255) + 0.5).to(dtype=arg_dtype) + b = ((torch.rand([2, 4, 5]) * 255) + 0.5).to(dtype=arg_dtype) + accum_dtype = torch.int32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) + torch.testing.assert_close(result, ref, atol=0, rtol=0) + + def testArgLhsI8RhsUi8AccumI32(self): + a = ((torch.rand([2, 3, 5]) - 0.5) * 255).to(dtype=torch.int8) + b = ((torch.rand([2, 4, 5]) * 255) + 0.5).to(dtype=torch.uint8) + accum_dtype = torch.int32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) + torch.testing.assert_close(result, ref, atol=0, rtol=0) + + def testArgI8AccumI32(self): + arg_dtype = torch.int8 + a = ((torch.rand([2, 3, 5]) - 0.5) * 255).to(dtype=arg_dtype) + b = ((torch.rand([2, 3, 5]) - 0.5) * 255).to(dtype=arg_dtype) + accum_dtype = torch.int32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) + torch.testing.assert_close(result, ref, atol=0, rtol=0) + + @pytest.mark.xfail( + reason="""No uint32 dtype conversions in IREE Turbine. + Does not work with unsigned types. The kernel needs to be adapted. + The problem is that we reinterpret cast to signless integer types. + Maybe linalg.batch_matmul_transpose_b when promoting from i8 to i32 assumes a + signed type even though i8 is signless.""" + ) + def testArgUi8AccumUi32(self): + arg_dtype = torch.uint8 + a = ((torch.rand([2, 3, 5]) * 255) + 0.5).to(dtype=arg_dtype) + b = ((torch.rand([2, 4, 5]) * 255) + 0.5).to(dtype=arg_dtype) + accum_dtype = torch.uint32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=torch.int32), bT.to(dtype=torch.int32)) + ref = ref.to(dtype=accum_dtype) + torch.testing.assert_close(result, ref, atol=0, rtol=0) + + @parameterized.expand( + [ + (torch.int32, None), + (torch.float8_e4m3fnuz, torch.float32), + ] + ) + def testExportStaticDims( + self, arg_dtype: torch.dtype, accum_dtype: torch.dtype | None + ): class MyModule(torch.nn.Module): def forward(self, a, b): - return kernels.batch_matmul_transpose_b(a, b) + return kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) mod = MyModule() - dtype = torch.int32 ep = torch.export.export( mod, args=( - (torch.rand([4, 16, 2]) * 64).to(dtype), - (torch.rand([4, 8, 2]) * 64).to(dtype), + (torch.rand([4, 16, 2]) * 64).to(arg_dtype), + (torch.rand([4, 8, 2]) * 64).to(arg_dtype), ), ) output = aot.export(ep) output.verify() asm = str(output.mlir_module) - self.assertIn("@sharktank_batch_matmul_transpose_b_L4x16x2xi32_R4x8x2xi32", asm) + arg_dtype_asm = TORCH_DTYPE_TO_IREE_TYPE_ASM[arg_dtype] + accum_dtype_asm = arg_dtype_asm + if accum_dtype is not None: + accum_dtype_asm = TORCH_DTYPE_TO_IREE_TYPE_ASM[accum_dtype] + self.assertIn( + ( + "@sharktank_batch_matmul_transpose_b_" + f"L4x16x2x{arg_dtype_asm}_R4x8x2x{arg_dtype_asm}_{accum_dtype_asm}" + ), + asm, + ) if __name__ == "__main__": diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py index d59d0a85b..ff60ef8b6 100644 --- a/sharktank/tests/layers/kv_cache_test.py +++ b/sharktank/tests/layers/kv_cache_test.py @@ -4,8 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import unittest - +import pytest import torch from sharktank.ops import replicate, reshard_split, unshard @@ -13,7 +12,16 @@ from sharktank.types import * -def test_paged(): +@pytest.mark.parametrize( + "dtype", + [ + torch.float8_e4m3fnuz, + torch.bfloat16, + torch.float16, + torch.float32, + ], +) +def test_paged(dtype: torch.dtype): bs = 4 seq_length = 24 attn_head_count = 4 @@ -25,7 +33,7 @@ def test_paged(): transformer_block_count=transformer_block_count, attn_head_count=attn_head_count, attn_head_dim=attn_head_dim, - dtype=torch.float32, + dtype=dtype, device=None, ) @@ -36,15 +44,18 @@ def test_paged(): write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] allocation = cache.allocate(page_count=page_count) - allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + for t in allocation: + t[...] = torch.full(t.shape, 0.0).to(dtype=dtype) # Write a prefill in: write_ones = torch.full( - (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 - ) + (bs, write_seq_length, attn_head_count, attn_head_dim), + 1.0, + ).to(dtype=dtype) write_twos = torch.full( - (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 - ) + (bs, write_seq_length, attn_head_count, attn_head_dim), + 2.0, + ).to(dtype=dtype) cache.write( allocation, @@ -72,15 +83,19 @@ def test_paged(): seq_len=write_seq_length, page_ids=write_page_ids, ) - torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) - torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) + torch.testing.assert_close( + read_ones[0], torch.full(read_ones[0].shape, 0.0).to(dtype=dtype) + ) + torch.testing.assert_close( + read_ones[1], torch.full(read_ones[0].shape, 0.0).to(dtype=dtype) + ) # Write timestep - write_threes = torch.full( - (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 + write_threes = torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0).to( + dtype=dtype ) - write_fours = torch.full( - (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 + write_fours = torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0).to( + dtype=dtype ) write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) cache.write_timestep( @@ -98,8 +113,16 @@ def test_paged(): page_ids=page_ids, ) - check_concat_0 = torch.concat([write_ones, write_threes], dim=1) - check_concat_1 = torch.concat([write_twos, write_fours], dim=1) + if dtype == torch.float8_e4m3fnuz: + check_concat_0 = torch.concat( + [write_ones.view(torch.int8), write_threes.view(torch.int8)], dim=1 + ).view(torch.float8_e4m3fnuz) + check_concat_1 = torch.concat( + [write_twos.view(torch.int8), write_fours.view(torch.int8)], dim=1 + ).view(torch.float8_e4m3fnuz) + else: + check_concat_0 = torch.concat([write_ones, write_threes], dim=1) + check_concat_1 = torch.concat([write_twos, write_fours], dim=1) torch.testing.assert_close(check_concat_0, read_back[0]) torch.testing.assert_close(check_concat_1, read_back[1]) diff --git a/sharktank/tests/layers/linear_test.py b/sharktank/tests/layers/linear_test.py index ad657889d..08164176f 100644 --- a/sharktank/tests/layers/linear_test.py +++ b/sharktank/tests/layers/linear_test.py @@ -4,12 +4,16 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import logging import unittest - import torch +from parameterized import parameterized from sharktank.layers import * from sharktank.types import * +from sharktank.utils.testing import make_rand_torch + +logger = logging.getLogger(__name__) def _randomize_per_axis(t: torch.Tensor, axis: int, offset_range: float = 0.0): @@ -91,6 +95,102 @@ def testNativeQuant_SymPerTensor_AsymPerAxis0_Dynamic(self): print(torch.abs(output - output_ref)) torch.testing.assert_close(output, output_ref, atol=1e-1, rtol=1e-1) + @parameterized.expand( + [ + (torch.bfloat16, torch.float32, torch.float8_e4m3fnuz, False, False, 1e-2), + (torch.bfloat16, torch.float32, torch.float8_e4m3fnuz, False, True, 1e-2), + (torch.float32, torch.float32, torch.float8_e4m3fnuz, False, False, 1e-6), + (torch.float32, torch.float32, torch.float8_e4m3fnuz, False, True, 1e-6), + (torch.float32, torch.float32, torch.float16, True, False, 1e-6), + (torch.float32, torch.float32, torch.float16, False, False, 1e-6), + (torch.float32, torch.float32, torch.float16, False, True, 1e-6), + (torch.float32, torch.float32, torch.float32, True, False, 1e-6), + ], + ) + def testPerTensorScale( + self, + dequantized_dtype: torch.dtype, + quantized_scale_dtype: torch.dtype, + quantized_dtype: torch.dtype, + with_bias: bool, + fake_quant: bool, + atol: float, + ): + """Test a linear layer where each tensor being quantized with a single + different scale.""" + ref_dtype = torch.float64 + + x = make_rand_torch([10, 8, 8], dtype=dequantized_dtype) + input_scale = torch.tensor(0.5, dtype=quantized_scale_dtype) + input_quantizer = StaticScaledQuantizer( + name="q_input", scale=input_scale, dtype=quantized_dtype + ) + # We roundtrip through quantization to know that any discrepancies in the + # results come from the quantized linear operation itself. Not form the + # inaccuracies of the initial quantization. + x_dequantized = input_quantizer.quantize(x).unpack().dequant() + torch.testing.assert_close( + input_quantizer.quantize(x_dequantized).unpack().dequant(), + x_dequantized, + atol=0, + rtol=0, + ) + + weight = make_rand_torch([12, x.shape[2]], dtype=dequantized_dtype) + weight_scale = torch.tensor(0.66, dtype=quantized_scale_dtype) + weight_quantizer = StaticScaledQuantizer( + scale=weight_scale, dtype=quantized_dtype + ) + weight_dequantized = weight_quantizer.quantize(weight).unpack().dequant() + weight_quantized = weight_quantizer.quantize(weight_dequantized, name="weight") + torch.testing.assert_close( + weight_quantizer.quantize(weight_dequantized).unpack().dequant(), + weight_dequantized, + atol=0, + rtol=0, + ) + + if with_bias: + bias = make_rand_torch( + [x.shape[1], weight.shape[0]], dtype=dequantized_dtype + ) + bias_scale = torch.tensor(1.25, dtype=quantized_scale_dtype) + bias_quantizer = StaticScaledQuantizer( + scale=bias_scale, dtype=quantized_dtype + ) + bias_dequantized = bias_quantizer.quantize(bias).unpack().dequant() + bias_quantized = bias_quantizer.quantize(bias_dequantized, name="bias") + torch.testing.assert_close( + bias_quantizer.quantize(bias_dequantized).unpack().dequant(), + bias_dequantized, + atol=0, + rtol=0, + ) + + expected = torch.matmul( + x_dequantized.to(ref_dtype), weight_dequantized.T.to(ref_dtype) + ) + if with_bias: + expected += bias_dequantized.to(ref_dtype) + + theta_tensors = [ + input_quantizer, + weight_quantized, + ] + if with_bias: + theta_tensors += [bias_quantized] + theta = Theta(theta_tensors) + linear = LinearLayer(theta, fake_quant=fake_quant) + actual = linear(x_dequantized) + actual = actual.to(dtype=expected.dtype) + + abs_diff = (expected - actual).abs() + logger.info( + f"abs diff from expected (std, mean, median) = {[float(abs_diff.std()), float(abs_diff.mean()), float(abs_diff.median())]}" + ) + + torch.testing.assert_close(actual, expected, atol=atol, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index 6056f9c6c..39d6f4ce8 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -94,7 +94,8 @@ def setUp(self): tensor_parallelism_size=self.tensor_parallelism_size, block_seq_stride=32, activation_dtype="bfloat16", - attention_dtype="float8_e4m3fnuz", + attention_dtype="bfloat16", + kv_cache_dtype="float8_e4m3fnuz", ) self.prefill_args_bs4_128_stride_32_f16 = ( self.artifacts_dir / "prefill_args_bs4_128_stride_32_tp1" @@ -242,9 +243,9 @@ def testBenchmark8B_f16_Non_Decomposed_Input_Len_2048(self): ) @pytest.mark.xfail( - reason="Benchmark inputs not configured yet.", + reason="Fails due to https://github.com/iree-org/iree/issues/20002.", strict=True, - raises=IreeBenchmarkException, + raises=IreeCompileException, ) def testBenchmark8B_fp8_Non_Decomposed(self): output_file_name = self.dir_path_8b / "fp8_torch" diff --git a/sharktank/tests/models/llama/quark_parity_test.py b/sharktank/tests/models/llama/quark_parity_test.py index b45696fe4..b8a30c543 100644 --- a/sharktank/tests/models/llama/quark_parity_test.py +++ b/sharktank/tests/models/llama/quark_parity_test.py @@ -11,11 +11,12 @@ import pytest from pathlib import Path import subprocess +from sharktank.utils.testing import TempDirTestBase with_quark_data = pytest.mark.skipif("not config.getoption('with_quark_data')") -class QuarkParityTest(unittest.TestCase): +class QuarkParityTest(TempDirTestBase): def setUp(self): super().setUp() self.path_prefix = Path("/shark-dev/quark_test") @@ -25,7 +26,7 @@ def test_compare_against_quark(self): sharktank_dir = str( Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent.parent ) - our_path = self.path_prefix / "ours_prefill.safetensors" + our_path = self._temp_dir / "ours_prefill.safetensors" if os.path.exists(our_path): os.remove(our_path) @@ -58,16 +59,15 @@ def test_compare_against_quark(self): "--fake-quant", "--attention-kernel=torch", "--activation-dtype=bfloat16", - f"--save_intermediates_path={self.path_prefix}/ours", + f"--save_intermediates_path={self._temp_dir / 'ours'}", "--use-hf", "--attention-dtype=bfloat16", + "--kv-cache-dtype=float8_e4m3fnuz", "--skip-decode", "--block-seq-stride=16", ] command = subprocess.list2cmdline(command) - proc = subprocess.run( - command, shell=True, capture_output=True, cwd=sharktank_dir - ) + subprocess.check_call(command, shell=True, cwd=sharktank_dir) ours = dict() with safe_open(our_path, "pytorch") as st: