diff --git a/models/demos/falcon7b_common/tt/model_utils.py b/models/demos/falcon7b_common/tt/model_utils.py index 3bf7dc0919d..2b068eaeade 100644 --- a/models/demos/falcon7b_common/tt/model_utils.py +++ b/models/demos/falcon7b_common/tt/model_utils.py @@ -50,7 +50,9 @@ def preprocess_weights(weights_to_cache): layout=tt_layout, device=mesh_device, memory_config=model_config[f"{weight_config_str}_MEMCFG"], - mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if type(mesh_device) == ttnn.MeshDevice else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) + if type(mesh_device) == ttnn.MeshDevice + else None, cache_file_name=str(path), preprocess=preprocess_weights, ) diff --git a/models/demos/llama3/tt/llama_common.py b/models/demos/llama3/tt/llama_common.py index 7ec888fa9b3..829d02761a9 100644 --- a/models/demos/llama3/tt/llama_common.py +++ b/models/demos/llama3/tt/llama_common.py @@ -402,7 +402,9 @@ def sample_host(tt_input, mesh_device, temperature=0.6, top_p=0.08, on_host=True layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.uint32, device=None, - mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if mesh_device.get_num_devices() > 1 else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) + if mesh_device.get_num_devices() > 1 + else None, ), pt_out, ) diff --git a/models/demos/t3000/llama2_70b/tt/llama_common.py b/models/demos/t3000/llama2_70b/tt/llama_common.py index a834b18e653..14474e541ba 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_common.py +++ b/models/demos/t3000/llama2_70b/tt/llama_common.py @@ -31,6 +31,7 @@ MeshToTensor, ) + class ConcatMesh2DToTensor(MeshToTensor): def __init__(self, mesh_device, dims, cluster_shape): self.dims = dims diff --git a/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py b/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py index eba08c3cb59..e455197f1cb 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py @@ -17,513 +17,460 @@ from models.demos.tg.llama3_70b.tt.llama_common import setup_llama_env from models.demos.t3000.llama2_70b.reference.llama.llama.model import precompute_freqs_cis from models.demos.t3000.llama2_70b.tt.llama_common import ( - check_mesh_device, - extract_pcc_from_log, - generate_rot_emb, - get_rotation_mat, - gather_cos_sin, - precompute_freqs, - MAX_SEQ_LEN, - MAX_SEQ_LEN_LLAMA3, - BASE_URL, - UNIT_TEST_N_LAYER, - UNIT_TEST_LAYER_NUM, - UNIT_TEST_START_POS, - UNIT_TEST_GENERATION_LENGTH, - comp_pcc, - get_rot_transformation_mat, - should_skip_model_load, - check_kv_cache, - num_to_corerange, - ConcatMesh2DToTensor, + check_mesh_device, + extract_pcc_from_log, + generate_rot_emb, + get_rotation_mat, + gather_cos_sin, + precompute_freqs, + MAX_SEQ_LEN, + MAX_SEQ_LEN_LLAMA3, + BASE_URL, + UNIT_TEST_N_LAYER, + UNIT_TEST_LAYER_NUM, + UNIT_TEST_START_POS, + UNIT_TEST_GENERATION_LENGTH, + comp_pcc, + get_rot_transformation_mat, + should_skip_model_load, + check_kv_cache, + num_to_corerange, + ConcatMesh2DToTensor, ) from models.utility_functions import skip_for_grayskull - - class PytorchLlamaAttentionModel(torch.nn.Module): - def __init__(self, hf_reference_model, layer_num, rope_theta): - super().__init__() - self.attention = hf_reference_model.layers[layer_num].attention - self.rope_theta = rope_theta - # Disable dropout - self.attention.eval() - - - configuration = hf_reference_model.params - self.n_heads = configuration.n_heads - hidden_dim = configuration.dim - self.head_dim = hidden_dim // self.n_heads - self.max_seq_len = configuration.max_seq_len - - - def prepare_inputs(self, x, start_pos): - """ - Prepare inputs for decode mode. Assume that current token is at - start_pos, and KV cache has valid data up to start_pos. - """ - batch = x.size(0) - freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2, self.rope_theta) - freqs_cis = freqs_cis[start_pos : start_pos + 1] - - - attn_mask = torch.zeros(batch, 1, 1, start_pos + 1) - # attn_mask[:, :, :, : start_pos + 1] = -1e9 - attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1) - - - return x, start_pos, freqs_cis, attn_mask - - - def prepare_inputs_prefill(self, x, start_pos): - """ - Prepare inputs for decode mode. Assume that current token is at - start_pos, and KV cache has valid data up to start_pos. - """ - batch = x.size(0) - seq_len = x.size(1) - freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2, self.rope_theta) - freqs_cis = freqs_cis[start_pos : start_pos + seq_len] - - - attn_mask = torch.full((seq_len, seq_len), float("-inf")) - attn_mask = torch.triu(attn_mask, diagonal=1) - attn_mask = attn_mask.expand(batch, self.n_heads, -1, -1) - - - return x, start_pos, freqs_cis, attn_mask - - - def forward(self, x, start_pos, freqs_cis, mask): - """ - x: (batch, seq, hidden_dim) - start_pos: int - freqs_cis: ? - mask: ? - - - return: (batch, seq, hidden_dim) - """ - result = self.attention( - x, - start_pos, - freqs_cis, - mask, - ) - return result - - + def __init__(self, hf_reference_model, layer_num, rope_theta): + super().__init__() + self.attention = hf_reference_model.layers[layer_num].attention + self.rope_theta = rope_theta + # Disable dropout + self.attention.eval() + + configuration = hf_reference_model.params + self.n_heads = configuration.n_heads + hidden_dim = configuration.dim + self.head_dim = hidden_dim // self.n_heads + self.max_seq_len = configuration.max_seq_len + + def prepare_inputs(self, x, start_pos): + """ + Prepare inputs for decode mode. Assume that current token is at + start_pos, and KV cache has valid data up to start_pos. + """ + batch = x.size(0) + freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2, self.rope_theta) + freqs_cis = freqs_cis[start_pos : start_pos + 1] + + attn_mask = torch.zeros(batch, 1, 1, start_pos + 1) + # attn_mask[:, :, :, : start_pos + 1] = -1e9 + attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1) + + return x, start_pos, freqs_cis, attn_mask + + def prepare_inputs_prefill(self, x, start_pos): + """ + Prepare inputs for decode mode. Assume that current token is at + start_pos, and KV cache has valid data up to start_pos. + """ + batch = x.size(0) + seq_len = x.size(1) + freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2, self.rope_theta) + freqs_cis = freqs_cis[start_pos : start_pos + seq_len] + + attn_mask = torch.full((seq_len, seq_len), float("-inf")) + attn_mask = torch.triu(attn_mask, diagonal=1) + attn_mask = attn_mask.expand(batch, self.n_heads, -1, -1) + + return x, start_pos, freqs_cis, attn_mask + + def forward(self, x, start_pos, freqs_cis, mask): + """ + x: (batch, seq, hidden_dim) + start_pos: int + freqs_cis: ? + mask: ? + + + return: (batch, seq, hidden_dim) + """ + result = self.attention( + x, + start_pos, + freqs_cis, + mask, + ) + return result def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos, rope_theta, mode="decode"): - assert len(x.size()) == 3 - batch, seq_len, _ = x.shape - - - cache_name = lambda name: llama_attention_model.cache_path / (f"{name}") - - - if mode == "decode": - assert seq_len == 1, "Only supporting decode mode" - x = x.transpose(0, 1).unsqueeze(1) - assert x.shape == (seq_len, 1, batch, llama_attention_model.hidden_size) - - - ACT_MEMCFG = ttnn.create_sharded_memory_config( - shape=(x.shape[2], x.shape[3] // 32 // llama_attention_model.cluster_shape[0]), - core_grid=ttnn.CoreGrid(y=4, x=8), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - xs = ttnn.as_tensor( - x, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - memory_config=ACT_MEMCFG, - device=llama_attention_model.mesh_device, - mesh_mapper=shard_tensor_to_2d_mesh_mapper( - llama_attention_model.mesh_device, mesh_shape=llama_attention_model.cluster_shape, dims=(None, 3) - ), - ) - - - batch_size_per_group = llama_attention_model.batch_size_per_device_group - - - rot_emb = generate_rot_emb(llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2, rope_theta) - rot_mat = get_rotation_mat(rot_emb, start_pos, seq_len, batch=batch_size_per_group) - assert rot_mat.size() == ( - 1, - batch_size_per_group, - llama_attention_model.head_dim, - llama_attention_model.head_dim, - ) - - - shard_spec_n_cores_grid = ttnn.CoreRangeSet({num_to_corerange(batch_size_per_group)}) - ROT_MAT_MEMCFG = ttnn.MemoryConfig( - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.BufferType.L1, - ttnn.ShardSpec( - shard_spec_n_cores_grid, - [ - llama_attention_model.head_dim, - llama_attention_model.head_dim, - ], - ttnn.ShardOrientation.ROW_MAJOR, - ), - ) - rot_mats = ttnn.as_tensor( - rot_mat, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - memory_config=ROT_MAT_MEMCFG, - device=llama_attention_model.mesh_device, - mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), - ) - - - attn_masks = None - - - elif mode == "prefill": - assert ( - seq_len % 256 == 0 and seq_len > 0 and seq_len <= 8192 - ), "Prefill mode only supports seqlen as a multiple of 256 up to 8k" - assert batch == 1, "prefill mode only supports batch size 1" - x = x.unsqueeze(0) - assert x.shape == (1, batch, seq_len, llama_attention_model.hidden_size) - xs = ttnn.as_tensor( - x, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=llama_attention_model.mesh_device, - mesh_mapper=shard_tensor_to_2d_mesh_mapper( - llama_attention_model.mesh_device, mesh_shape=llama_attention_model.cluster_shape, dims=(None, 3) - ), - ) - - - cos, sin = precompute_freqs( - llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2, rope_theta, use_scaled=False - ) - cos_gathered, sin_gathered = gather_cos_sin(torch.arange(start_pos, start_pos + seq_len), cos, sin) - assert cos_gathered.size() == (1, 1, seq_len, llama_attention_model.head_dim) - assert sin_gathered.size() == (1, 1, seq_len, llama_attention_model.head_dim) - - - cos_gathereds = ttnn.as_tensor( - cos_gathered, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - # cache_file_name=cache_name(f"cos_gathered_prefill_{seq_len}"), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=llama_attention_model.mesh_device, - mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), - ) - sin_gathereds = ttnn.as_tensor( - sin_gathered, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - # cache_file_name=cache_name(f"sin_gathered_prefill_{seq_len}"), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=llama_attention_model.mesh_device, - mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), - ) - - - rot_mats = [cos_gathereds, sin_gathereds] - - - attn_mask = torch.full((seq_len, seq_len), torch.finfo(torch.float32).min) - attn_mask = torch.triu(attn_mask, diagonal=1) - attn_mask = attn_mask.expand(1, batch, -1, -1) - attn_masks = ttnn.as_tensor( - attn_mask, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - # cache_file_name=cache_name(f"attn_mask_prefill_{seq_len}"), - mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=llama_attention_model.mesh_device, - ) - - - return ( - xs, - start_pos, - rot_mats, - attn_masks, - ) - - + assert len(x.size()) == 3 + batch, seq_len, _ = x.shape + + cache_name = lambda name: llama_attention_model.cache_path / (f"{name}") + + if mode == "decode": + assert seq_len == 1, "Only supporting decode mode" + x = x.transpose(0, 1).unsqueeze(1) + assert x.shape == (seq_len, 1, batch, llama_attention_model.hidden_size) + + ACT_MEMCFG = ttnn.create_sharded_memory_config( + shape=(x.shape[2], x.shape[3] // 32 // llama_attention_model.cluster_shape[0]), + core_grid=ttnn.CoreGrid(y=4, x=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + xs = ttnn.as_tensor( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ACT_MEMCFG, + device=llama_attention_model.mesh_device, + mesh_mapper=shard_tensor_to_2d_mesh_mapper( + llama_attention_model.mesh_device, mesh_shape=llama_attention_model.cluster_shape, dims=(None, 3) + ), + ) + + batch_size_per_group = llama_attention_model.batch_size_per_device_group + + rot_emb = generate_rot_emb(llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2, rope_theta) + rot_mat = get_rotation_mat(rot_emb, start_pos, seq_len, batch=batch_size_per_group) + assert rot_mat.size() == ( + 1, + batch_size_per_group, + llama_attention_model.head_dim, + llama_attention_model.head_dim, + ) + + shard_spec_n_cores_grid = ttnn.CoreRangeSet({num_to_corerange(batch_size_per_group)}) + ROT_MAT_MEMCFG = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + shard_spec_n_cores_grid, + [ + llama_attention_model.head_dim, + llama_attention_model.head_dim, + ], + ttnn.ShardOrientation.ROW_MAJOR, + ), + ) + rot_mats = ttnn.as_tensor( + rot_mat, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ROT_MAT_MEMCFG, + device=llama_attention_model.mesh_device, + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), + ) + + attn_masks = None + + elif mode == "prefill": + assert ( + seq_len % 256 == 0 and seq_len > 0 and seq_len <= 8192 + ), "Prefill mode only supports seqlen as a multiple of 256 up to 8k" + assert batch == 1, "prefill mode only supports batch size 1" + x = x.unsqueeze(0) + assert x.shape == (1, batch, seq_len, llama_attention_model.hidden_size) + xs = ttnn.as_tensor( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=llama_attention_model.mesh_device, + mesh_mapper=shard_tensor_to_2d_mesh_mapper( + llama_attention_model.mesh_device, mesh_shape=llama_attention_model.cluster_shape, dims=(None, 3) + ), + ) + + cos, sin = precompute_freqs( + llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2, rope_theta, use_scaled=False + ) + cos_gathered, sin_gathered = gather_cos_sin(torch.arange(start_pos, start_pos + seq_len), cos, sin) + assert cos_gathered.size() == (1, 1, seq_len, llama_attention_model.head_dim) + assert sin_gathered.size() == (1, 1, seq_len, llama_attention_model.head_dim) + + cos_gathereds = ttnn.as_tensor( + cos_gathered, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + # cache_file_name=cache_name(f"cos_gathered_prefill_{seq_len}"), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=llama_attention_model.mesh_device, + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), + ) + sin_gathereds = ttnn.as_tensor( + sin_gathered, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + # cache_file_name=cache_name(f"sin_gathered_prefill_{seq_len}"), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=llama_attention_model.mesh_device, + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), + ) + + rot_mats = [cos_gathereds, sin_gathereds] + + attn_mask = torch.full((seq_len, seq_len), torch.finfo(torch.float32).min) + attn_mask = torch.triu(attn_mask, diagonal=1) + attn_mask = attn_mask.expand(1, batch, -1, -1) + attn_masks = ttnn.as_tensor( + attn_mask, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + # cache_file_name=cache_name(f"attn_mask_prefill_{seq_len}"), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=llama_attention_model.mesh_device, + ) + + return ( + xs, + start_pos, + rot_mats, + attn_masks, + ) def run_test_LlamaAttention_inference( - mesh_device, - cluster_shape, - batch, - seq_len, - pcc, - model_config, - llama_version, - ckpt_dir, - tokenizer_path, - cache_path, + mesh_device, + cluster_shape, + batch, + seq_len, + pcc, + model_config, + llama_version, + ckpt_dir, + tokenizer_path, + cache_path, ): - # Prepare paths and devices - skip_model_load = should_skip_model_load() - - - # Prepare configs - hugging_face_reference_model = Llama.build( - ckpt_dir, - tokenizer_path, - max_seq_len=MAX_SEQ_LEN if llama_version == "llama2" else MAX_SEQ_LEN_LLAMA3, - max_batch_size=batch, - n_layers=UNIT_TEST_N_LAYER, - skip_model_load=skip_model_load, - ).model - hugging_face_reference_model.eval() - state_dict = hugging_face_reference_model.state_dict() - logger.info(state_dict.keys()) - torch.manual_seed(0) - configuration = hugging_face_reference_model.params - - - # PyTorch model -------------------------------------------------------------------- - pytorch_LlamaAttention_model = PytorchLlamaAttentionModel( - hugging_face_reference_model, UNIT_TEST_LAYER_NUM, configuration.rope_theta - ) - # TT model ------------------------------------------------------------------------- - transformation_mat_torch = get_rot_transformation_mat(32) # 32 for tile size - - - transformation_mats = ttnn.as_tensor( - transformation_mat_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=mesh_device, - mesh_mapper=replicate_tensor_to_mesh_mapper(mesh_device), - ) - - - tt_LlamaAttention_model = TtLlamaAttention_galaxy( - mesh_device, - cluster_shape, - state_dict, - BASE_URL, - UNIT_TEST_LAYER_NUM, - model_config, - configuration, - transformation_mats, - cache_path=cache_path, - ) - - - mode = "decode" if seq_len == 1 else "prefill" - - - all_tests_pass, all_pccs = True, [] - if mode == "prefill": - generation_start_pos = 0 - generation_length = 1 - else: - generation_start_pos = UNIT_TEST_START_POS - generation_length = UNIT_TEST_GENERATION_LENGTH - - - for i in range(generation_length): - # Prepare input - pt_inp_ids = torch.randint(0, configuration.vocab_size, (batch, seq_len)) - pt_inp = hugging_face_reference_model.tok_embeddings(pt_inp_ids) - pt_inp_normed = hugging_face_reference_model.layers[UNIT_TEST_LAYER_NUM].attention_norm(pt_inp) - tt_input = pt_inp_normed.clone() - start_pos = generation_start_pos + i - - - # PyTorch output -------------------------------------------------------------------- - if mode == "prefill": - attention_input, start_pos, freqs_cis, attn_mask = pytorch_LlamaAttention_model.prepare_inputs_prefill( - pt_inp_normed, start_pos - ) - else: - attention_input, start_pos, freqs_cis, attn_mask = pytorch_LlamaAttention_model.prepare_inputs( - pt_inp_normed, start_pos - ) - - - pytorch_out = pytorch_LlamaAttention_model( - attention_input, - start_pos, - freqs_cis, - attn_mask, - ) - - - # TT hardware execution ------------------------------------------------------------- - attention_input, start_pos, rot_mat, attn_mask = tt_llama_attention_prepare_inputs( - tt_LlamaAttention_model, tt_input, start_pos, configuration.rope_theta, mode=mode - ) - tt_out = tt_LlamaAttention_model( - attention_input, - rot_mat, - start_pos, - attn_mask, - mode=mode, - ) - # tt_out = [ttnn.to_torch(shard) for shard in ttnn.get_device_tensors(tt_out.cpu())] - - - tt_out = ttnn.to_torch( - tt_out, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(3, 1), cluster_shape=cluster_shape) - ) - tt_out = tt_out[:, 0:1, :, :] - tt_out = tt_out.permute(2, 1, 0, 3).squeeze(1) # [seq, batch, hidden_dim] - - - does_pass, output_pcc = comp_pcc(pytorch_out, tt_out, pcc) - logger.info(f"Output: {output_pcc}") - - - all_pccs.append(extract_pcc_from_log(output_pcc)) - - - if does_pass: - logger.info(f"[start_pos={start_pos}] {llama_version} Attention output Passed!") - else: - logger.warning( - f"[start_pos={start_pos}] {llama_version} Attention output Failed! PCC value is lower than {pcc}" - ) - all_tests_pass = False - - - logger.info(f"Average PCC over {len(all_pccs)} tokens: {sum(all_pccs) / len(all_pccs)}") - - - # Check kv cache - # PyTorch output -------------------------------------------------------------------- - pytorch_layer_present = [ - pytorch_LlamaAttention_model.attention.cache_k.clone().permute(0, 2, 1, 3)[ - :batch, ... - ], # [batch, n_kv_heads, seq, head_dim] - pytorch_LlamaAttention_model.attention.cache_v.clone().permute(0, 2, 1, 3)[ - :batch, ... - ], # [batch, n_kv_heads, seq, head_dim] - ] - # TT hardware output ---------------------------------------------------------------- - - - # concat the pasts by heads - tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_LlamaAttention_model.layer_past] - - - tt_layer_present_all = [ - ttnn.to_torch(lp, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(0, 1), cluster_shape=cluster_shape))[ - :batch, ... - ] - for lp in tt_layer_present_all - ] - - - cache_test_pass = check_kv_cache( - pytorch_layer_present, - tt_layer_present_all, - generation_start_pos, - generation_length, - seq_len, - mode == "prefill", - pcc, - ) - - - all_tests_pass = all_tests_pass and cache_test_pass - - - if all_tests_pass: - logger.info(f"{llama_version} Attention output Passed!") - else: - gc.collect() - logger.warning(f"{llama_version} Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" - - + # Prepare paths and devices + skip_model_load = should_skip_model_load() + + # Prepare configs + hugging_face_reference_model = Llama.build( + ckpt_dir, + tokenizer_path, + max_seq_len=MAX_SEQ_LEN if llama_version == "llama2" else MAX_SEQ_LEN_LLAMA3, + max_batch_size=batch, + n_layers=UNIT_TEST_N_LAYER, + skip_model_load=skip_model_load, + ).model + hugging_face_reference_model.eval() + state_dict = hugging_face_reference_model.state_dict() + logger.info(state_dict.keys()) + torch.manual_seed(0) + configuration = hugging_face_reference_model.params + + # PyTorch model -------------------------------------------------------------------- + pytorch_LlamaAttention_model = PytorchLlamaAttentionModel( + hugging_face_reference_model, UNIT_TEST_LAYER_NUM, configuration.rope_theta + ) + # TT model ------------------------------------------------------------------------- + transformation_mat_torch = get_rot_transformation_mat(32) # 32 for tile size + + transformation_mats = ttnn.as_tensor( + transformation_mat_torch, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=mesh_device, + mesh_mapper=replicate_tensor_to_mesh_mapper(mesh_device), + ) + + tt_LlamaAttention_model = TtLlamaAttention_galaxy( + mesh_device, + cluster_shape, + state_dict, + BASE_URL, + UNIT_TEST_LAYER_NUM, + model_config, + configuration, + transformation_mats, + cache_path=cache_path, + ) + + mode = "decode" if seq_len == 1 else "prefill" + + all_tests_pass, all_pccs = True, [] + if mode == "prefill": + generation_start_pos = 0 + generation_length = 1 + else: + generation_start_pos = UNIT_TEST_START_POS + generation_length = UNIT_TEST_GENERATION_LENGTH + + for i in range(generation_length): + # Prepare input + pt_inp_ids = torch.randint(0, configuration.vocab_size, (batch, seq_len)) + pt_inp = hugging_face_reference_model.tok_embeddings(pt_inp_ids) + pt_inp_normed = hugging_face_reference_model.layers[UNIT_TEST_LAYER_NUM].attention_norm(pt_inp) + tt_input = pt_inp_normed.clone() + start_pos = generation_start_pos + i + + # PyTorch output -------------------------------------------------------------------- + if mode == "prefill": + attention_input, start_pos, freqs_cis, attn_mask = pytorch_LlamaAttention_model.prepare_inputs_prefill( + pt_inp_normed, start_pos + ) + else: + attention_input, start_pos, freqs_cis, attn_mask = pytorch_LlamaAttention_model.prepare_inputs( + pt_inp_normed, start_pos + ) + + pytorch_out = pytorch_LlamaAttention_model( + attention_input, + start_pos, + freqs_cis, + attn_mask, + ) + + # TT hardware execution ------------------------------------------------------------- + attention_input, start_pos, rot_mat, attn_mask = tt_llama_attention_prepare_inputs( + tt_LlamaAttention_model, tt_input, start_pos, configuration.rope_theta, mode=mode + ) + tt_out = tt_LlamaAttention_model( + attention_input, + rot_mat, + start_pos, + attn_mask, + mode=mode, + ) + # tt_out = [ttnn.to_torch(shard) for shard in ttnn.get_device_tensors(tt_out.cpu())] + + tt_out = ttnn.to_torch( + tt_out, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(3, 1), cluster_shape=cluster_shape) + ) + tt_out = tt_out[:, 0:1, :, :] + tt_out = tt_out.permute(2, 1, 0, 3).squeeze(1) # [seq, batch, hidden_dim] + + does_pass, output_pcc = comp_pcc(pytorch_out, tt_out, pcc) + logger.info(f"Output: {output_pcc}") + + all_pccs.append(extract_pcc_from_log(output_pcc)) + + if does_pass: + logger.info(f"[start_pos={start_pos}] {llama_version} Attention output Passed!") + else: + logger.warning( + f"[start_pos={start_pos}] {llama_version} Attention output Failed! PCC value is lower than {pcc}" + ) + all_tests_pass = False + + logger.info(f"Average PCC over {len(all_pccs)} tokens: {sum(all_pccs) / len(all_pccs)}") + + # Check kv cache + # PyTorch output -------------------------------------------------------------------- + pytorch_layer_present = [ + pytorch_LlamaAttention_model.attention.cache_k.clone().permute(0, 2, 1, 3)[ + :batch, ... + ], # [batch, n_kv_heads, seq, head_dim] + pytorch_LlamaAttention_model.attention.cache_v.clone().permute(0, 2, 1, 3)[ + :batch, ... + ], # [batch, n_kv_heads, seq, head_dim] + ] + # TT hardware output ---------------------------------------------------------------- + + # concat the pasts by heads + tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_LlamaAttention_model.layer_past] + + tt_layer_present_all = [ + ttnn.to_torch(lp, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(0, 1), cluster_shape=cluster_shape))[ + :batch, ... + ] + for lp in tt_layer_present_all + ] + + cache_test_pass = check_kv_cache( + pytorch_layer_present, + tt_layer_present_all, + generation_start_pos, + generation_length, + seq_len, + mode == "prefill", + pcc, + ) + + all_tests_pass = all_tests_pass and cache_test_pass + + if all_tests_pass: + logger.info(f"{llama_version} Attention output Passed!") + else: + gc.collect() + logger.warning(f"{llama_version} Attention output Failed!") + assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( - "cluster_shape, mesh_device", [pytest.param((4, 8), (8, 4), id="4x8_grid")], indirect=["mesh_device"] + "cluster_shape, mesh_device", [pytest.param((4, 8), (8, 4), id="4x8_grid")], indirect=["mesh_device"] ) @pytest.mark.parametrize( - "llama_version", - (("llama3-tg"),), + "llama_version", + (("llama3-tg"),), ) @pytest.mark.parametrize( - "batch, seq_len, pcc", - [ - (32, 1, 0.9995), - (1, 256, 0.999), - ], - ids=[ - "decode", - "prefill", - ], + "batch, seq_len, pcc", + [ + (32, 1, 0.9995), + (1, 256, 0.999), + ], + ids=[ + "decode", + "prefill", + ], ) @pytest.mark.parametrize( - "max_batch_size, max_context_len", - ( - (32, 2048), - # (16, 8192), - ), - ids=( - "short_context", - # "long_context", - ), + "max_batch_size, max_context_len", + ( + (32, 2048), + # (16, 8192), + ), + ids=( + "short_context", + # "long_context", + ), ) def test_LlamaAttention_inference( - batch, - seq_len, - pcc, - mesh_device, - max_batch_size, - max_context_len, - llama_version, - cluster_shape, - use_program_cache, + batch, + seq_len, + pcc, + mesh_device, + max_batch_size, + max_context_len, + llama_version, + cluster_shape, + use_program_cache, ): - if batch > max_batch_size: - pytest.skip(f"Decode with {batch} users is not supported with large context") - - - if batch == 1 and seq_len > max_context_len: - pytest.skip(f"Prefill with {seq_len=} is not supported with short context") - - - if llama_version == "llama2" and seq_len > 2048: - pytest.skip(f"Llama2 with {seq_len=} is not supported (max 2048)") - - - model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env( - llama_version=llama_version, - max_batch_size=max_batch_size, - max_context_len=max_context_len, - ) - check_mesh_device(mesh_device, model_config) - run_test_LlamaAttention_inference( - mesh_device, - cluster_shape, - batch, - seq_len, - pcc, - model_config, - llama_version, - ckpt_dir, - tokenizer_path, - cache_path, - ) + if batch > max_batch_size: + pytest.skip(f"Decode with {batch} users is not supported with large context") + + if batch == 1 and seq_len > max_context_len: + pytest.skip(f"Prefill with {seq_len=} is not supported with short context") + + if llama_version == "llama2" and seq_len > 2048: + pytest.skip(f"Llama2 with {seq_len=} is not supported (max 2048)") + + model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env( + llama_version=llama_version, + max_batch_size=max_batch_size, + max_context_len=max_context_len, + ) + check_mesh_device(mesh_device, model_config) + run_test_LlamaAttention_inference( + mesh_device, + cluster_shape, + batch, + seq_len, + pcc, + model_config, + llama_version, + ckpt_dir, + tokenizer_path, + cache_path, + ) diff --git a/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py index 5d3fae48334..2404016e361 100644 --- a/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py @@ -110,8 +110,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=shard_tensor_to_2d_mesh_mapper -(self.mesh_device, self.cluster_shape, (None, 2)), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, self.cluster_shape, (None, 2)), cache_file_name=self.cache_path / attn_norm_sharded_str, ) @@ -121,8 +120,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=shard_tensor_to_2d_mesh_mapper -(self.mesh_device, self.cluster_shape, (None, 2)), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, self.cluster_shape, (None, 2)), cache_file_name=self.cache_path / ffn_norm_sharded_str, ) diff --git a/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py index de8c1b3d11d..068ad25c44d 100644 --- a/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py @@ -12,6 +12,7 @@ ) from ttnn import shard_tensor_to_2d_mesh_mapper + class TtLlamaMLP_galaxy: def __init__( self, diff --git a/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py index 429673467cf..d5abeb42724 100644 --- a/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py @@ -26,6 +26,7 @@ ) from ttnn import shard_tensor_to_2d_mesh_mapper + def is_power_of_two(n): if n <= 0: return False diff --git a/tests/ttnn/unit_tests/test_multi_device_async.py b/tests/ttnn/unit_tests/test_multi_device_async.py index a4822361d11..8f9e8c5b5df 100644 --- a/tests/ttnn/unit_tests/test_multi_device_async.py +++ b/tests/ttnn/unit_tests/test_multi_device_async.py @@ -31,7 +31,10 @@ def test_ttnn_to_and_from_multi_device_shard(pcie_mesh_device, layout, memory_co for i in range(100): torch_tensor = torch.rand((1, 1, 256, 512), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3) + torch_tensor, + dtype=dtype, + layout=layout, + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3), ) ttnn_tensor = ttnn.to_device(ttnn_tensor, pcie_mesh_device, memory_config=memory_config) ttnn_loop_back_tensor = ttnn.from_device(ttnn_tensor) @@ -63,7 +66,10 @@ def test_multi_device_check_per_device_shard(pcie_mesh_device, layout, memory_co torch_tensor = torch.rand((8, 1, 1024, 1024), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3) + torch_tensor, + dtype=dtype, + layout=layout, + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3), ) ttnn_tensor = ttnn.to_device(ttnn_tensor, pcie_mesh_device, memory_config=memory_config) ttnn_loop_back_tensor = ttnn.from_device(ttnn_tensor) diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index d639eae06f2..84abb481e26 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -189,7 +189,8 @@ def create_mesh_device(*args, **kwargs): yield mesh_device finally: close_mesh_device(mesh_device) - + + def synchronize_devices( devices: Union["ttnn.Device", "ttnn.MeshDevice"], queue_id: Optional[int] = ttnn.DefaultQueueId, @@ -293,6 +294,7 @@ def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": ] return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim) + @contextlib.contextmanager def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor, MeshToTensor]): """ @@ -330,4 +332,5 @@ def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor, MeshToTensor ttnn.from_torch = _original_from_torch ttnn.to_torch = _original_to_torch + __all__ = []