diff --git a/models/common/rmsnorm.py b/models/common/rmsnorm.py index 28eb9cadf55..6926df48f7a 100644 --- a/models/common/rmsnorm.py +++ b/models/common/rmsnorm.py @@ -80,7 +80,7 @@ def __init__( layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=weight_memory_config, cache_file_name=cache_name, - mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(device) if is_mesh_device else None, ) if self.is_distributed: diff --git a/models/common/tests/test_rmsnorm.py b/models/common/tests/test_rmsnorm.py index 1828a6702e4..1933b0798ab 100644 --- a/models/common/tests/test_rmsnorm.py +++ b/models/common/tests/test_rmsnorm.py @@ -12,7 +12,7 @@ os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml" import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.common.rmsnorm import RMSNorm as TtRMSNorm from models.utility_functions import ( @@ -130,7 +130,7 @@ def test_rmsnorm_multidevice(t3k_mesh_device, is_sharded, use_program_cache, res device=t3k_mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) tt_output = tt_model(tt_input) diff --git a/models/demos/falcon7b_common/tests/test_falcon_mlp.py b/models/demos/falcon7b_common/tests/test_falcon_mlp.py index cf741ff67d1..6e8f2328eef 100644 --- a/models/demos/falcon7b_common/tests/test_falcon_mlp.py +++ b/models/demos/falcon7b_common/tests/test_falcon_mlp.py @@ -6,7 +6,7 @@ import torch from loguru import logger import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper from models.demos.falcon7b_common.tt.falcon_mlp import TtFalconMLPDecode, TtFalconMLPPrefill from models.demos.falcon7b_common.tt.model_config import get_model_config from models.demos.falcon7b_common.tests.test_utils import load_hf_model, tt_from_torch, get_num_devices @@ -79,7 +79,7 @@ def run_test_FalconMLP_inference( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) tt_out = tt_FalconMLP_model(tt_mlp_input) diff --git a/models/demos/falcon7b_common/tests/test_utils.py b/models/demos/falcon7b_common/tests/test_utils.py index 3e7e29fe478..b8bf2caa254 100644 --- a/models/demos/falcon7b_common/tests/test_utils.py +++ b/models/demos/falcon7b_common/tests/test_utils.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper from transformers import FalconForCausalLM from models.utility_functions import tt_tensors_to_torch_tensors @@ -20,14 +20,14 @@ def initialize_kv_cache(configuration, num_layers, batch_size, max_seq_len, mesh dtype=ttnn.bfloat16, device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_v_cache = tt_from_torch( v_cache, dtype=ttnn.bfloat16, device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) kv_cache += ((tt_k_cache, tt_v_cache),) return kv_cache @@ -106,7 +106,7 @@ def get_rand_falcon_inputs( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) if model_config["PREFILL_OPTIMIZED_MODE"] and seq_len in [2048, 128, 1024]: @@ -121,7 +121,7 @@ def get_rand_falcon_inputs( dtype=ttnn.bfloat4_b, device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) for attn_mask in attn_masks ] @@ -131,7 +131,7 @@ def get_rand_falcon_inputs( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) # Generate kvcache for each layer @@ -145,14 +145,14 @@ def get_rand_falcon_inputs( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) tt_v_cache = tt_from_torch( tt_v_cache.unsqueeze(1), dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) tt_layer_past += ((tt_k_cache, tt_v_cache),) @@ -169,7 +169,7 @@ def get_rand_falcon_inputs( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=2), ) attention_mask_bool = torch.zeros(global_batch, 1, q_len, kv_len, dtype=bool) @@ -200,7 +200,7 @@ def get_rand_falcon_inputs( device=mesh_device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=model_config["ATTN_MASK_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=device_shard_dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=device_shard_dim), ) if not model_config["l1_sharded"]: # Tilize attn masks @@ -227,14 +227,14 @@ def get_rand_falcon_inputs( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) tt_v_cache = tt_from_torch( tt_v_cache.unsqueeze(1), dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) tt_layer_past += ((tt_k_cache, tt_v_cache),) diff --git a/models/demos/falcon7b_common/tt/falcon_attention.py b/models/demos/falcon7b_common/tt/falcon_attention.py index ea1c2740148..54af7c56102 100644 --- a/models/demos/falcon7b_common/tt/falcon_attention.py +++ b/models/demos/falcon7b_common/tt/falcon_attention.py @@ -9,7 +9,7 @@ from models.demos.falcon7b_common.tt.model_utils import get_falcon_default_core_grid import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.utility_functions import ( nearest_32, @@ -155,7 +155,7 @@ def __init__( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) # optimized version can utilize single float value for softmax @@ -175,7 +175,7 @@ def __init__( device=self.mesh_device, layout=ttnn.TILE_LAYOUT, memory_config=self.model_config["ATTN_OPTIMIZED_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) self.model_config["ATTN_OUTPUT_TENSORS"][seq_len] = tt_tensors @@ -553,7 +553,7 @@ def __init__( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) def forward( diff --git a/models/demos/falcon7b_common/tt/falcon_causallm.py b/models/demos/falcon7b_common/tt/falcon_causallm.py index 09702d4a94f..6194690c3d0 100644 --- a/models/demos/falcon7b_common/tt/falcon_causallm.py +++ b/models/demos/falcon7b_common/tt/falcon_causallm.py @@ -6,7 +6,7 @@ import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.demos.falcon7b_common.tt.falcon_lm_head import falcon_lm_head_matmul_2d from models.demos.falcon7b_common.tt.falcon_model import TtFalconModelShared from models.demos.falcon7b_common.tt.model_utils import ( @@ -123,7 +123,7 @@ def __init__( device=self.mesh_device, layout=ttnn.TILE_LAYOUT, memory_config=self.model_config["LM_HEAD_MM_INPUT_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) self.lm_head_weights = get_weights_cached( diff --git a/models/demos/falcon7b_common/tt/falcon_mlp.py b/models/demos/falcon7b_common/tt/falcon_mlp.py index 7694e2d4ea8..d6884d7b59e 100644 --- a/models/demos/falcon7b_common/tt/falcon_mlp.py +++ b/models/demos/falcon7b_common/tt/falcon_mlp.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.demos.falcon7b_common.tt.model_utils import ( get_falcon_default_core_grid, get_weights_cached, @@ -176,7 +176,7 @@ def _load_mlp_padded_tensors(self): device=self.mesh_device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) mlp_padding_tensors[seq_len] = tt_padding self.model_config["MLP_PREFILL_PADDING_TENSORS"] = mlp_padding_tensors @@ -191,7 +191,7 @@ def _allocate_output_mlp_tensors(self): device=self.mesh_device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) self.model_config["MLP_OUTPUT_TENSORS"] = out_tt @@ -344,7 +344,7 @@ def _load_mlp_padded_tensors(self): device=self.mesh_device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) self.model_config["MLP_DECODE_PADDING_TENSORS"] = tt_paddings diff --git a/models/demos/falcon7b_common/tt/falcon_model.py b/models/demos/falcon7b_common/tt/falcon_model.py index fa2932cb0c8..d79cdee51d8 100644 --- a/models/demos/falcon7b_common/tt/falcon_model.py +++ b/models/demos/falcon7b_common/tt/falcon_model.py @@ -7,7 +7,7 @@ import torch import ttnn -from ttnn import ReplicateTensorToMesh, ShardTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, shard_tensor_to_mesh_mapper from models.demos.falcon7b_common.tt.falcon_decoder import TtFalconDecoderLayer from models.demos.falcon7b_common.tt.model_utils import get_weights_cached, layernorm @@ -134,7 +134,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token device=self.mesh_device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=self.model_config["ATTN_MASK_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) for attention_mask_slice in attention_mask_ ] @@ -156,7 +156,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token device=self.mesh_device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=self.model_config["ATTN_MASK_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) # Repeat attn masks for all heads tt_attention_mask = ttnn.repeat( @@ -177,7 +177,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=self.model_config["INPUT_MEMCFG"], - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=0), ) elif llm_mode == "decode": assert batch_size % 32 == 0, "For decode, batch_size must be multiple of 32!" @@ -210,7 +210,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token device=self.mesh_device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=self.model_config["ATTN_MASK_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) if not self.model_config["l1_sharded"]: # Tilize attn masks @@ -226,7 +226,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=self.model_config["INPUT_MEMCFG"], - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), ) else: raise NotImplementedError(f"Llm mode {llm_mode} is not supported! Must be one of prefill or decode.") diff --git a/models/demos/falcon7b_common/tt/model_utils.py b/models/demos/falcon7b_common/tt/model_utils.py index b7ce657bd69..2b068eaeade 100644 --- a/models/demos/falcon7b_common/tt/model_utils.py +++ b/models/demos/falcon7b_common/tt/model_utils.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.utility_functions import is_wormhole_b0 @@ -50,7 +50,9 @@ def preprocess_weights(weights_to_cache): layout=tt_layout, device=mesh_device, memory_config=model_config[f"{weight_config_str}_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(mesh_device) if type(mesh_device) == ttnn.MeshDevice else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) + if type(mesh_device) == ttnn.MeshDevice + else None, cache_file_name=str(path), preprocess=preprocess_weights, ) diff --git a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py index dc395842338..451e70be45c 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py @@ -15,7 +15,7 @@ ##### TTNN imports ##### import ttnn from ttnn import experimental as ttl -from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from models.utility_functions import skip_for_grayskull from models.utility_functions import ( comp_pcc, @@ -108,7 +108,7 @@ def test_llama_class_embedding_inference( layout=layout, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) logger.info(f"TT Input tensor shape: {tt_input_tensor.shape}") diff --git a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py index c38dd5ccb26..eea8858a6fb 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py +++ b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py @@ -14,7 +14,7 @@ ##### TTNN imports ##### import ttnn from ttnn import experimental as ttl -from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from models.utility_functions import skip_for_grayskull from models.utility_functions import ( comp_pcc, diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index 15490b6ba41..7bdd8059769 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -106,7 +106,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat16, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) for _ in range(2) ] @@ -170,7 +170,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask = ttnn.from_torch( full_text_mask_expand[b : b + 1], @@ -178,7 +178,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_out = tt_model( tt_tensor_x, @@ -209,7 +209,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_xattn_mask = ttnn.reshape( tt_xattn_mask, @@ -224,7 +224,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask = ttnn.reshape( tt_full_text_mask, diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 7c59a9630de..eec8b4f7bd1 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -193,7 +193,7 @@ def test_llama_cross_attention_transformer_text_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH[b : b + 1], @@ -201,7 +201,7 @@ def test_llama_cross_attention_transformer_text_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_11SD = ttnn.from_torch( full_text_mask_expand_11SD[b : b + 1], @@ -209,7 +209,7 @@ def test_llama_cross_attention_transformer_text_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), ) rot_mats = get_prefill_rot_mat( @@ -253,7 +253,7 @@ def test_llama_cross_attention_transformer_text_inference( dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) rot_mats, _ = get_single_rot_mat( @@ -275,7 +275,7 @@ def test_llama_cross_attention_transformer_text_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_xattn_mask = ttnn.reshape( tt_xattn_mask, @@ -290,7 +290,7 @@ def test_llama_cross_attention_transformer_text_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_1NSH = ttnn.reshape( tt_full_text_mask_expand_1NSH, @@ -309,7 +309,7 @@ def test_llama_cross_attention_transformer_text_inference( device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_11SD = ttnn.to_layout(tt_full_text_mask_expand_11SD, ttnn.TILE_LAYOUT) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index 7516354af66..ff8d79180c0 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -97,7 +97,7 @@ def test_llama_cross_attention_transformer_block_inference( layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat16, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) for _ in range(2) ] @@ -161,7 +161,7 @@ def test_llama_cross_attention_transformer_block_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH[b : b + 1], @@ -169,7 +169,7 @@ def test_llama_cross_attention_transformer_block_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_11SD = ttnn.from_torch( full_text_mask_expand_11SD[b : b + 1], @@ -177,7 +177,7 @@ def test_llama_cross_attention_transformer_block_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), ) tt_out = tt_model( tt_tensor_x, @@ -207,7 +207,7 @@ def test_llama_cross_attention_transformer_block_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_xattn_mask = ttnn.reshape( tt_xattn_mask, @@ -222,7 +222,7 @@ def test_llama_cross_attention_transformer_block_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_1NSH = ttnn.reshape( tt_full_text_mask_expand_1NSH, @@ -241,7 +241,7 @@ def test_llama_cross_attention_transformer_block_inference( device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_11SD = ttnn.to_layout(tt_full_text_mask_expand_11SD, ttnn.TILE_LAYOUT) diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py index 3d9e6977145..03be0a437a9 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py @@ -96,7 +96,7 @@ def test_llama_attention_inference(batch, num_chunks, mesh_device, use_program_c dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_out = tt_model(attention_input, mask=tt_mask) diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_block.py b/models/demos/llama3/tests/multimodal/test_llama_image_block.py index 23096202e29..f21ad59bda2 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_block.py @@ -104,7 +104,7 @@ def test_llama_block_inference(batch, num_chunks, mesh_device, gated, use_progra dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_out = tt_model(attention_input, mask=tt_mask) diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py b/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py index c6b65ef7f9d..8013df2f2da 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py @@ -75,7 +75,7 @@ def test_llama_mlp_inference(batch, num_chunks, mesh_device, use_program_cache, tt_input = ttnn.from_torch( torch_input, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py index 502736ac790..5a22ffb3b84 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py @@ -135,7 +135,7 @@ def test_llama_image_transformer_inference( dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) with torch.no_grad(): diff --git a/models/demos/llama3/tests/multimodal/test_llama_layernorm.py b/models/demos/llama3/tests/multimodal/test_llama_layernorm.py index d52d9f415f3..56daf2540c0 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_layernorm.py +++ b/models/demos/llama3/tests/multimodal/test_llama_layernorm.py @@ -74,7 +74,7 @@ def test_layernorm_inference(mesh_device, use_program_cache, reset_seeds, ensure tt_input = ttnn.from_torch( torch_input, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, diff --git a/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py index c5262bf2235..aaa2c76d20e 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py @@ -17,7 +17,7 @@ ##### TTNN imports ##### import ttnn from ttnn import experimental as ttl -from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from models.utility_functions import skip_for_grayskull from models.utility_functions import ( comp_pcc, @@ -128,7 +128,7 @@ def test_llama_positional_embedding_inference( layout=layout, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_input_tensor = ttnn.to_layout(tt_input_tensor, ttnn.ROW_MAJOR_LAYOUT) diff --git a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py index 4ba64dd76ff..c00c27773bf 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py @@ -17,7 +17,7 @@ ##### TTNN imports ##### import ttnn from ttnn import experimental as ttl -from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from models.utility_functions import skip_for_grayskull from models.utility_functions import ( comp_pcc, @@ -98,7 +98,7 @@ def test_llama_conv2d_inference( layout=layout, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) logger.info(f"TT Input tensor shape: {tt_input_tensor.shape}") diff --git a/models/demos/llama3/tests/test_llama_attention_prefill.py b/models/demos/llama3/tests/test_llama_attention_prefill.py index 52d6e2cc19a..534ffa2c407 100644 --- a/models/demos/llama3/tests/test_llama_attention_prefill.py +++ b/models/demos/llama3/tests/test_llama_attention_prefill.py @@ -100,7 +100,7 @@ def test_llama_attention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) transformation_mats = {"prefill": transformation_mats_prefill} @@ -129,7 +129,7 @@ def test_llama_attention_inference( device=mesh_device, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_model = TtLlamaAttention( diff --git a/models/demos/llama3/tests/test_llama_decoder_prefill.py b/models/demos/llama3/tests/test_llama_decoder_prefill.py index a370011383d..2e0c9551054 100644 --- a/models/demos/llama3/tests/test_llama_decoder_prefill.py +++ b/models/demos/llama3/tests/test_llama_decoder_prefill.py @@ -102,7 +102,7 @@ def test_llama_decoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) transformation_mats = {"prefill": transformation_mats_prefill} @@ -127,7 +127,7 @@ def test_llama_decoder_inference( device=mesh_device, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) # Initialize TT model diff --git a/models/demos/llama3/tests/test_llama_embedding.py b/models/demos/llama3/tests/test_llama_embedding.py index 71d56a3a7f4..2b28a51944b 100644 --- a/models/demos/llama3/tests/test_llama_embedding.py +++ b/models/demos/llama3/tests/test_llama_embedding.py @@ -67,7 +67,7 @@ def test_llama_embedding(max_seq_len, batch_size, mesh_device, use_program_cache tt_input = ttnn.from_torch( pt_input.squeeze(1), device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, ) diff --git a/models/demos/llama3/tests/test_llama_mlp.py b/models/demos/llama3/tests/test_llama_mlp.py index 710ee9498c5..37770024ce1 100644 --- a/models/demos/llama3/tests/test_llama_mlp.py +++ b/models/demos/llama3/tests/test_llama_mlp.py @@ -75,7 +75,7 @@ def test_llama_mlp_inference(seq_len, batch_size, mesh_device, use_program_cache device=mesh_device, mesh_mapper=ttnn.ShardTensor2dMesh( mesh_device, dims=(None, 3) if model_args.is_galaxy else (None, None), mesh_shape=model_args.cluster_shape - ), # When both dims are None, the mapper used is `ReplicateTensorToMesh` + ), # When both dims are None, the mapper used is `ttnn.replicate_tensor_to_mesh_mapper` dtype=ttnn.bfloat8_b, memory_config=( ( diff --git a/models/demos/llama3/tests/test_llama_model_prefill.py b/models/demos/llama3/tests/test_llama_model_prefill.py index 667764a2304..6e6bfcca2e3 100644 --- a/models/demos/llama3/tests/test_llama_model_prefill.py +++ b/models/demos/llama3/tests/test_llama_model_prefill.py @@ -160,7 +160,7 @@ def test_llama_model_inference( device=mesh_device, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) # Load TTNN model diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index d1c1bee93b0..d9c064c2ddc 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -63,7 +63,7 @@ def __init__( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), ) user_selection_matrix = torch.eye(8, 8) user_selection_matrix = torch.nn.functional.pad(user_selection_matrix, (0, 24), "constant", 0) # (8, 32) @@ -74,7 +74,7 @@ def __init__( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, device=self.mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) self.dtype = dtype @@ -128,7 +128,7 @@ def __init__( self.wqkv_bias_prefill = ttnn.as_tensor( qkv_bias, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, @@ -153,7 +153,7 @@ def __init__( bias_tensor = ttnn.as_tensor( qkv_bias_decode, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, @@ -277,7 +277,7 @@ def init_kv_cache(self, configuration, weight_cache_path): layout=self.model_config["ATTN_W_LAYOUT_TILE"], device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=( f"{weight_cache_path}/kvcache_{k_or_v.shape}" if weight_cache_path and not configuration.dummy_weights diff --git a/models/demos/llama3/tt/llama_common.py b/models/demos/llama3/tt/llama_common.py index dd6873ed8b3..829d02761a9 100644 --- a/models/demos/llama3/tt/llama_common.py +++ b/models/demos/llama3/tt/llama_common.py @@ -219,14 +219,14 @@ def get_prefill_rot_mat(head_dim, mesh_device, seq_len, theta, scale_factor, ori dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) sin_gathereds = ttnn.from_torch( sin_gathered, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) rot_mats = [cos_gathereds, sin_gathereds] @@ -280,13 +280,13 @@ def get_single_rot_mat( device=mesh_device if not on_host else None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if num_devices > 1 or not on_host else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if num_devices > 1 or not on_host else None, ), ttnn.from_torch( rot_matrix.unsqueeze(0).unsqueeze(0), # 1,1,head_dim,head_dim device=mesh_device if not on_host else None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if num_devices > 1 or not on_host else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if num_devices > 1 or not on_host else None, ) @@ -402,7 +402,9 @@ def sample_host(tt_input, mesh_device, temperature=0.6, top_p=0.08, on_host=True layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.uint32, device=None, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if mesh_device.get_num_devices() > 1 else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) + if mesh_device.get_num_devices() > 1 + else None, ), pt_out, ) @@ -413,7 +415,7 @@ def sample_host(tt_input, mesh_device, temperature=0.6, top_p=0.08, on_host=True layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.uint32, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), pt_out, ) diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index 8f49cd04299..1f473a473dc 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -113,7 +113,7 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag device=self.mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) tokens_embd = self.embd(tokens) tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) @@ -127,7 +127,7 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag device=self.mesh_device, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) else: tt_page_table = None @@ -138,7 +138,7 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag device=self.mesh_device, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) else: tt_chunk_page_table = None @@ -172,7 +172,7 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None): tokens, device=None, dtype=ttnn.uint32, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) tokens = ttnn.unsqueeze_to_4D(tokens) diff --git a/models/demos/llama3/tt/llama_rope.py b/models/demos/llama3/tt/llama_rope.py index 533768df5b5..3a4a414ca5f 100644 --- a/models/demos/llama3/tt/llama_rope.py +++ b/models/demos/llama3/tt/llama_rope.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ReplicateTensorToMesh, ShardTensor2dMesh +from ttnn import replicate_tensor_to_mesh_mapper, ShardTensor2dMesh from models.common.lightweightmodule import LightweightModule from models.demos.llama3.tt.llama_common import precompute_freqs, get_rot_transformation_mat, gather_cos_sin from models.utility_functions import nearest_32 @@ -56,14 +56,14 @@ def __init__( device=device, layout=ttnn.TILE_LAYOUT, dtype=datatype, - mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(device) if self.is_mesh_device else None, ) self.sin_matrix = ttnn.from_torch( sin_matrix, device=device, layout=ttnn.TILE_LAYOUT, dtype=datatype, - mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(device) if self.is_mesh_device else None, ) batch_grid = ttnn.num_cores_to_corerangeset(batch_size, self.core_grid, row_wise=True) @@ -107,7 +107,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, dtype=datatype, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(device) if self.is_mesh_device else None, ) def get_both_trans_mats(self): @@ -133,7 +133,7 @@ def get_rot_idxs(self, position_idxs, on_host=False): position_idxs, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(self.device) if self.is_mesh_device else None, ) else: # On device rot_idxs = ttnn.as_tensor( @@ -142,7 +142,7 @@ def get_rot_idxs(self, position_idxs, on_host=False): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(self.device) if self.is_mesh_device else None, ) return rot_idxs diff --git a/models/demos/llama3/tt/lm_head.py b/models/demos/llama3/tt/lm_head.py index a79f8856e66..628ca3e093d 100644 --- a/models/demos/llama3/tt/lm_head.py +++ b/models/demos/llama3/tt/lm_head.py @@ -87,7 +87,7 @@ def __init__( ttnn.as_tensor( combined_split, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), layout=ttnn.TILE_LAYOUT, dtype=dtype, memory_config=memory_config, diff --git a/models/demos/llama3/tt/multimodal/llama_class_embedding.py b/models/demos/llama3/tt/multimodal/llama_class_embedding.py index 6bb57822953..fd3d8defe4c 100644 --- a/models/demos/llama3/tt/multimodal/llama_class_embedding.py +++ b/models/demos/llama3/tt/multimodal/llama_class_embedding.py @@ -8,7 +8,7 @@ import ttnn from models.common.lightweightmodule import LightweightModule -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper class TtLlamaClassEmbedding(LightweightModule): @@ -37,7 +37,7 @@ def __init__( layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) def forward(self, x): diff --git a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py index f5ff04f7e3e..ea1cf85685e 100644 --- a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py +++ b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py @@ -11,7 +11,7 @@ ) from models.common.lightweightmodule import LightweightModule -from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper class TtLlamaConv2dPatch(LightweightModule): @@ -56,7 +56,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) if bias else None @@ -76,7 +76,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) self.compute_kernel_config = ttnn.init_device_compute_kernel_config( @@ -102,7 +102,7 @@ def forward(self, x: torch.Tensor): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) out = ttnn.linear( diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index ef312334bcf..e2a8164c695 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -72,7 +72,7 @@ def __init__( self.wq = ttnn.as_tensor( self.state_dict[wq_str].transpose(-2, -1), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, @@ -82,7 +82,7 @@ def __init__( self.wk = ttnn.as_tensor( self.state_dict[wk_str].transpose(-2, -1), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, @@ -92,7 +92,7 @@ def __init__( self.wv = ttnn.as_tensor( self.state_dict[wv_str].transpose(-2, -1), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, @@ -102,7 +102,7 @@ def __init__( self.wo = ttnn.as_tensor( self.state_dict[wo_str].transpose(-2, -1), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-2), memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=self.dtype, layout=ttnn.TILE_LAYOUT, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py index 28ee6e810ed..65248a2d619 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py @@ -93,7 +93,7 @@ def __init__( lm_head_torch[split], dtype=type, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=dim), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, cache_file_name=cache_name(name, suffix, split), @@ -254,7 +254,7 @@ def setup_cache(self, max_batch_size): layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat16, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), ) for _ in range(2) ] diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py index 441ccda766b..f96d39bb3d9 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py @@ -74,9 +74,9 @@ def shuffle_weight(weight): dtype=type, device=self.mesh_device, mesh_mapper=( - ttnn.ShardTensorToMesh(self.mesh_device, dim=dim) + ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=dim) if dim is not None - else ttnn.ReplicateTensorToMesh(self.mesh_device) + else ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device) ), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_block.py b/models/demos/llama3/tt/multimodal/llama_cross_block.py index e09ae041595..5d7ad4620b7 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_block.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_block.py @@ -74,7 +74,7 @@ def __init__( state_dict[f"{state_dict_prefix}gate_attn"].unsqueeze(0).expand(1, self.hidden_size), dtype=ttnn.bfloat16, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) @@ -109,7 +109,7 @@ def __init__( state_dict[f"{state_dict_prefix}gate_ffwd"].unsqueeze(0).expand(1, self.hidden_size), dtype=ttnn.bfloat16, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) diff --git a/models/demos/llama3/tt/multimodal/llama_image_attention.py b/models/demos/llama3/tt/multimodal/llama_image_attention.py index c518793f83e..6721b100732 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_image_attention.py @@ -118,7 +118,7 @@ def pad_head_dim(weight, heads_out=True): dim=-1, ), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, @@ -132,7 +132,7 @@ def pad_head_dim(weight, heads_out=True): -1, ), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-2), memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=self.dtype, layout=ttnn.TILE_LAYOUT, diff --git a/models/demos/llama3/tt/multimodal/llama_image_block.py b/models/demos/llama3/tt/multimodal/llama_image_block.py index 9ab361aed26..257ee5763c6 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_block.py +++ b/models/demos/llama3/tt/multimodal/llama_image_block.py @@ -79,7 +79,7 @@ def __init__( state_dict[f"{state_dict_prefix}gate_attn"].unsqueeze(0).expand(1, self.hidden_size), dtype=ttnn.bfloat16, device=self.mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) @@ -87,7 +87,7 @@ def __init__( state_dict[f"{state_dict_prefix}gate_ffn"].unsqueeze(0).expand(1, self.hidden_size), dtype=ttnn.bfloat16, device=self.mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) diff --git a/models/demos/llama3/tt/multimodal/llama_image_mlp.py b/models/demos/llama3/tt/multimodal/llama_image_mlp.py index 0d56f310eaf..212e558c8f7 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_mlp.py +++ b/models/demos/llama3/tt/multimodal/llama_image_mlp.py @@ -41,9 +41,9 @@ def __init__( dtype=type, device=self.mesh_device, mesh_mapper=( - ttnn.ShardTensorToMesh(self.mesh_device, dim=dim) + ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=dim) if dim is not None - else ttnn.ReplicateTensorToMesh(self.mesh_device) + else ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device) ), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, diff --git a/models/demos/llama3/tt/multimodal/llama_layernorm.py b/models/demos/llama3/tt/multimodal/llama_layernorm.py index a20c4764ad1..737b16290af 100644 --- a/models/demos/llama3/tt/multimodal/llama_layernorm.py +++ b/models/demos/llama3/tt/multimodal/llama_layernorm.py @@ -42,7 +42,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, memory_config=weight_memory_config, cache_file_name=cache_name / "weight", - mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(device) if is_mesh_device else None, ) self.bias = ttnn.as_tensor( @@ -52,7 +52,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, memory_config=weight_memory_config, cache_file_name=cache_name / "bias", - mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(device) if is_mesh_device else None, ) if model_config: diff --git a/models/demos/llama3/tt/multimodal/llama_positional_embedding.py b/models/demos/llama3/tt/multimodal/llama_positional_embedding.py index af80b24b862..fea3542b12a 100644 --- a/models/demos/llama3/tt/multimodal/llama_positional_embedding.py +++ b/models/demos/llama3/tt/multimodal/llama_positional_embedding.py @@ -13,7 +13,7 @@ ) from models.common.lightweightmodule import LightweightModule -from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper TILE_SIZE = 32 @@ -48,7 +48,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) padded_gated_embeddings, self.ar_mapping = self.generate_padded_gated_embeddings( gated_positional_embedding, gated_positional_embedding_gate @@ -59,7 +59,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) # Add batch and ntok dimensions @@ -72,7 +72,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) def generate_padded_gated_embeddings(self, gated_embedding, gate): diff --git a/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py b/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py index 9ef2aadddac..a97f20d264e 100644 --- a/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py +++ b/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py @@ -13,7 +13,7 @@ ) from models.common.lightweightmodule import LightweightModule -from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper class TtLlamaTilePositionEmbedding(LightweightModule): @@ -56,7 +56,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) if self.gated: @@ -67,7 +67,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) def generate_padded_embeddings(self, embedding: torch.Tensor, num_tiles, width): diff --git a/models/demos/llama3/tt/multimodal/llama_vision_encoder.py b/models/demos/llama3/tt/multimodal/llama_vision_encoder.py index dfe441ee039..fa8d30f6e18 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_encoder.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_encoder.py @@ -37,7 +37,7 @@ def pad_seq_one_tile(x, mesh_device): device=mesh_device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) @@ -239,7 +239,7 @@ def forward(self, images, ar): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) x = ttnn.reshape(x, (1, bsz * num_concurrent_media, -1, dim)) diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 7fc9d630102..a1632100a07 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -243,7 +243,7 @@ def compute_vision_tokens_masks( memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat16, device=self.mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) padded_masks = _pad_masks( # torch.Size([1, 512, 1, 4]) @@ -314,7 +314,7 @@ def prepare_inputs_prefill( dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT) tt_xattn_mask = ttnn.typecast(tt_xattn_mask, ttnn.bfloat4_b) @@ -333,7 +333,7 @@ def prepare_inputs_prefill( dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) tt_full_text_mask_expand_1NSH = ttnn.to_layout(tt_full_text_mask_expand_1NSH, ttnn.TILE_LAYOUT) tt_full_text_mask_expand_1NSH = ttnn.typecast(tt_full_text_mask_expand_1NSH, ttnn.bfloat4_b) @@ -345,7 +345,7 @@ def prepare_inputs_prefill( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), ) if isinstance(cross_page_table, torch.Tensor): @@ -356,7 +356,7 @@ def prepare_inputs_prefill( memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) else: assert cross_attention_masks is None and full_text_row_masked_out_mask is None @@ -385,7 +385,7 @@ def prepare_inputs_prefill( memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) return ( @@ -503,7 +503,7 @@ def prepare_decode_inputs_host( device=None, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) rot_position_id = torch.maximum( @@ -535,7 +535,7 @@ def prepare_decode_inputs_host( device=None, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) full_text_mask = torch.cat(full_text_mask, dim=1).unsqueeze(0) @@ -553,7 +553,7 @@ def prepare_decode_inputs_host( device=None, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) full_text_mask_expand_11SD = full_text_mask @@ -569,7 +569,7 @@ def prepare_decode_inputs_host( device=None, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) if isinstance(page_table, torch.Tensor): @@ -578,7 +578,7 @@ def prepare_decode_inputs_host( page_table, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) if isinstance(cross_page_table, torch.Tensor): @@ -587,7 +587,7 @@ def prepare_decode_inputs_host( cross_page_table, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) return ( diff --git a/models/demos/qwen/demo/demo.py b/models/demos/qwen/demo/demo.py index 3474877333d..a076904b1ed 100644 --- a/models/demos/qwen/demo/demo.py +++ b/models/demos/qwen/demo/demo.py @@ -283,7 +283,7 @@ def run_qwen_demo( dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, ) profiler.end(f"prepare_rot_mat_for_prefill", iteration=batch_idx) @@ -371,7 +371,7 @@ def run_qwen_demo( torch.tensor([start_pos]), device=mesh_device, dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_out = tt_model(pt_decode_input, current_pos_tensor, rot_mat=current_rot_mat) @@ -389,7 +389,7 @@ def run_qwen_demo( tt_out_tok = ttnn.from_torch( torch.nn.functional.pad(pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0), device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.uint32, ) profiler.end(f"prepare_first_decode_token_{batch_idx}") @@ -419,7 +419,7 @@ def run_qwen_demo( current_pos = ttnn.from_torch( torch.tensor(decoding_pos, dtype=torch.int32), device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.int32, ) @@ -467,12 +467,12 @@ def run_qwen_demo( current_pos_reset = ttnn.from_torch( torch.tensor(decoding_pos, dtype=torch.int32), dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if tt_model.args.num_devices > 1 else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if tt_model.args.num_devices > 1 else None, ) tt_out_tok_reset = ttnn.from_torch( torch.nn.functional.pad(pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0), dtype=ttnn.uint32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if tt_model.args.num_devices > 1 else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if tt_model.args.num_devices > 1 else None, ) ttnn.copy_host_to_device_tensor(current_pos_reset, current_pos) diff --git a/models/demos/qwen/tests/test_lm_head.py b/models/demos/qwen/tests/test_lm_head.py index b62acd9284b..a703fae02ba 100644 --- a/models/demos/qwen/tests/test_lm_head.py +++ b/models/demos/qwen/tests/test_lm_head.py @@ -65,7 +65,7 @@ def test_qwen_lm_head_inference(mesh_device, seq_len, use_program_cache, reset_s tt_input = ttnn.from_torch( torch_input, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.bfloat8_b, memory_config=model_args.model_config["LM_HEAD_INPUT_MEMCFG"], layout=ttnn.TILE_LAYOUT, diff --git a/models/demos/qwen/tests/test_qwen_attention.py b/models/demos/qwen/tests/test_qwen_attention.py index 18ec68dba7f..c47242ebce7 100644 --- a/models/demos/qwen/tests/test_qwen_attention.py +++ b/models/demos/qwen/tests/test_qwen_attention.py @@ -88,7 +88,7 @@ def test_qwen_attention_inference(mesh_device, use_program_cache, reset_seeds, e torch.tensor([current_pos] * batch), device=mesh_device, dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) attention_input = model_args.prepare_inputs_ttnn_decode( diff --git a/models/demos/qwen/tests/test_qwen_decoder.py b/models/demos/qwen/tests/test_qwen_decoder.py index ff86c59320c..7095f670d9d 100644 --- a/models/demos/qwen/tests/test_qwen_decoder.py +++ b/models/demos/qwen/tests/test_qwen_decoder.py @@ -89,7 +89,7 @@ def test_qwen_decoder_inference(mesh_device, use_program_cache, reset_seeds, ens torch.tensor([current_pos] * batch), device=mesh_device, dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) decode_input = model_args.prepare_inputs_ttnn_decode( diff --git a/models/demos/qwen/tests/test_qwen_embedding.py b/models/demos/qwen/tests/test_qwen_embedding.py index 1768ba78e37..e41a436a327 100644 --- a/models/demos/qwen/tests/test_qwen_embedding.py +++ b/models/demos/qwen/tests/test_qwen_embedding.py @@ -61,7 +61,7 @@ def test_qwen_embedding(mesh_device, use_program_cache, reset_seeds, ensure_gc): tt_input = ttnn.from_torch( pt_input.squeeze(1), device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, ) diff --git a/models/demos/qwen/tests/test_qwen_mlp.py b/models/demos/qwen/tests/test_qwen_mlp.py index 911e79aa407..1aabd937d01 100644 --- a/models/demos/qwen/tests/test_qwen_mlp.py +++ b/models/demos/qwen/tests/test_qwen_mlp.py @@ -75,7 +75,7 @@ def test_qwen_mlp_inference(mesh_device, seq_len, use_program_cache, reset_seeds tt_input = ttnn.from_torch( torch_input, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.bfloat8_b, memory_config=model_args.model_config["SHARDED_MLP_INPUT_MEMCFG"] if mode == "decode" diff --git a/models/demos/qwen/tests/test_qwen_model.py b/models/demos/qwen/tests/test_qwen_model.py index c07b626f571..fa492e842d3 100644 --- a/models/demos/qwen/tests/test_qwen_model.py +++ b/models/demos/qwen/tests/test_qwen_model.py @@ -158,7 +158,7 @@ def test_qwen_model_inference(mesh_device, weights, layers, use_program_cache, r torch.tensor([current_pos] * batch), device=mesh_device, dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) # Run TT model diff --git a/models/demos/qwen/tests/test_qwen_perf.py b/models/demos/qwen/tests/test_qwen_perf.py index b1bfd92c77e..1c6a09446e7 100644 --- a/models/demos/qwen/tests/test_qwen_perf.py +++ b/models/demos/qwen/tests/test_qwen_perf.py @@ -151,7 +151,7 @@ def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos encoded_prompts_tensor[:, 0].unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0 ), device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.uint32, ) @@ -167,7 +167,7 @@ def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos current_pos = ttnn.from_torch( torch.tensor([generation_start_pos] * batch), device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.int32, ) diff --git a/models/demos/qwen/tests/test_qwen_rms_norm.py b/models/demos/qwen/tests/test_qwen_rms_norm.py index e5e482e7e04..a7c64249f05 100644 --- a/models/demos/qwen/tests/test_qwen_rms_norm.py +++ b/models/demos/qwen/tests/test_qwen_rms_norm.py @@ -74,7 +74,7 @@ def test_qwen_rms_norm_inference(mesh_device, use_program_cache, reset_seeds, en device=mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), memory_config=ttnn.L1_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG, ) diff --git a/models/demos/qwen/tt/lm_head.py b/models/demos/qwen/tt/lm_head.py index 84bfb0043a1..9348c137290 100644 --- a/models/demos/qwen/tt/lm_head.py +++ b/models/demos/qwen/tt/lm_head.py @@ -63,7 +63,7 @@ def __init__( ttnn.as_tensor( combined_split, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), layout=ttnn.TILE_LAYOUT, dtype=dtype, memory_config=memory_config, diff --git a/models/demos/qwen/tt/model_config.py b/models/demos/qwen/tt/model_config.py index 8b58ce59475..102e537685d 100644 --- a/models/demos/qwen/tt/model_config.py +++ b/models/demos/qwen/tt/model_config.py @@ -503,9 +503,9 @@ def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False): x: (batch, seq, dim) """ mesh_mapper = ( - ttnn.ReplicateTensorToMesh(self.mesh_device) + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device) if force_replicated - else ttnn.ShardTensorToMesh(self.mesh_device, dim=-1) + else ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1) ) if len(x.shape) == 3: @@ -559,9 +559,9 @@ def prepare_inputs_ttnn_prefill(self, x_bsh, force_replicated=False): x_1BSH = x_bsh.unsqueeze(0) mesh_mapper = ( - ttnn.ReplicateTensorToMesh(self.mesh_device) + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device) if force_replicated - else ttnn.ShardTensorToMesh(self.mesh_device, dim=-1) + else ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1) ) # input goes to DRAM diff --git a/models/demos/qwen/tt/qwen_attention.py b/models/demos/qwen/tt/qwen_attention.py index 6ef253cf8a4..d3e56f1f921 100644 --- a/models/demos/qwen/tt/qwen_attention.py +++ b/models/demos/qwen/tt/qwen_attention.py @@ -25,7 +25,7 @@ def fall_back_rope(xq, xk, rot_mats, mesh_device): xq = ttnn.from_torch( xq, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, @@ -33,7 +33,7 @@ def fall_back_rope(xq, xk, rot_mats, mesh_device): xk = ttnn.from_torch( xk, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, @@ -135,7 +135,7 @@ def __init__( dim=-1, ), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=wqkv_mem_config, layout=self.model_config["ATTN_W_LAYOUT_TILE"], @@ -152,7 +152,7 @@ def __init__( dim=-1, ).unsqueeze(0), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=self.model_config["ATTN_BIAS_WEIGHTS_MEMCFG"], layout=self.model_config["ATTN_B_LAYOUT_TILE"], @@ -174,7 +174,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), cache_file_name=cache_name("wo_width_sharded"), ) self.wo = ttnn.to_device(wo_ttnn, self.mesh_device) @@ -190,7 +190,7 @@ def __init__( -1, ), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-2), memory_config=wo_mem_config, dtype=self.dtype, layout=self.model_config["ATTN_W_LAYOUT_TILE"], @@ -236,7 +236,7 @@ def __init__( ttnn.as_tensor( k_or_v, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), layout=self.model_config["ATTN_W_LAYOUT_TILE"], dtype=self.dtype, cache_file_name=f"{weight_cache_path}/kvcache_{k_or_v.shape}" diff --git a/models/demos/qwen/tt/qwen_common.py b/models/demos/qwen/tt/qwen_common.py index b6649cce918..d48307f18e5 100644 --- a/models/demos/qwen/tt/qwen_common.py +++ b/models/demos/qwen/tt/qwen_common.py @@ -115,14 +115,14 @@ def get_prefill_rot_mat(head_dim, max_seq_len, mesh_device, seq_len): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) sin_gathereds = ttnn.from_torch( sin_gathered, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) rot_mats = [cos_gathereds, sin_gathereds] @@ -169,13 +169,13 @@ def get_single_rot_mat( device=mesh_device if not on_host else None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if num_devices > 1 or not on_host else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if num_devices > 1 or not on_host else None, ), ttnn.from_torch( rot_matrix.unsqueeze(0).unsqueeze(0), # 1,1,head_dim,head_dim device=mesh_device if not on_host else None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if num_devices > 1 or not on_host else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if num_devices > 1 or not on_host else None, ) diff --git a/models/demos/qwen/tt/qwen_embedding.py b/models/demos/qwen/tt/qwen_embedding.py index 9cefdf8af90..ad9a0f10a67 100644 --- a/models/demos/qwen/tt/qwen_embedding.py +++ b/models/demos/qwen/tt/qwen_embedding.py @@ -28,7 +28,7 @@ def __init__( torch_weight, dtype=dtype, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=3), layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=args.get_model_config()["EMB_WEIGHTS_MEMCFG"], cache_file_name=cache_name, diff --git a/models/demos/qwen/tt/qwen_mlp.py b/models/demos/qwen/tt/qwen_mlp.py index ad500853920..ca9976166e0 100644 --- a/models/demos/qwen/tt/qwen_mlp.py +++ b/models/demos/qwen/tt/qwen_mlp.py @@ -38,7 +38,7 @@ def __init__( torch_weight(name_dict[name[:2]]), # Grab only the wX part of the name dtype=type, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=dim), layout=ttnn.TILE_LAYOUT, memory_config=w2_mem_config if "w2" in name else w1_w3_mem_config, cache_file_name=cache_name(name), diff --git a/models/demos/t3000/falcon40b/demo/demo.py b/models/demos/t3000/falcon40b/demo/demo.py index 3e53c1b0a8e..d7aa80dea38 100644 --- a/models/demos/t3000/falcon40b/demo/demo.py +++ b/models/demos/t3000/falcon40b/demo/demo.py @@ -130,7 +130,7 @@ def initialize_and_fill_kv_cache( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_v_cache_host, @@ -138,7 +138,7 @@ def initialize_and_fill_kv_cache( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) kv_cache += ((tt_k_cache, tt_v_cache),) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_attention.py b/models/demos/t3000/falcon40b/tests/test_falcon_attention.py index a3bb1c8c386..55bff77d105 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_attention.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_attention.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.falcon40b.reference.hf_modeling_falcon import ( FalconForCausalLM, ) @@ -90,7 +90,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["ATTN_INPUT_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) attention_mask_memconfig = model_config["ATTN_MASK_MEMCFG"] @@ -105,7 +105,7 @@ def run_test_FalconAttention_inference( layout=ttnn.ROW_MAJOR_LAYOUT, device=mesh_device, memory_config=attention_mask_memconfig, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), preprocess=lambda x: (x * (-1e5)).expand(1, 1, -1, -1), ) @@ -124,7 +124,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( @@ -133,7 +133,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past = (tt_k_cache, tt_v_cache) @@ -161,7 +161,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["LN_ATTN_OUTPUT_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), preprocess=lambda x: x.unsqueeze(1).transpose(0, 2), ) @@ -185,7 +185,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=attention_mask_memconfig, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), preprocess=lambda x: (x.transpose(0, 2) * -1e5).expand(-1, configuration.num_attention_heads, -1, -1), ) @@ -200,7 +200,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_v_cache_host, @@ -208,7 +208,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past = (tt_k_cache, tt_v_cache) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_causallm.py b/models/demos/t3000/falcon40b/tests/test_falcon_causallm.py index a853b74a827..f79ce843b72 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_causallm.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_causallm.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ShardTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.falcon40b.reference.hf_modeling_falcon import ( FalconForCausalLM, ) @@ -101,7 +101,7 @@ def run_test_FalconCausalLM_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_kv_cache_host, @@ -109,7 +109,7 @@ def run_test_FalconCausalLM_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past += ((tt_k_cache, tt_v_cache),) @@ -141,7 +141,7 @@ def run_test_FalconCausalLM_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_v_cache_host, @@ -149,7 +149,7 @@ def run_test_FalconCausalLM_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past += ((tt_k_cache, tt_v_cache),) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py b/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py index ef66249a132..18373c8a191 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ShardTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.falcon40b.reference.hf_modeling_falcon import ( FalconForCausalLM, ) @@ -93,7 +93,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), ) attention_mask_memconfig = model_config["ATTN_MASK_MEMCFG"] @@ -108,7 +108,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=attention_mask_memconfig, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), preprocess=lambda x: (x * -1e5).expand(-1, mesh_device.get_num_devices(), -1, -1), ) @@ -121,7 +121,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_v_cache_host, @@ -129,7 +129,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past = (tt_k_cache, tt_v_cache) @@ -167,7 +167,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), preprocess=lambda x: x.unsqueeze(1).transpose(0, 2), ) @@ -192,7 +192,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=attention_mask_memconfig, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), preprocess=lambda x: (x.transpose(0, 2) * -1e5).expand(-1, configuration.num_attention_heads, -1, -1), ) @@ -207,7 +207,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( @@ -216,7 +216,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past = (tt_k_cache, tt_v_cache) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_mlp.py b/models/demos/t3000/falcon40b/tests/test_falcon_mlp.py index 1dd2eacd664..877a3143170 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_mlp.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_mlp.py @@ -85,7 +85,7 @@ def run_test_FalconMLP_inference( device=mesh_device, memory_config=model_config["LN_MLP_OUTPUT_MEMCFG"], layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_out = tt_FalconMLP_model(tt_mlp_input, llm_mode) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_model.py b/models/demos/t3000/falcon40b/tests/test_falcon_model.py index 3696d037bbb..50048cf0fdc 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_model.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_model.py @@ -6,7 +6,7 @@ import pytest from loguru import logger import ttnn -from ttnn import ShardTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.falcon40b.reference.hf_modeling_falcon import ( FalconForCausalLM, ) @@ -95,7 +95,7 @@ def run_test_FalconModel_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_kv_cache_host, @@ -103,7 +103,7 @@ def run_test_FalconModel_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past += ((tt_k_cache, tt_v_cache),) @@ -136,7 +136,7 @@ def run_test_FalconModel_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_v_cache_host, @@ -144,7 +144,7 @@ def run_test_FalconModel_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past += ((tt_k_cache, tt_v_cache),) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py b/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py index 2f023c7eb04..82c070d701e 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ShardTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.falcon40b.reference.hf_modeling_falcon import FalconForCausalLM, FalconConfig from models.demos.t3000.falcon40b.tt.falcon_causallm import TtFalconCausalLM from models.demos.t3000.falcon40b.tt.model_config import get_model_config, model_config_entries @@ -68,7 +68,7 @@ def run_test_falcon_prefill_end_to_end_determinism( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_v_cache_host, @@ -76,7 +76,7 @@ def run_test_falcon_prefill_end_to_end_determinism( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past += ((tt_k_cache, tt_v_cache),) diff --git a/models/demos/t3000/falcon40b/tt/falcon_attention.py b/models/demos/t3000/falcon40b/tt/falcon_attention.py index b3adfd184c9..f6afd108e5e 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_attention.py +++ b/models/demos/t3000/falcon40b/tt/falcon_attention.py @@ -8,7 +8,7 @@ from typing import Optional, Tuple import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper from models.utility_functions import nearest_32 from models.demos.t3000.falcon40b.tt.model_utils import convert_to_layout @@ -46,7 +46,7 @@ def generate_cos_sin_cache( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["COS_CACHED_WEIGHTS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), cache_file_name=cos_cached_path, ) @@ -58,7 +58,7 @@ def generate_cos_sin_cache( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["SIN_CACHED_WEIGHTS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), cache_file_name=sin_cached_path, ) @@ -165,7 +165,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["FUSED_QKV_MM_WEIGHTS_MEMCFG"], - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), cache_file_name=query_key_value_path, preprocess=lambda x: torch.transpose(x.reshape(1, 1, *x.shape), -2, -1), ) @@ -178,7 +178,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["SELFOUT_MM_WEIGHTS_MEMCFG"], - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), cache_file_name=selfout_path, preprocess=lambda x: torch.transpose(x.reshape(1, 1, *x.shape), -2, -1), ) @@ -219,7 +219,7 @@ def initialize_kvcache(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["DRAM_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=kv_cache_path, ) @@ -229,7 +229,7 @@ def initialize_kvcache(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["DRAM_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=kv_cache_path, ) diff --git a/models/demos/t3000/falcon40b/tt/falcon_causallm.py b/models/demos/t3000/falcon40b/tt/falcon_causallm.py index 9f971d2e988..ffb8b1a1e60 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_causallm.py +++ b/models/demos/t3000/falcon40b/tt/falcon_causallm.py @@ -6,7 +6,7 @@ from typing import Optional, Tuple import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper from models.demos.t3000.falcon40b.tt.falcon_model import TtFalconModelShared from models.demos.t3000.falcon40b.tt.model_utils import falcon_prefill_matmul, determine_tensor_deallocation @@ -50,7 +50,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=self.model_config["LM_HEAD_MM_WEIGHTS_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), cache_file_name=lm_head_path, preprocess=lambda x: torch.transpose(x.reshape(1, 1, *x.shape), -2, -1), ) diff --git a/models/demos/t3000/falcon40b/tt/falcon_decoder.py b/models/demos/t3000/falcon40b/tt/falcon_decoder.py index d78b69c4aed..a17262a9043 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_decoder.py +++ b/models/demos/t3000/falcon40b/tt/falcon_decoder.py @@ -6,7 +6,7 @@ from typing import Optional, Tuple import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.demos.t3000.falcon40b.tt.falcon_attention import TtFalconAttention from models.demos.t3000.falcon40b.tt.falcon_mlp import TtFalconMLP @@ -80,7 +80,7 @@ def pad_ln_params(x): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["LN_MLP_WEIGHTS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=ln_mlp_weights_path, preprocess=pad_ln_params, ) @@ -93,7 +93,7 @@ def pad_ln_params(x): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["LN_MLP_BIAS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=ln_mlp_bias_path, preprocess=pad_ln_params, ) @@ -111,7 +111,7 @@ def pad_ln_params(x): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["LN_ATTN_WEIGHTS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=ln_attn_weights_path, preprocess=pad_ln_params, ) @@ -124,7 +124,7 @@ def pad_ln_params(x): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["LN_ATTN_BIAS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=ln_attn_bias_path, preprocess=pad_ln_params, ) diff --git a/models/demos/t3000/falcon40b/tt/falcon_embeddings.py b/models/demos/t3000/falcon40b/tt/falcon_embeddings.py index 8135a41b1a5..d0a1e88e0c7 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_embeddings.py +++ b/models/demos/t3000/falcon40b/tt/falcon_embeddings.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper class TtFalconEmbeddings(torch.nn.Module): @@ -25,7 +25,7 @@ def __init__(self, mesh_device, state_dict, cache_path, model_config): device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, cache_file_name=cache_path / base_name, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), preprocess=lambda x: x.reshape(1, 1, *x.shape), ) diff --git a/models/demos/t3000/falcon40b/tt/falcon_mlp.py b/models/demos/t3000/falcon40b/tt/falcon_mlp.py index ba75a8b4a95..b1c90745bfe 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_mlp.py +++ b/models/demos/t3000/falcon40b/tt/falcon_mlp.py @@ -8,7 +8,7 @@ from typing import List from models.demos.t3000.falcon40b.tt.model_utils import falcon_prefill_matmul, determine_tensor_deallocation -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper class TtFalconMLP: @@ -43,7 +43,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["DENSE_H_TO_4H_MM_WEIGHTS_MEMCFG"], - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=3), cache_file_name=tt_cache_path / dense_h_to_4h_str, preprocess=lambda x: torch.transpose(x.reshape(1, 1, *x.shape), -2, -1), ) @@ -54,7 +54,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["DENSE_4H_TO_H_MM_WEIGHTS_MEMCFG"], - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=2), cache_file_name=tt_cache_path / f"{dense_4h_to_h_str}_height_fractured", preprocess=lambda x: torch.transpose(x.reshape(1, 1, *x.shape), -2, -1), ) @@ -84,7 +84,7 @@ def _allocate_output_mlp_tensors(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["DEFAULT_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) def __call__(self, x: List[ttnn.Tensor], llm_mode: str) -> List[ttnn.Tensor]: diff --git a/models/demos/t3000/falcon40b/tt/falcon_model.py b/models/demos/t3000/falcon40b/tt/falcon_model.py index 1c2f7b12574..ac601a2650e 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_model.py +++ b/models/demos/t3000/falcon40b/tt/falcon_model.py @@ -9,7 +9,7 @@ import ttnn -from ttnn import ReplicateTensorToMesh, ShardTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, shard_tensor_to_mesh_mapper from models.demos.t3000.falcon40b.tt.falcon_decoder import TtFalconDecoderLayer from models.demos.t3000.falcon40b.tt.falcon_embeddings import TtFalconEmbeddings from models.demos.t3000.falcon40b.tt.falcon_attention import generate_cos_sin_cache @@ -107,7 +107,7 @@ def __init__( layout=ttnn.ROW_MAJOR_LAYOUT, device=mesh_device, memory_config=self.model_config["LN_F_WEIGHTS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), cache_file_name=layernorm_weights_path, preprocess=lambda x: x.reshape(1, 1, -1, 32), ) @@ -118,7 +118,7 @@ def __init__( layout=ttnn.ROW_MAJOR_LAYOUT, device=mesh_device, memory_config=self.model_config["LN_F_BIAS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), cache_file_name=layernorm_bias_path, preprocess=lambda x: x.reshape(1, 1, -1, 32), ) @@ -138,7 +138,7 @@ def create_attn_mask(self, max_seq_len): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=attention_mask_memconfig, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), preprocess=lambda x: (x * -1e5), ) @@ -181,7 +181,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) # Generate input and attention_mask --------------------------------------------- @@ -230,7 +230,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=self.model_config["DEFAULT_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), preprocess=lambda x: (x.transpose(0, 2) * -1e5).expand(1, 1, -1, -1), ) diff --git a/models/demos/t3000/falcon40b/tt/model_utils.py b/models/demos/t3000/falcon40b/tt/model_utils.py index 25ba146554f..e3635da4699 100644 --- a/models/demos/t3000/falcon40b/tt/model_utils.py +++ b/models/demos/t3000/falcon40b/tt/model_utils.py @@ -433,7 +433,7 @@ def generate_layernorm_persistent_tensors(seq_len, slice_size, ln_output_tensors layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) if name in ln_output_tensors_dict and ln_output_tensors_dict[name] is not None: ln_output_tensors_dict[name].update({seq_len: output_tensor}) diff --git a/models/demos/t3000/llama2_70b/demo/demo_continuous_batching_paged_attention.py b/models/demos/t3000/llama2_70b/demo/demo_continuous_batching_paged_attention.py index 02a6684d838..c08837dfd65 100644 --- a/models/demos/t3000/llama2_70b/demo/demo_continuous_batching_paged_attention.py +++ b/models/demos/t3000/llama2_70b/demo/demo_continuous_batching_paged_attention.py @@ -13,7 +13,7 @@ import pytest from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from transformers.generation.utils import top_k_top_p_filtering @@ -243,7 +243,7 @@ def run_decode( static_page_table, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(model.mesh_device), ) page_table_tt = ttnn.to_device(page_table_tt, model.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) diff --git a/models/demos/t3000/llama2_70b/tests/test_chunked_generation.py b/models/demos/t3000/llama2_70b/tests/test_chunked_generation.py index 22ba67ece5d..48dece25332 100644 --- a/models/demos/t3000/llama2_70b/tests/test_chunked_generation.py +++ b/models/demos/t3000/llama2_70b/tests/test_chunked_generation.py @@ -5,7 +5,7 @@ from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.t3000.llama2_70b.tt.llama_generation import ( diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_attention.py b/models/demos/t3000/llama2_70b/tests/test_llama_attention.py index 72bd9b7091f..e351736c14b 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_attention.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_attention.py @@ -6,7 +6,7 @@ from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.t3000.llama2_70b.tt.llama_attention_optimized import TtLlamaAttention_optimized @@ -130,7 +130,7 @@ def tt_llama_attention_prepare_inputs( dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), device=llama_attention_model.mesh_device, ) xs = ttnn.to_device(xs, llama_attention_model.mesh_device) @@ -149,7 +149,7 @@ def tt_llama_attention_prepare_inputs( cache_file_name=cache_name(f"cos_gathered_prefill_{start_pos}_to_{start_pos + seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_attention_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), ) sin_gathereds = ttnn.as_tensor( sin_gathered, @@ -158,7 +158,7 @@ def tt_llama_attention_prepare_inputs( cache_file_name=cache_name(f"sin_gathered_prefill_{start_pos}_to_{start_pos + seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_attention_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), ) cos_gathereds = ttnn.to_device(cos_gathereds, llama_attention_model.mesh_device) @@ -181,7 +181,7 @@ def tt_llama_attention_prepare_inputs( dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), device=llama_attention_model.mesh_device, ) xs = ttnn.to_device(xs, llama_attention_model.mesh_device) @@ -194,7 +194,7 @@ def tt_llama_attention_prepare_inputs( layout=ttnn.ROW_MAJOR_LAYOUT, device=llama_attention_model.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), ) rot_mats = rope_setup.get_rot_mats(cache_idxs) @@ -263,7 +263,7 @@ def run_test_LlamaAttention_inference( layout=ttnn.TILE_LAYOUT, device=t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) transformation_mats = ttnn.to_device(transformation_mats, t3k_mesh_device) transformation_mats = {"prefill": transformation_mats} @@ -284,7 +284,7 @@ def run_test_LlamaAttention_inference( page_table, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) page_table_tt = ttnn.to_device(page_table_tt, t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) @@ -352,7 +352,7 @@ def run_test_LlamaAttention_inference( layout=ttnn.ROW_MAJOR_LAYOUT, device=t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # SDPA requires that the page table batch dim matches the input batch dim, which must be 1 in prefill prefill_page_table = page_table[0:1, :] @@ -362,7 +362,7 @@ def run_test_LlamaAttention_inference( layout=ttnn.ROW_MAJOR_LAYOUT, device=t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) chunk_tt_input = tt_input[:, chunk_start:chunk_end] diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py b/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py index f57969b7c7f..f1e738f011a 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py @@ -6,7 +6,7 @@ from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.t3000.llama2_70b.tt.llama_decoder_optimized import TtLlamaDecoder_optimized @@ -119,7 +119,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode, rop dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(llama_decoder_model.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(llama_decoder_model.mesh_device, dim=3), device=llama_decoder_model.mesh_device, ) xs = ttnn.to_device(xs, llama_decoder_model.mesh_device) @@ -141,7 +141,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode, rop cache_file_name=cache_name(f"cos_gathered_prefill_{seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_decoder_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_decoder_model.mesh_device), ) sin_gathereds = ttnn.as_tensor( sin_gathered, @@ -150,7 +150,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode, rop cache_file_name=cache_name(f"sin_gathered_prefill_{seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_decoder_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_decoder_model.mesh_device), ) cos_gathereds = ttnn.to_device(cos_gathereds, llama_decoder_model.mesh_device) sin_gathereds = ttnn.to_device(sin_gathereds, llama_decoder_model.mesh_device) @@ -171,7 +171,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode, rop dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(llama_decoder_model.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(llama_decoder_model.mesh_device, dim=3), device=llama_decoder_model.mesh_device, ) xs = ttnn.to_device(xs, llama_decoder_model.mesh_device) @@ -184,7 +184,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode, rop layout=ttnn.ROW_MAJOR_LAYOUT, device=llama_decoder_model.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_decoder_model.mesh_device), ) rot_mats = rope_setup.get_rot_mats(cache_idxs) @@ -248,7 +248,7 @@ def run_test_LlamaDecoder_inference( layout=ttnn.TILE_LAYOUT, device=t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) transformation_mats = ttnn.to_device(transformation_mats, t3k_mesh_device) transformation_mats = {"prefill": transformation_mats} diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_generation.py b/models/demos/t3000/llama2_70b/tests/test_llama_generation.py index babfe3b3657..d6ca6ab4bfe 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_generation.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_generation.py @@ -6,7 +6,7 @@ import torch from torch import nn import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor import scipy diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_mlp.py b/models/demos/t3000/llama2_70b/tests/test_llama_mlp.py index fcb0956fb4a..cc9fa417b24 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_mlp.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_mlp.py @@ -6,7 +6,7 @@ from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.t3000.llama2_70b.tt.llama_mlp_optimized import TtLlamaMLP_optimized @@ -42,7 +42,7 @@ def tt_llama_mlp_prepare_inputs(llama_mlp_model, x, mode): layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16, device=llama_mlp_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_mlp_model.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_mlp_model.mesh_device), ) if mode == "decode": diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_model.py b/models/demos/t3000/llama2_70b/tests/test_llama_model.py index ef41fbe6d89..100dce9c12e 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_model.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_model.py @@ -6,7 +6,7 @@ from loguru import logger import torch import ttnn -from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper import os import scipy diff --git a/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py index f43779eafdf..dd11900e29d 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py @@ -6,7 +6,7 @@ import math import torch import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper from models.demos.t3000.falcon40b.tt.model_utils import matmul_2d_config_from_tensor_shapes @@ -110,7 +110,7 @@ def init_kv_cache(self): ttnn.as_tensor( lp, device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=self.kv_dtype, @@ -179,7 +179,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=3), cache_file_name=self.cache_path / wqkv_cache_str, ) self.qkv = ttnn.to_device(qkv_ttnn, self.mesh_device) @@ -190,7 +190,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=3), cache_file_name=self.cache_path / wo_str, ) diff --git a/models/demos/t3000/llama2_70b/tt/llama_common.py b/models/demos/t3000/llama2_70b/tt/llama_common.py index 63c8aad8233..14474e541ba 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_common.py +++ b/models/demos/t3000/llama2_70b/tt/llama_common.py @@ -28,42 +28,10 @@ UNIT_TEST_START_POS = 0 UNIT_TEST_GENERATION_LENGTH = 20 from ttnn import ( - TensorToMesh, MeshToTensor, ) -class ShardTensor2dMesh(TensorToMesh): - def __init__(self, mesh_device, dims, cluster_shape): - super().__init__(mesh_device) - self.dims = dims - self.cluster_shape = cluster_shape - - def map(self, tensor: torch.tensor): - # Returns list of tensors to map to row-major ordering of chips in cluster - tensors_grid_y = None - if self.dims[1] == None: - tensors_grid_y = [tensor.clone() for _ in range(self.cluster_shape[1])] - else: - tensors_grid_y = torch.chunk(tensor, self.cluster_shape[1], dim=self.dims[1]) - - tensors_grid_all = None - if self.dims[0] == None: - tensors_grid_all = [t.clone() for t in tensors_grid_y for _ in range(self.cluster_shape[0])] - else: - tensors_grid_all = [ - tt for t in tensors_grid_y for tt in torch.chunk(t, self.cluster_shape[0], dim=self.dims[0]) - ] - - return list(tensors_grid_all) - - def config(self): - return { - "strategy": "shard", - "shard_dim": f"{self.dims[0] if self.dims[0] else self.dims[1]}", - } - - class ConcatMesh2DToTensor(MeshToTensor): def __init__(self, mesh_device, dims, cluster_shape): self.dims = dims diff --git a/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py index a1bd6b1565e..3e38d525dac 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py @@ -6,7 +6,7 @@ from typing import List import torch import ttnn -from ttnn import ReplicateTensorToMesh, ShardTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, shard_tensor_to_mesh_mapper from models.demos.t3000.llama2_70b.tt.llama_attention_optimized import TtLlamaAttention_optimized from models.demos.t3000.llama2_70b.tt.llama_mlp_optimized import TtLlamaMLP_optimized @@ -106,7 +106,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=self.cache_path / attn_norm_str, ) self.attn_norm = ttnn.to_device(attn_norm_ttnn, self.mesh_device) @@ -117,7 +117,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=2), cache_file_name=self.cache_path / attn_norm_sharded_str, ) self.attn_norm_sharded = ttnn.to_device(attn_norm_sharded_ttnn, self.mesh_device) @@ -128,7 +128,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=self.cache_path / ffn_norm_str, ) self.ffn_norm = ttnn.to_device(ffn_norm_ttnn, self.mesh_device) @@ -139,7 +139,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=2), cache_file_name=self.cache_path / ffn_norm_sharded_str, ) self.ffn_norm_sharded = ttnn.to_device(ffn_norm_sharded_ttnn, self.mesh_device) diff --git a/models/demos/t3000/llama2_70b/tt/llama_embedding.py b/models/demos/t3000/llama2_70b/tt/llama_embedding.py index 177cfa7e293..8211edb21c0 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_embedding.py +++ b/models/demos/t3000/llama2_70b/tt/llama_embedding.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper class TtLlamaEmbedding: @@ -44,7 +44,7 @@ def __init__( layout=ttnn.ROW_MAJOR_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), cache_file_name=cache_path / base_name, ) self.emb_weights = ttnn.to_device(embd_weights_ttn, mesh_device) diff --git a/models/demos/t3000/llama2_70b/tt/llama_generation.py b/models/demos/t3000/llama2_70b/tt/llama_generation.py index 0aee8f7bf77..a1087e9d645 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_generation.py +++ b/models/demos/t3000/llama2_70b/tt/llama_generation.py @@ -5,7 +5,7 @@ import math import torch import ttnn -from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from loguru import logger diff --git a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py index aa0d5ae2a24..4356e5dfdaa 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py @@ -6,7 +6,7 @@ from typing import List import torch import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper from models.utility_functions import nearest_32 from models.demos.t3000.falcon40b.tt.model_utils import matmul_2d_config @@ -89,7 +89,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=w3_mem_config, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=3), cache_file_name=self.cache_path / w1_dram_shard_str, ) @@ -105,7 +105,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=w2_memory_config, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=2), cache_file_name=self.cache_path / w2_dram_shard_str, ) @@ -115,7 +115,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=w3_mem_config, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=3), cache_file_name=self.cache_path / w3_dram_shard_str, ) diff --git a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py index 32bce8227ec..29b59aa5490 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py @@ -7,7 +7,7 @@ from tqdm import tqdm import torch import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper from models.utility_functions import nearest_32, profiler @@ -66,7 +66,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) transformation_mats_prefill = ttnn.to_device(transformation_mats_prefill, mesh_device) @@ -139,7 +139,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=3), cache_file_name=self.cache_path / lm_head_str, ) self.lm_head = ttnn.to_device(padded_lm_head_ttnn, self.mesh_device) @@ -150,7 +150,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=self.cache_path / norm_str, ) self.norm = ttnn.to_device(norm_ttnn, self.mesh_device) @@ -161,7 +161,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=2), cache_file_name=self.cache_path / norm_sharded_str, ) self.norm_sharded = ttnn.to_device(norm_sharded_ttnn, self.mesh_device) @@ -210,7 +210,7 @@ def prepare_inputs( inp_ids, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) if mode == "prefill": @@ -235,7 +235,7 @@ def prepare_inputs( cache_file_name=cache_name(f"cos_gathered_prefill_{start_pos}_{start_pos+seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) sin_gathereds = ttnn.as_tensor( sin_gathered, @@ -244,7 +244,7 @@ def prepare_inputs( cache_file_name=cache_name(f"sin_gathered_prefill_{start_pos}_{start_pos+seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) cos_gathereds = ttnn.to_device(cos_gathereds, self.mesh_device) sin_gathereds = ttnn.to_device(sin_gathereds, self.mesh_device) @@ -261,7 +261,7 @@ def prepare_inputs( memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) if chunk_page_table is not None: chunk_page_table = ttnn.as_tensor( @@ -270,7 +270,7 @@ def prepare_inputs( memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) return (xs, start_pos, rot_mats, rot_idxs_tt, cache_idxs_tt, page_table, chunk_page_table) @@ -288,7 +288,7 @@ def prepare_inputs( cache_idxs, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) rot_mats = None # Created in prepare_device_inputs @@ -303,7 +303,7 @@ def prepare_inputs( page_table, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) return (xs, start_pos, rot_mats, rot_idxs_tt, cache_idxs_tt, page_table) diff --git a/models/demos/t3000/llama2_70b/tt/llama_rope.py b/models/demos/t3000/llama2_70b/tt/llama_rope.py index e7f4baeb4fd..02d16ea0271 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_rope.py +++ b/models/demos/t3000/llama2_70b/tt/llama_rope.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.common.lightweightmodule import LightweightModule from models.demos.t3000.llama2_70b.tt.llama_common import precompute_freqs, get_rot_transformation_mat, gather_cos_sin from loguru import logger @@ -48,14 +48,14 @@ def __init__( device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=datatype, - mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(device) if self.is_mesh_device else None, ) self.sin_matrix = ttnn.from_torch( sin_matrix, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=datatype, - mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(device) if self.is_mesh_device else None, ) # Generate the transformation matrix @@ -74,7 +74,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, dtype=datatype, memory_config=trans_mat_mem_config, - mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(device) if self.is_mesh_device else None, ) def get_trans_mats(self): @@ -93,7 +93,7 @@ def get_rot_idxs(self, position_idxs): position_idxs, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(self.device) if self.is_mesh_device else None, ) return rot_idxs diff --git a/models/demos/t3000/mixtral8x7b/demo/demo.py b/models/demos/t3000/mixtral8x7b/demo/demo.py index be02adcf491..afcf5b15d92 100644 --- a/models/demos/t3000/mixtral8x7b/demo/demo.py +++ b/models/demos/t3000/mixtral8x7b/demo/demo.py @@ -9,7 +9,7 @@ from time import time import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_common import ( load_inputs, preprocess_inputs, diff --git a/models/demos/t3000/mixtral8x7b/demo/demo_with_prefill.py b/models/demos/t3000/mixtral8x7b/demo/demo_with_prefill.py index 408b223e3cf..2de3717c87b 100644 --- a/models/demos/t3000/mixtral8x7b/demo/demo_with_prefill.py +++ b/models/demos/t3000/mixtral8x7b/demo/demo_with_prefill.py @@ -9,7 +9,7 @@ from time import time import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_common import ( load_inputs, preprocess_inputs_prefill, @@ -178,7 +178,7 @@ def run_mixtral_demo(user_input, batch_size, mesh_device, instruct_mode, test_pr layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) profiler.end("prepare_rot_mat_for_prefill") diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention.py index 957be57c7de..a2c04f6c58f 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_attention import TtMixtralAttention from models.demos.t3000.mixtral8x7b.tt.mixtral_common import prepare_inputs_ttnn, get_single_rot_mat from models.demos.t3000.mixtral8x7b.reference.model import Attention, precompute_freqs_cis diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention_prefill.py index d4e50a5f5cb..2f78e8254f5 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention_prefill.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_attention import TtMixtralAttention from models.demos.t3000.mixtral8x7b.tt.mixtral_common import ( prepare_inputs_ttnn_prefill, @@ -59,7 +59,7 @@ def test_mixtral_attention_inference(t3k_mesh_device, use_program_cache, reset_s layout=ttnn.TILE_LAYOUT, device=t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # Load ttnn model diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py index 36b035b536e..416e81d59a1 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py @@ -12,7 +12,7 @@ from models.demos.t3000.mixtral8x7b.reference.model import TransformerBlock, precompute_freqs_cis from models.demos.t3000.mixtral8x7b.tt.model_config import TtModelArgs from models.utility_functions import comp_pcc, comp_allclose -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor @pytest.mark.parametrize( diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder_prefill.py index dc4b84ba4ef..8ff473a67d5 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder_prefill.py @@ -17,7 +17,7 @@ from models.demos.t3000.mixtral8x7b.reference.model import TransformerBlock, precompute_freqs_cis, RMSNorm from models.demos.t3000.mixtral8x7b.tt.model_config import TtModelArgs from models.utility_functions import comp_pcc, comp_allclose -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor @pytest.mark.parametrize( @@ -57,7 +57,7 @@ def test_mixtral_decoder_inference(t3k_mesh_device, use_program_cache, reset_see layout=ttnn.TILE_LAYOUT, device=t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # Initialize TT model diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py index 932a60af16f..79f96b157ca 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py @@ -6,7 +6,7 @@ from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_mlp import TtMixtralMLP from models.demos.t3000.mixtral8x7b.reference.model import FeedForward, RMSNorm @@ -62,7 +62,7 @@ def test_mixtral_mlp_inference(t3k_mesh_device, use_program_cache, reset_seeds): dtype=ttnn.bfloat16, memory_config=ttnn.L1_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) tt_output = tt_model(tt_input) diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp_prefill.py index 7e952a57d98..55412b4f9c2 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp_prefill.py @@ -7,7 +7,7 @@ import pytest import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_mlp import TtMixtralMLP from models.demos.t3000.mixtral8x7b.reference.model import FeedForward, RMSNorm @@ -73,7 +73,7 @@ def test_mixtral_mlp_inference(t3k_mesh_device, use_program_cache, reset_seeds, dtype=ttnn.bfloat8_b, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) tt_input = ttnn.to_device(tt_input, t3k_mesh_device) diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py index afb36a0a7f6..f0fe2b92fd4 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py @@ -9,7 +9,7 @@ from sklearn.metrics import top_k_accuracy_score import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_common import prepare_inputs_ttnn from models.demos.t3000.mixtral8x7b.tt.mixtral_model import TtTransformer diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model_prefill.py index 876392eecc8..672a2995152 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model_prefill.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_common import ( prepare_inputs_ttnn_prefill, @@ -89,7 +89,7 @@ def test_mixtral_model_inference_CI(t3k_mesh_device, use_program_cache, reset_se layout=ttnn.TILE_LAYOUT, device=t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # Load TTNN model tt_model = TtTransformer( diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py index 10a1e2e0bc9..60a683f534c 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_mlp import TtMixtralMLP from models.demos.t3000.mixtral8x7b.tt.mixtral_moe import TtMoeLayer @@ -84,7 +84,7 @@ def test_mixtral_moe_inference(t3k_mesh_device, use_program_cache, reset_seeds): dtype=ttnn.bfloat16, memory_config=ttnn.L1_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # Run TT model diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe_prefill.py index 5e8df333fd7..a7d788ac8de 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe_prefill.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_mlp import TtMixtralMLP from models.demos.t3000.mixtral8x7b.tt.mixtral_moe import TtMoeLayer @@ -91,7 +91,7 @@ def test_mixtral_moe_inference(t3k_mesh_device, use_program_cache, reset_seeds, dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # Run TT model diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py index 1fa29fac602..81c8b6558c7 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py @@ -6,7 +6,7 @@ import pytest import ttnn -from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from models.demos.t3000.mixtral8x7b.tt.mixtral_common import ( preprocess_inputs_prefill, @@ -327,7 +327,7 @@ def run_inference_prefill(tt_model, model_args, prefill_seqlen, mesh_device, pt_ layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) profiler.end("prefill_prepare_rot_matrices") diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perplexity.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perplexity.py index 1418f44c19d..459e1a4ffad 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perplexity.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perplexity.py @@ -12,7 +12,7 @@ import numpy as np import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_common import ( prepare_inputs_ttnn, get_single_rot_mat, diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py index df46b58b1d0..601fb7e41af 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.common.rmsnorm import RMSNorm as TtRMSNorm from models.demos.t3000.mixtral8x7b.reference.model import RMSNorm as RefRMSNorm @@ -50,7 +50,7 @@ def test_mixtral_rms_norm_inference(t3k_mesh_device, use_program_cache, reset_se device=t3k_mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) tt_output = tt_model(tt_input, mode="decode") diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py index 1b27ad4a3f3..c768c940384 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py @@ -5,7 +5,7 @@ import torch import ttnn from models.utility_functions import nearest_32 -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.common.lightweightmodule import LightweightModule @@ -74,7 +74,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtype): .unsqueeze(0) .unsqueeze(0), device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=self.model_config["ATTN_WEIGHTS_MEMCFG"], layout=self.model_config["ATTN_W_LAYOUT_TILE"], @@ -90,7 +90,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtype): .unsqueeze(0) .unsqueeze(0), device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=-2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-2), dtype=self.dtype, memory_config=self.model_config["ATTN_WEIGHTS_MEMCFG"], layout=self.model_config["ATTN_W_LAYOUT_TILE"], @@ -118,7 +118,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtype): ttnn.as_tensor( lp, device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), dtype=ttnn.bfloat8_b, layout=self.model_config["ATTN_W_LAYOUT_TILE"], memory_config=self.model_config["ATTN_CACHE_WEIGHTS_MEMCFG"], @@ -135,7 +135,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtype): self.reduce_mask = ttnn.from_torch( reduce_mask_torch, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, ) diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py index 8a061d72b2e..dfdd8179bbf 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py @@ -5,7 +5,7 @@ from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.utility_functions import nearest_32 import json import math @@ -55,7 +55,7 @@ def preprocess_inputs(input_prompts, tokenizer, model_args, dtype, instruct, mes device=mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) for i in range(max_prompt_len) ] @@ -65,7 +65,7 @@ def preprocess_inputs(input_prompts, tokenizer, model_args, dtype, instruct, mes device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) for i in range(max_prompt_len) ] @@ -183,7 +183,7 @@ def prepare_inputs_ttnn(x_bsh, hidden_size, mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) return xs_1SBH @@ -224,7 +224,7 @@ def cache_attention(mesh_device, state_dict, model_args, current_rot_mat, rot_ma layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.L1_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_attn = TtMixtralAttention( @@ -295,13 +295,13 @@ def get_single_rot_mat(dhead, mesh_device, start_pos=0, theta: float = 1000000.0 device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), ttnn.from_torch( rot_matrix.unsqueeze(0).unsqueeze(0), # 1,1,head_dim,head_dim device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) @@ -330,13 +330,13 @@ def get_single_rot_mat_multi_pos(dhead, mesh_device, start_pos_ids, theta: float device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), ttnn.from_torch( rot_matrix.unsqueeze(0).unsqueeze(0).repeat(1, len(start_pos_ids), 1, 1), # 1,1,head_dim,head_dim device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) @@ -376,14 +376,14 @@ def get_prefill_rot_mat(head_dim, max_seq_len, mesh_device, seq_len): cos_gathered, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), device=mesh_device, ) sin_gathereds = ttnn.from_torch( sin_gathered, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), device=mesh_device, ) @@ -421,7 +421,7 @@ def prepare_inputs_ttnn_prefill(x_bsh, mesh_device, num_tokens=None): dtype=attn_mask_dtype, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) # input goes to L1 @@ -431,7 +431,7 @@ def prepare_inputs_ttnn_prefill(x_bsh, mesh_device, num_tokens=None): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) return xs_1BSH, attn_mask, attn_mask_torch diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py index c1272b8e62a..496ebd7346c 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper from models.common.lightweightmodule import LightweightModule @@ -36,7 +36,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtypes): torch_weight(name), dtype=dtypes[name], device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=0), layout=self.model_config["MLP_W_LAYOUT_TILE"], memory_config=self.model_config["MLP_WEIGHTS_MEMCFG"], cache_file_name=cache_name(name), diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_model.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_model.py index 093b3c8f7b4..b9288d09fb7 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_model.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_model.py @@ -7,7 +7,7 @@ from models.common.rmsnorm import RMSNorm from models.common.lightweightmodule import LightweightModule from models.demos.t3000.mixtral8x7b.tt.mixtral_common import get_single_rot_mat_multi_pos, get_single_rot_mat_torch -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper import torch @@ -55,7 +55,7 @@ def __init__(self, mesh_device, state_dict, args, dtype, layers, start_pos_ids, dtype=dtype, memory_config=self.model_config["OUTPUT_WEIGHTS_MEMCFG"], cache_file_name=output_cache_name, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) self.compute_kernel = self.args.get_compute_kernel_config() @@ -86,7 +86,7 @@ def forward( device=self.mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) else: rot_mats = self.current_rot_mat diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py index c71feb93bf5..8aaeeffee58 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper from models.common.lightweightmodule import LightweightModule @@ -44,7 +44,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): memory_config=self.model_config["GATE_WEIGHTS_MEMCFG"], cache_file_name=cache_name, device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) self.tile_size = 32 @@ -58,7 +58,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) self.top8_mask_11B_64 = ttnn.sum(self.top8_mask_11B_64, dim=2) @@ -69,7 +69,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) self.top2_mask_11BB = ttnn.sum(self.top2_mask_11BB, dim=2) @@ -81,7 +81,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) def forward(self, inputs, mode="decode"): diff --git a/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py b/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py index a2ea1b7c792..e455197f1cb 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py @@ -1,14 +1,17 @@ # SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + # SPDX-License-Identifier: Apache-2.0 + import pytest from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, shard_tensor_to_2d_mesh_mapper import gc + from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.tg.llama3_70b.tt.llama_attention_galaxy import TtLlamaAttention_galaxy from models.demos.tg.llama3_70b.tt.llama_common import setup_llama_env @@ -33,9 +36,9 @@ check_kv_cache, num_to_corerange, ConcatMesh2DToTensor, - ShardTensor2dMesh, ) + from models.utility_functions import skip_for_grayskull @@ -91,6 +94,7 @@ def forward(self, x, start_pos, freqs_cis, mask): freqs_cis: ? mask: ? + return: (batch, seq, hidden_dim) """ result = self.attention( @@ -126,8 +130,8 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos, rope_ layout=ttnn.TILE_LAYOUT, memory_config=ACT_MEMCFG, device=llama_attention_model.mesh_device, - mesh_mapper=ShardTensor2dMesh( - llama_attention_model.mesh_device, dims=(3, None), cluster_shape=llama_attention_model.cluster_shape + mesh_mapper=shard_tensor_to_2d_mesh_mapper( + llama_attention_model.mesh_device, mesh_shape=llama_attention_model.cluster_shape, dims=(None, 3) ), ) @@ -161,7 +165,7 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos, rope_ layout=ttnn.TILE_LAYOUT, memory_config=ROT_MAT_MEMCFG, device=llama_attention_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), ) attn_masks = None @@ -179,8 +183,8 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos, rope_ layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_attention_model.mesh_device, - mesh_mapper=ShardTensor2dMesh( - llama_attention_model.mesh_device, dims=(3, None), cluster_shape=llama_attention_model.cluster_shape + mesh_mapper=shard_tensor_to_2d_mesh_mapper( + llama_attention_model.mesh_device, mesh_shape=llama_attention_model.cluster_shape, dims=(None, 3) ), ) @@ -198,7 +202,7 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos, rope_ # cache_file_name=cache_name(f"cos_gathered_prefill_{seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_attention_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), ) sin_gathereds = ttnn.as_tensor( sin_gathered, @@ -207,7 +211,7 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos, rope_ # cache_file_name=cache_name(f"sin_gathered_prefill_{seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_attention_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), ) rot_mats = [cos_gathereds, sin_gathereds] @@ -220,7 +224,7 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos, rope_ dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, # cache_file_name=cache_name(f"attn_mask_prefill_{seq_len}"), - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_attention_model.mesh_device, ) @@ -276,7 +280,7 @@ def run_test_LlamaAttention_inference( layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(mesh_device), ) tt_LlamaAttention_model = TtLlamaAttention_galaxy( diff --git a/models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py b/models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py index 1c48eb04d89..29dddca1698 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py @@ -6,7 +6,7 @@ from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, shard_tensor_to_2d_mesh_mapper from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.tg.llama3_70b.tt.llama_decoder_galaxy import TtLlamaDecoder_galaxy @@ -33,7 +33,6 @@ check_kv_cache, num_to_corerange, ConcatMesh2DToTensor, - ShardTensor2dMesh, ) import gc @@ -129,8 +128,8 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode): layout=ttnn.TILE_LAYOUT, device=llama_decoder_model.mesh_device, memory_config=ACT_MEMCFG, - mesh_mapper=ShardTensor2dMesh( - llama_decoder_model.mesh_device, dims=(3, None), cluster_shape=llama_decoder_model.cluster_shape + mesh_mapper=shard_tensor_to_2d_mesh_mapper( + llama_decoder_model.mesh_device, mesh_shape=llama_decoder_model.cluster_shape, dims=(None, 3) ), ) @@ -159,7 +158,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode): layout=ttnn.TILE_LAYOUT, memory_config=ROT_MAT_MEMCFG, device=llama_decoder_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_decoder_model.mesh_device), ) attn_masks = None @@ -173,8 +172,8 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode): layout=ttnn.TILE_LAYOUT, device=llama_decoder_model.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh( - llama_decoder_model.mesh_device, dims=(3, None), cluster_shape=llama_decoder_model.cluster_shape + mesh_mapper=shard_tensor_to_2d_mesh_mapper( + llama_decoder_model.mesh_device, cluster_shape=llama_decoder_model.cluster_shape, dims=(None, 3) ), ) @@ -196,7 +195,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode): # cache_file_name=cache_name(f"cos_gathered_prefill_{seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_decoder_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_decoder_model.mesh_device), ) sin_gathereds = ttnn.as_tensor( sin_gathered, @@ -205,7 +204,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode): # cache_file_name=cache_name(f"sin_gathered_prefill_{seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_decoder_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_decoder_model.mesh_device), ) rot_mats = [cos_gathereds, sin_gathereds] @@ -218,7 +217,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, # cache_file_name=cache_name(f"attn_mask_prefill_{seq_len}"), - mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_decoder_model.mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_decoder_model.mesh_device, ) @@ -273,7 +272,7 @@ def run_test_LlamaDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(mesh_device), ) tt_LlamaDecoder_model = TtLlamaDecoder_galaxy( diff --git a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py index 62abd01a8ec..6346cd604eb 100644 --- a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py @@ -6,9 +6,8 @@ import math import torch import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper, shard_tensor_to_2d_mesh_mapper, replicate_tensor_to_mesh_mapper from models.demos.t3000.llama2_70b.tt.llama_common import ( - ShardTensor2dMesh, ConcatMesh2DToTensor, ) from models.demos.t3000.llama2_70b.tt.llama_common import ( @@ -91,7 +90,7 @@ def get_slice_mat(self): dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=1), + mesh_mapper=shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), ) def get_user_selection_mat(self): @@ -104,7 +103,7 @@ def get_user_selection_mat(self): dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(self.mesh_device), ) def init_kv_cache(self): @@ -133,7 +132,7 @@ def init_kv_cache(self): ttnn.as_tensor( lp, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(self.mesh_device), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat8_b, @@ -206,7 +205,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(2, 3), cluster_shape=self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, cluster_shape=self.cluster_shape, dims=(3, 2)), cache_file_name=self.cache_path / wqkv_cache_str, ) @@ -216,7 +215,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(3, 2), cluster_shape=self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, cluster_shape=self.cluster_shape, dims=(2, 3)), cache_file_name=self.cache_path / wo_cache_str, ) diff --git a/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py index 5c6e1c64ef2..2404016e361 100644 --- a/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py @@ -7,13 +7,11 @@ from models.demos.tg.llama3_70b.tt.llama_attention_galaxy import TtLlamaAttention_galaxy from models.demos.tg.llama3_70b.tt.llama_mlp_galaxy import TtLlamaMLP_galaxy -from models.demos.t3000.llama2_70b.tt.llama_common import ( - ShardTensor2dMesh, -) from models.demos.tg.llama3_70b.tt.llama_common import ( tt_sharded_distributed_rmsnorm, tt_distributed_rmsnorm, ) +from ttnn import shard_tensor_to_2d_mesh_mapper class TtLlamaDecoder_galaxy: @@ -112,7 +110,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, (2, None), self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, self.cluster_shape, (None, 2)), cache_file_name=self.cache_path / attn_norm_sharded_str, ) @@ -122,7 +120,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, (2, None), self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, self.cluster_shape, (None, 2)), cache_file_name=self.cache_path / ffn_norm_sharded_str, ) diff --git a/models/demos/tg/llama3_70b/tt/llama_embedding_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_embedding_galaxy.py index d76abe350f2..46d49eee42f 100644 --- a/models/demos/tg/llama3_70b/tt/llama_embedding_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_embedding_galaxy.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import ttnn -from models.demos.t3000.llama2_70b.tt.llama_common import ShardTensor2dMesh +from ttnn import shard_tensor_to_2d_mesh_mapper class TtLlamaEmbedding_galaxy: @@ -28,7 +28,7 @@ def __init__( layout=ttnn.ROW_MAJOR_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(3, None), cluster_shape=self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, mesh_shape=self.cluster_shape, dims=(None, 3)), cache_file_name=cache_path / embedding_cache_name, ) diff --git a/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py index c876713ce9f..068ad25c44d 100644 --- a/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py @@ -4,12 +4,13 @@ from typing import List import ttnn -from models.demos.t3000.llama2_70b.tt.llama_common import ShardTensor2dMesh, ConcatMesh2DToTensor +from models.demos.t3000.llama2_70b.tt.llama_common import ConcatMesh2DToTensor from models.utility_functions import nearest_32 from models.demos.tg.llama3_70b.tt.llama_common import tt_all_reduce, tt_composite_sharded_all_reduce from models.demos.t3000.falcon40b.tt.model_utils import ( matmul_2d_config_from_tensor_shapes as get_matmul_2d_config_from_tensor_shapes, ) +from ttnn import shard_tensor_to_2d_mesh_mapper class TtLlamaMLP_galaxy: @@ -79,7 +80,7 @@ def load_weights(self): device=self.mesh_device, # memory_config=self.w1_mem_config, # TODO: Reenable when DRAM-SHARDED PCC issues resolves memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(2, 3), cluster_shape=self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, mesh_shape=self.cluster_shape, dims=(3, 2)), cache_file_name=self.cache_path / w1_cache_str, ) @@ -90,7 +91,7 @@ def load_weights(self): device=self.mesh_device, # memory_config=self.mlp_config["W1_MEM_CONFIG"](self.mesh_device, self.cluster_shape), # TODO: Reenable when DRAM-SHARDED PCC issues resolves memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(2, 3), cluster_shape=self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, mesh_shape=self.cluster_shape, dims=(3, 2)), cache_file_name=self.cache_path / w3_cache_str, ) @@ -101,7 +102,7 @@ def load_weights(self): device=self.mesh_device, # memory_config=self.mlp_config["W2_MEM_CONFIG"](self.mesh_device), # TODO: Reenable when DRAM-SHARDED PCC issues resolves memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(3, 2), cluster_shape=self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, mesh_shape=self.cluster_shape, dims=(2, 3)), cache_file_name=self.cache_path / w2_cache_str, ) diff --git a/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py index b309872779b..d5abeb42724 100644 --- a/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py @@ -7,7 +7,7 @@ from tqdm import tqdm import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, shard_tensor_to_2d_mesh_mapper from models.demos.tg.llama3_70b.tt.llama_decoder_galaxy import TtLlamaDecoder_galaxy from models.demos.tg.llama3_70b.tt.llama_embedding_galaxy import TtLlamaEmbedding_galaxy from models.demos.t3000.llama2_70b.tt.llama_common import ( @@ -17,7 +17,6 @@ get_rot_transformation_mat, num_to_corerange, gather_cos_sin, - ShardTensor2dMesh, ) from models.demos.tg.llama3_70b.tt.llama_common import ( tt_all_reduce, @@ -25,6 +24,7 @@ tt_sharded_distributed_rmsnorm, tt_distributed_rmsnorm, ) +from ttnn import shard_tensor_to_2d_mesh_mapper def is_power_of_two(n): @@ -74,7 +74,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(mesh_device), ) logger.info("Creating Layers") @@ -142,7 +142,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(2, 3), cluster_shape=self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, mesh_shape=self.cluster_shape, dims=(3, 2)), cache_file_name=self.cache_path / lm_head_cache_str, ) @@ -152,7 +152,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, (2, None), self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, mesh_shape=self.cluster_shape, dims=(None, 2)), cache_file_name=self.cache_path / norm_sharded_cache_str, ) @@ -173,7 +173,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, attn_mask=None, layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(self.mesh_device), ) xs = self.tt_embd(x) @@ -226,7 +226,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, attn_mask=None, device=self.mesh_device, cache_file_name=cache_name(f"rot_mat_decode_galaxy_{start_pos}"), memory_config=ROT_MAT_MEMCFG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(self.mesh_device), ) attn_masks = None @@ -247,7 +247,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, attn_mask=None, # cache_file_name=cache_name(f"cos_gathered_prefill_galaxy_{start_pos}"), device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(self.mesh_device), ) sin_gathereds = ttnn.as_tensor( sin_gathered, @@ -256,7 +256,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, attn_mask=None, # cache_file_name=cache_name(f"sin_gathered_prefill_galaxy_{start_pos}"), device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(self.mesh_device), ) rot_mats = [cos_gathereds, sin_gathereds] @@ -269,7 +269,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, attn_mask=None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, # cache_file_name=cache_name(f"attn_mask_prefill_{seq_len}"), - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(self.mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=self.mesh_device, ) diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py index 98322a8f0c6..f55c4d3e784 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py @@ -21,7 +21,7 @@ import transformers from loguru import logger -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor PRETRAINED_MODEL_NAME = f"tiiuae/falcon-7b-instruct" @@ -104,7 +104,7 @@ def test_falcon_attention( seq_len, configuration.hidden_size, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=shard_dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=shard_dim), ) position_ids = create_position_ids(llm_mode, kv_cache_len) attention_mask, tt_attention_mask = create_attention_mask( @@ -116,7 +116,7 @@ def test_falcon_attention( configuration.num_attention_heads, kv_cache_len, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=shard_dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=shard_dim), ) layer_past, tt_layer_past = create_kv_cache( llm_mode, @@ -125,7 +125,7 @@ def test_falcon_attention( kv_cache_len, configuration, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) pytorch_out, pytorch_layer_present = torch_model( @@ -144,7 +144,7 @@ def test_falcon_attention( tt_cache_path=get_tt_cache_path(f"{model_name}"), device=mesh_device, base_file_name=get_model_prefix(), - weights_mesh_mapper=ReplicateTensorToMesh(mesh_device), + weights_ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), ) tt_FalconAttention_model = TtFalconAttention( diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py index 1de4f9a058c..267259df1ab 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py @@ -22,7 +22,7 @@ ) from loguru import logger -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor PRETRAINED_MODEL_NAME = f"tiiuae/falcon-7b-instruct" @@ -111,7 +111,7 @@ def test_falcon_causal_lm( kv_cache_len, configuration, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) tt_layer_past += (tt_current_layer_past,) attention_mask = None @@ -127,7 +127,7 @@ def test_falcon_causal_lm( kv_cache_len, configuration, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) past_key_values += (current_layer_past,) tt_layer_past += (tt_current_layer_past,) @@ -153,7 +153,7 @@ def convert_to_ttnn(model, name): model_config, tt_cache_path=get_tt_cache_path(f"{model_version}"), device=mesh_device, - weights_mesh_mapper=ReplicateTensorToMesh(mesh_device), + weights_ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), convert_to_ttnn=convert_to_ttnn, ) @@ -327,7 +327,7 @@ def test_t3k_falcon_causal_lm_with_trace( kv_cache_len, configuration, t3k_mesh_device, - mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0), ) tt_layer_past += (tt_current_layer_past,) attention_mask = None @@ -343,7 +343,7 @@ def test_t3k_falcon_causal_lm_with_trace( kv_cache_len, configuration, t3k_mesh_device, - mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0), ) past_key_values += (current_layer_past,) tt_layer_past += (tt_current_layer_past,) @@ -369,7 +369,7 @@ def convert_to_ttnn(model, name): model_config, tt_cache_path=get_tt_cache_path(f"{model_version}"), device=t3k_mesh_device, - weights_mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + weights_ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ), convert_to_ttnn=convert_to_ttnn, ) @@ -393,7 +393,7 @@ def convert_to_ttnn(model, name): torch.full(scalar_shape, layer.self_attn.scalar), device=t3k_mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # TODO: Generate embeddings and attention_mask on device tt_embeddings, tt_attention_mask = tt_FalconCausalLM.model_preprocessing( diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py index 40676143a5b..bb247fd88a4 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py @@ -21,7 +21,7 @@ ) from loguru import logger -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor PRETRAINED_MODEL_NAME = f"tiiuae/falcon-7b-instruct" @@ -102,7 +102,7 @@ def test_falcon_decoder( seq_len, configuration.hidden_size, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=shard_dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=shard_dim), ) position_ids = create_position_ids(llm_mode, kv_cache_len) attention_mask, tt_attention_mask = create_attention_mask( @@ -114,7 +114,7 @@ def test_falcon_decoder( configuration.num_attention_heads, kv_cache_len, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=shard_dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=shard_dim), ) layer_past, tt_layer_past = create_kv_cache( llm_mode, @@ -123,7 +123,7 @@ def test_falcon_decoder( kv_cache_len, configuration, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) pytorch_out, pytorch_layer_present = torch_model( @@ -142,7 +142,7 @@ def test_falcon_decoder( tt_cache_path=get_tt_cache_path(f"{model_name}"), device=mesh_device, base_file_name=get_model_prefix(), - weights_mesh_mapper=ReplicateTensorToMesh(mesh_device), + weights_ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), ) tt_FalconDecoder_model = TtFalconDecoderLayer( diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py index c118f9a9b15..af8e7a6beda 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py @@ -11,7 +11,7 @@ from models.demos.ttnn_falcon7b.tt.common import create_custom_preprocessor, strip_state_dict_prefix from ttnn.model_preprocessing import preprocess_model_parameters from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor import transformers from loguru import logger @@ -88,7 +88,7 @@ def test_falcon_mlp( tt_cache_path=get_tt_cache_path(f"{model_name}"), device=mesh_device, base_file_name=get_model_prefix(), - weights_mesh_mapper=ReplicateTensorToMesh(mesh_device), + weights_mesh_mapper=replicate_tensor_to_mesh_mapper(mesh_device), ), ) @@ -98,7 +98,7 @@ def test_falcon_mlp( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) ttnn_output = ttnn_model(ttnn_input) diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_model.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_model.py index 31c4d04816a..ebdf747aa5d 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_model.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_model.py @@ -23,7 +23,7 @@ ) from loguru import logger -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor PRETRAINED_MODEL_NAME = f"tiiuae/falcon-7b-instruct" @@ -114,7 +114,7 @@ def test_falcon_model( kv_cache_len, configuration, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) tt_layer_past += (tt_current_layer_past,) attention_mask = None @@ -130,7 +130,7 @@ def test_falcon_model( kv_cache_len, configuration, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) past_key_values += (current_layer_past,) tt_layer_past += (tt_current_layer_past,) @@ -157,7 +157,7 @@ def convert_to_ttnn(model, name): tt_cache_path=get_tt_cache_path(f"{model_version}"), device=mesh_device, base_file_name=get_model_prefix(), - weights_mesh_mapper=ReplicateTensorToMesh(mesh_device), + weights_ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), convert_to_ttnn=convert_to_ttnn, ) diff --git a/models/demos/ttnn_falcon7b/tt/falcon_model.py b/models/demos/ttnn_falcon7b/tt/falcon_model.py index f27f1122947..08ad282b6fe 100644 --- a/models/demos/ttnn_falcon7b/tt/falcon_model.py +++ b/models/demos/ttnn_falcon7b/tt/falcon_model.py @@ -10,7 +10,7 @@ from models.demos.ttnn_falcon7b.tt.falcon_decoder import TtFalconDecoderLayer from models.demos.ttnn_falcon7b.tt.common import create_attention_mask -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor class TtFalconModelShared: @@ -58,7 +58,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token mesh_mapper = None else: shard_dim = 2 if llm_mode == "decode" else 0 - mesh_mapper = ShardTensorToMesh(self.device, dim=shard_dim) + mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(self.device, dim=shard_dim) # Generate input and attention_mask --------------------------------------------- if llm_mode == "prefill": diff --git a/models/demos/ttnn_resnet/tests/resnet50_test_infra.py b/models/demos/ttnn_resnet/tests/resnet50_test_infra.py index 2866840ad8d..842a3916666 100644 --- a/models/demos/ttnn_resnet/tests/resnet50_test_infra.py +++ b/models/demos/ttnn_resnet/tests/resnet50_test_infra.py @@ -260,8 +260,8 @@ def __init__( def get_mesh_mappers(self, device): is_mesh_device = isinstance(device, ttnn.MeshDevice) if is_mesh_device: - inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0) - weights_mesh_mapper = None # ttnn.ReplicateTensorToMesh(device) causes unnecessary replication/takes more time on the first pass + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(device, dim=0) + weights_mesh_mapper = None # ttnn.replicate_tensor_to_mesh_mapper(device) causes unnecessary replication/takes more time on the first pass output_mesh_composer = ttnn.ConcatMeshToTensor(device, dim=0) else: inputs_mesh_mapper = None diff --git a/models/demos/wormhole/bert_tiny/demo/demo.py b/models/demos/wormhole/bert_tiny/demo/demo.py index fe403df2338..0d4834d3cd7 100644 --- a/models/demos/wormhole/bert_tiny/demo/demo.py +++ b/models/demos/wormhole/bert_tiny/demo/demo.py @@ -70,10 +70,10 @@ def run_bert_question_and_answering_inference( profiler.start(f"preprocessing_parameter") mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 batch_size = 16 if mesh_device_flag else 8 - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: pytorch_model, device=mesh_device, @@ -189,9 +189,9 @@ def run_bert_question_and_answering_inference_squad_v2( mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 batch_size = 16 if mesh_device_flag else 8 - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: pytorch_model, device=mesh_device, diff --git a/models/demos/wormhole/bert_tiny/tests/test_performance.py b/models/demos/wormhole/bert_tiny/tests/test_performance.py index bcc438ea198..92aa50ebc60 100644 --- a/models/demos/wormhole/bert_tiny/tests/test_performance.py +++ b/models/demos/wormhole/bert_tiny/tests/test_performance.py @@ -52,9 +52,9 @@ def test_perf_bert_tiny( torch_position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) torch_attention_mask = torch.zeros(1, sequence_size) - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: pytorch_model, device=mesh_device, @@ -73,7 +73,7 @@ def test_perf_bert_tiny( ttnn_attention_mask = ttnn.from_torch( torch_attention_mask, dtype=ttnn.bfloat16, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), device=mesh_device, ) durations = [] diff --git a/models/demos/wormhole/distilbert/demo/demo.py b/models/demos/wormhole/distilbert/demo/demo.py index dfd89c18939..71b72c9e295 100644 --- a/models/demos/wormhole/distilbert/demo/demo.py +++ b/models/demos/wormhole/distilbert/demo/demo.py @@ -50,13 +50,13 @@ def run_distilbert_question_and_answering_inference( HF_model.eval() tt_model_name = f"ttnn_{model_name}_optimized" - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) profiler.start(f"preprocessing_parameter") - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( model_name=tt_model_name, initialize_model=lambda: HF_model, @@ -191,11 +191,11 @@ def run_distilbert_question_and_answering_inference_squad_v2( tt_model_name = f"ttnn_{model_name}_optimized" - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( model_name=tt_model_name, initialize_model=lambda: HF_model, diff --git a/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py b/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py index a3fad4aa54c..855562661a7 100644 --- a/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py +++ b/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py @@ -67,11 +67,11 @@ def test_performance_distilbert_for_qa( ) tt_model_name = f"ttnn_{model_name}_optimized" - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) profiler.start(f"preprocessing_parameter") - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( model_name=tt_model_name, initialize_model=lambda: HF_model, diff --git a/models/experimental/functional_unet/tests/test_unet_bottleneck.py b/models/experimental/functional_unet/tests/test_unet_bottleneck.py index c78de65acaf..80983c87af5 100644 --- a/models/experimental/functional_unet/tests/test_unet_bottleneck.py +++ b/models/experimental/functional_unet/tests/test_unet_bottleneck.py @@ -52,8 +52,8 @@ def test_unet_bottleneck_multi_device( if not is_n300_with_eth_dispatch_cores(mesh_device): pytest.skip("Test is only valid for N300") - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) torch_input, ttnn_input = create_unet_input_tensors(batch, groups) diff --git a/models/experimental/functional_unet/tests/test_unet_downblock.py b/models/experimental/functional_unet/tests/test_unet_downblock.py index 1ea2633b2ad..6e8c4017dc1 100644 --- a/models/experimental/functional_unet/tests/test_unet_downblock.py +++ b/models/experimental/functional_unet/tests/test_unet_downblock.py @@ -81,8 +81,8 @@ def test_unet_downblock_multi_device( if not is_n300_with_eth_dispatch_cores(mesh_device): pytest.skip("Test is only valid for N300") - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) torch_input, ttnn_input = create_unet_input_tensors(batch, groups) diff --git a/models/experimental/functional_unet/tests/test_unet_multi_device.py b/models/experimental/functional_unet/tests/test_unet_multi_device.py index f611ce3999c..f886a24c8ff 100644 --- a/models/experimental/functional_unet/tests/test_unet_multi_device.py +++ b/models/experimental/functional_unet/tests/test_unet_multi_device.py @@ -28,8 +28,8 @@ def test_unet_multi_device_model(batch, groups, mesh_device, use_program_cache, if not is_n300_with_eth_dispatch_cores(mesh_device) and not is_t3k_with_eth_dispatch_cores(mesh_device): pytest.skip("Test is only valid for N300 or T3000") - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) torch_input, ttnn_input = create_unet_input_tensors(batch, groups) diff --git a/models/experimental/functional_unet/tests/test_unet_trace.py b/models/experimental/functional_unet/tests/test_unet_trace.py index 17211134106..0e484be93f4 100644 --- a/models/experimental/functional_unet/tests/test_unet_trace.py +++ b/models/experimental/functional_unet/tests/test_unet_trace.py @@ -231,8 +231,8 @@ def test_unet_trace_2cq_multi_device( if not is_n300_with_eth_dispatch_cores(mesh_device) and not is_t3k_with_eth_dispatch_cores(mesh_device): pytest.skip("Test is only valid for N300 or T3000") - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) torch_input, ttnn_input = create_unet_input_tensors(batch, groups) @@ -489,8 +489,8 @@ def test_unet_trace_2cq_same_io_multi_device( if not is_n300_with_eth_dispatch_cores(mesh_device) and not is_t3k_with_eth_dispatch_cores(mesh_device): pytest.skip("Test is only valid for N300 or T3000") - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) torch_input, ttnn_input = create_unet_input_tensors(batch, groups) diff --git a/models/experimental/functional_unet/tests/test_unet_upblock.py b/models/experimental/functional_unet/tests/test_unet_upblock.py index 9c623d1c840..dda519db358 100644 --- a/models/experimental/functional_unet/tests/test_unet_upblock.py +++ b/models/experimental/functional_unet/tests/test_unet_upblock.py @@ -98,8 +98,8 @@ def test_unet_upblock_multi_device( if not is_n300_with_eth_dispatch_cores(mesh_device): pytest.skip("Test is only valid for N300") - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) torch_input, ttnn_input = create_unet_input_tensors(batch, groups) diff --git a/models/experimental/grok/demo/demo.py b/models/experimental/grok/demo/demo.py index 1a8a477506d..b41b689e237 100644 --- a/models/experimental/grok/demo/demo.py +++ b/models/experimental/grok/demo/demo.py @@ -17,7 +17,7 @@ os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml" import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.experimental.grok.tt.grok_common import ( prepare_inputs_ttnn, prepare_rotation_mat_ttnn, @@ -85,7 +85,7 @@ def preprocess_inputs(input_prompts, tokenizer, model_args, dtype, instruct, mes device=mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) for i in range(max_prompt_len) ] @@ -95,7 +95,7 @@ def preprocess_inputs(input_prompts, tokenizer, model_args, dtype, instruct, mes device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) for i in range(max_prompt_len) ] diff --git a/models/experimental/grok/tests/test_grok_decoder.py b/models/experimental/grok/tests/test_grok_decoder.py index aa3b1c6ce00..b2f7618215d 100644 --- a/models/experimental/grok/tests/test_grok_decoder.py +++ b/models/experimental/grok/tests/test_grok_decoder.py @@ -18,7 +18,7 @@ from models.experimental.grok.reference.model import DecoderLayer from models.experimental.grok.tt.model_config import TtModelArgs from models.utility_functions import comp_pcc, comp_allclose -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor @pytest.mark.timeout(500 * 8) diff --git a/models/experimental/grok/tests/test_grok_mlp.py b/models/experimental/grok/tests/test_grok_mlp.py index d5a154ce8fe..adc730df7fc 100644 --- a/models/experimental/grok/tests/test_grok_mlp.py +++ b/models/experimental/grok/tests/test_grok_mlp.py @@ -13,7 +13,7 @@ os.environ["GROK_CACHE_PATH"] = "/mnt/MLPerf/tt_dnn-models/Grok/Grok-1/" import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.experimental.grok.tt.grok_mlp import TtGrokMLP from models.experimental.grok.reference.model import MoeMLP, RMSNorm @@ -70,7 +70,7 @@ def test_grok_mlp_inference(t3k_mesh_device, use_program_cache, reset_seeds): dtype=ttnn.bfloat16, memory_config=ttnn.L1_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) tt_output = tt_model(tt_input) diff --git a/models/experimental/grok/tests/test_grok_moe.py b/models/experimental/grok/tests/test_grok_moe.py index ee6c77e6553..c36d85528a3 100644 --- a/models/experimental/grok/tests/test_grok_moe.py +++ b/models/experimental/grok/tests/test_grok_moe.py @@ -13,7 +13,7 @@ os.environ["GROK_CACHE_PATH"] = "/mnt/MLPerf/tt_dnn-models/Grok/Grok-1/" import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.experimental.grok.tt.grok_mlp import TtGrokMLP from models.experimental.grok.tt.grok_moe import TtMoeLayer @@ -86,7 +86,7 @@ def test_grok_moe_inference(t3k_mesh_device, use_program_cache, reset_seeds): dtype=ttnn.bfloat16, memory_config=ttnn.L1_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # Run TT model tt_out = tt_model(tt_decode_input) diff --git a/models/experimental/grok/tests/test_grok_rms_norm.py b/models/experimental/grok/tests/test_grok_rms_norm.py index 5f220b9eb2b..e080ec5bebe 100644 --- a/models/experimental/grok/tests/test_grok_rms_norm.py +++ b/models/experimental/grok/tests/test_grok_rms_norm.py @@ -13,7 +13,7 @@ os.environ["GROK_CACHE_PATH"] = "/mnt/MLPerf/tt_dnn-models/Grok/Grok-1/" import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.experimental.grok.tt.grok_rms_norm import TtRMSNorm, TtRMSNormSharded from models.experimental.grok.reference.model import RMSNorm @@ -55,7 +55,7 @@ def test_grok_rms_norm_inference(t3k_mesh_device, use_program_cache, reset_seeds device=t3k_mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) tt_output = tt_model(tt_input) @@ -104,7 +104,7 @@ def test_grok_rms_norm_sharded_inference(t3k_mesh_device, use_program_cache, res device=t3k_mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) tt_output = tt_model(tt_input) diff --git a/models/experimental/grok/tt/grok_attention.py b/models/experimental/grok/tt/grok_attention.py index 794c6daa784..7b5d8158d6f 100644 --- a/models/experimental/grok/tt/grok_attention.py +++ b/models/experimental/grok/tt/grok_attention.py @@ -5,7 +5,7 @@ import torch import ttnn from models.utility_functions import nearest_32 -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.experimental.grok.tt.grok_common import LightweightModule @@ -75,7 +75,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtype): .unsqueeze(0) .unsqueeze(0), device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=-1), + ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=self.model_config["ATTN_WEIGHTS_MEMCFG"], layout=self.model_config["ATTN_W_LAYOUT_TILE"], @@ -91,7 +91,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtype): .unsqueeze(0) .unsqueeze(0), device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), dtype=self.dtype, memory_config=self.model_config["ATTN_WEIGHTS_MEMCFG"], layout=self.model_config["ATTN_W_LAYOUT_TILE"], @@ -119,7 +119,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtype): ttnn.as_tensor( lp, device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=0), + ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=0), dtype=ttnn.bfloat8_b, layout=self.model_config["ATTN_W_LAYOUT_TILE"], memory_config=self.model_config["ATTN_CACHE_WEIGHTS_MEMCFG"], diff --git a/models/experimental/grok/tt/grok_common.py b/models/experimental/grok/tt/grok_common.py index 08181844cfa..b6e05a0a84f 100644 --- a/models/experimental/grok/tt/grok_common.py +++ b/models/experimental/grok/tt/grok_common.py @@ -5,7 +5,7 @@ from loguru import logger import torch import ttnn -from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from models.utility_functions import nearest_32 @@ -78,7 +78,7 @@ def prepare_inputs_ttnn(x_bsh, hidden_size, current_pos, mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) # Attention mask @@ -94,7 +94,7 @@ def prepare_inputs_ttnn(x_bsh, hidden_size, current_pos, mesh_device): dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) ATTN_MASK_MEMCFG = ttnn.create_sharded_memory_config( @@ -121,7 +121,7 @@ def prepare_rotation_mat_ttnn(head_dim, max_seq_len, mesh_device): device=mesh_device, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) for rot_mat_i in rot_mat ] @@ -163,7 +163,7 @@ def cache_attention(mesh_device, state_dict, model_args, rot_emb_matrix_list, se layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.L1_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_attn = TtGrokAttention( @@ -185,7 +185,7 @@ def cache_attention(mesh_device, state_dict, model_args, rot_emb_matrix_list, se dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) ATTN_MASK_MEMCFG = ttnn.create_sharded_memory_config( diff --git a/models/experimental/grok/tt/grok_mlp.py b/models/experimental/grok/tt/grok_mlp.py index db3cea55d8a..9771e835fa0 100644 --- a/models/experimental/grok/tt/grok_mlp.py +++ b/models/experimental/grok/tt/grok_mlp.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper from models.experimental.grok.tt.grok_common import LightweightModule @@ -36,7 +36,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtypes): torch_weight(name), dtype=dtypes[name], device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=0), + ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=0), layout=self.model_config["MLP_W_LAYOUT_TILE"], memory_config=self.model_config["MLP_WEIGHTS_MEMCFG"], cache_file_name=cache_name(name), diff --git a/models/experimental/grok/tt/grok_model.py b/models/experimental/grok/tt/grok_model.py index 98e7c18b0b5..8a97ebf400b 100644 --- a/models/experimental/grok/tt/grok_model.py +++ b/models/experimental/grok/tt/grok_model.py @@ -61,7 +61,7 @@ def __init__( dtype=ttnn.bfloat16, memory_config=self.model_config["OUTPUT_WEIGHTS_MEMCFG"], cache_file_name=output_cache_name, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), ) self.compute_kernel = self.args.get_compute_kernel_output_config() diff --git a/models/experimental/grok/tt/grok_moe.py b/models/experimental/grok/tt/grok_moe.py index 82526c3292a..ede5b375701 100644 --- a/models/experimental/grok/tt/grok_moe.py +++ b/models/experimental/grok/tt/grok_moe.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper from models.experimental.grok.tt.grok_common import LightweightModule from models.experimental.grok.scripts.tlog import tlog, tlog_mesh_device @@ -34,7 +34,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): memory_config=self.model_config["GATE_WEIGHTS_MEMCFG"], cache_file_name=cache_name, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) self.num_devices = 8 @@ -48,14 +48,14 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) self.expert_mask_11BB = ttnn.from_torch( torch.cat([torch.full((1, 1, 32, 32), fill_value=i + 1) for i in range(8)], dim=3), dtype=ttnn.uint16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), ) top8_mask = torch.full((1, 1, 32, 64), fill_value=torch.finfo(torch.float).min) top8_mask[:, :, :, 1:9] = 0.0 @@ -64,7 +64,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) top2_mask = torch.full((1, 1, 32, 32), fill_value=0.0) @@ -74,7 +74,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) self.softmax_compute_config = ttnn.WormholeComputeKernelConfig( math_fidelity=ttnn.MathFidelity.HiFi4, math_approx_mode=False, fp32_dest_acc_en=True, packer_l1_acc=True diff --git a/models/experimental/grok/tt/grok_rms_norm.py b/models/experimental/grok/tt/grok_rms_norm.py index b337ab81381..08166210307 100644 --- a/models/experimental/grok/tt/grok_rms_norm.py +++ b/models/experimental/grok/tt/grok_rms_norm.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.experimental.grok.tt.grok_common import LightweightModule @@ -43,7 +43,7 @@ def __init__( layout=self.model_config["NORM_W_LAYOUT_TILE"], memory_config=self.model_config["NORM_WEIGHTS_MEMCFG"], cache_file_name=cache_name, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: @@ -88,7 +88,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, memory_config=self.model_config["NORM_WEIGHTS_MEMCFG"], cache_file_name=cache_name, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) def forward(self, x: ttnn.Tensor, out_sharded=False) -> ttnn.Tensor: diff --git a/tech_reports/CNNs/cnn_optimizations.md b/tech_reports/CNNs/cnn_optimizations.md index 5afd41adb32..3d3dfebd36d 100644 --- a/tech_reports/CNNs/cnn_optimizations.md +++ b/tech_reports/CNNs/cnn_optimizations.md @@ -196,8 +196,8 @@ Combining these two features should For more details on tracing and multi-CQs, c Throughput can be improved if multiple chips are availible by replicating the CNN across each chip. For our UNet model, we replicate across the outermost dimension: ```python -inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) # Shard input tensor on dimension 0 across each device -weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) # Replicate weights across all devices +inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) # Shard input tensor on dimension 0 across each device +weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) # Replicate weights across all devices output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) # Map multi-device tensor back to single host tensor ``` diff --git a/tech_reports/LLMs/llms.md b/tech_reports/LLMs/llms.md index 0342e432399..bf49b9d5326 100644 --- a/tech_reports/LLMs/llms.md +++ b/tech_reports/LLMs/llms.md @@ -913,7 +913,7 @@ for i, split_size in enumerate(split_sizes): ttnn.as_tensor( combined_split, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), layout=ttnn.TILE_LAYOUT, dtype=dtype, memory_config=memory_config, @@ -1206,7 +1206,7 @@ mesh_tensor_sharded = ttnn.from_torch( torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), ) # Convert to ttnn.Tensor, tilize and move onto mesh_device (2x4 devices) by replication @@ -1215,7 +1215,7 @@ mesh_tensor_replicated = ttnn.from_torch( torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) ``` diff --git a/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md b/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md index 862921f5d33..28de54fa44f 100644 --- a/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md +++ b/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md @@ -140,7 +140,7 @@ torch_tensor[..., 32:64] = 2.0 # Convert to ttnn.Tensor; MeshTensor holds buffers to two shards in host-memory mesh_tensor = ttnn.from_torch( torch_tensor, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), layout=ttnn.TILE_LAYOUT, ) ``` @@ -306,7 +306,7 @@ mesh_tensor = ttnn.from_torch( torch_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), ) # Execute All-Gather on the tensor; `num_links=1` specifies the number of ethernet links to use @@ -338,7 +338,7 @@ mesh_tensor = ttnn.from_torch( torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), ) # Execute Line All-Gather on the tensor @@ -452,7 +452,7 @@ torch_output = model.forward(torch_hidden_states) mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(y=1, x=4)) # Shard input activations on batch dimension to devices in the mesh -with ttnn.distribute(ttnn.ShardTensorToMesh(mesh_device, dim=0)): +with ttnn.distribute(ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0)): hidden_states = ttnn.from_torch( torch_hidden_states, dtype=ttnn.bfloat16, @@ -461,7 +461,7 @@ with ttnn.distribute(ttnn.ShardTensorToMesh(mesh_device, dim=0)): ) # Replicate model parameters to devices in the mesh -with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): +with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = ttnn.model_preprocessing.preprocess_model_parameters( initialize_model=lambda: model, device=mesh_device, @@ -539,7 +539,7 @@ mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2,4)) # Initialize input activations on all devices in the mesh # Alternatively, we can shard the input activations on the height dimension and # subsequently invoke all-gather on the height dimension to form a complete tensor per device. -with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): +with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): hidden_states = ttnn.from_torch( torch_hidden_states, dtype=ttnn.bfloat16, @@ -548,7 +548,7 @@ with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): ) # Shard model parameters on width dimension to devices in the mesh -with ttnn.distribute(ttnn.ShardTensorToMesh(t3k_mesh_device, dim=-1)): +with ttnn.distribute(ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=-1)): parameters = ttnn.model_preprocessing.preprocess_model_parameters( initialize_model=lambda: model, device=t3k_mesh_device, diff --git a/tests/sweep_framework/sweeps/ccl/line_all_gather.py b/tests/sweep_framework/sweeps/ccl/line_all_gather.py index b30cd0f9f1e..0440aa17d64 100644 --- a/tests/sweep_framework/sweeps/ccl/line_all_gather.py +++ b/tests/sweep_framework/sweeps/ccl/line_all_gather.py @@ -12,7 +12,7 @@ from loguru import logger from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from tests.ttnn.unit_tests.operations.ccl.test_all_gather import is_unsupported_case -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper # Override the default timeout in seconds for hang detection. TIMEOUT = 30 @@ -104,7 +104,7 @@ def run( input_tensor = torch.rand(input_shape).bfloat16() ttnn_tensor = ttnn.from_torch( - input_tensor, tile=ttnn.Tile(tile), mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=dim) + input_tensor, tile=ttnn.Tile(tile), ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=dim) ) input_tensor_mesh = ttnn.to_device(ttnn_tensor, t3k_mesh_device) diff --git a/tests/ttnn/distributed/test_data_parallel_example.py b/tests/ttnn/distributed/test_data_parallel_example.py index fb5f59568c0..cd34afa2572 100644 --- a/tests/ttnn/distributed/test_data_parallel_example.py +++ b/tests/ttnn/distributed/test_data_parallel_example.py @@ -37,7 +37,7 @@ def test_data_parallel_falcon_mlp(mesh_device): torch_output = model.forward(torch_hidden_states) # Shard input activations on batch dimension to devices in the mesh - with ttnn.distribute(mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0)): + with ttnn.distribute(mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0)): hidden_states = ttnn.from_torch( torch_hidden_states, dtype=ttnn.bfloat16, @@ -46,7 +46,7 @@ def test_data_parallel_falcon_mlp(mesh_device): ) # Replicate model parameters to devices in the mesh - with ttnn.distribute(mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: model, device=mesh_device, diff --git a/tests/ttnn/distributed/test_data_parallel_example_TG.py b/tests/ttnn/distributed/test_data_parallel_example_TG.py index 66b8bcacb5b..35e6a6f699e 100644 --- a/tests/ttnn/distributed/test_data_parallel_example_TG.py +++ b/tests/ttnn/distributed/test_data_parallel_example_TG.py @@ -39,7 +39,7 @@ def test_data_parallel_falcon_mlp(mesh_device): torch_output = model.forward(torch_hidden_states) # Shard input activations on batch dimension to devices in the mesh - with ttnn.distribute(ttnn.ShardTensorToMesh(mesh_device, dim=0)): + with ttnn.distribute(ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0)): hidden_states = ttnn.from_torch( torch_hidden_states, dtype=ttnn.bfloat16, @@ -48,7 +48,7 @@ def test_data_parallel_falcon_mlp(mesh_device): ) # Replicate model parameters to devices in the mesh - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: model, device=mesh_device, diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py new file mode 100644 index 00000000000..7ac3a22b677 --- /dev/null +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -0,0 +1,375 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn +from loguru import logger +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc +from models.utility_functions import nearest_32 + + +@pytest.mark.parametrize( + "mesh_device", + [ + 32, + ], + indirect=True, +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): + torch.manual_seed(1234) + + mapper = ttnn.ReplicateTensorToMesh(mesh_device) + + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) + else: + torch_tensor = torch.randn(1, 1, 32, 256) + replicated_tensors = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=mapper, + device=mesh_device, + ) + + out_tensors = ttnn.get_device_tensors(replicated_tensors) + + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(out_tensors[0]), pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_direct_shard_to_tensor_mesh(mesh_device, dtype): + torch.manual_seed(1234) + + mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) + + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) + else: + torch_tensor = torch.randn(1, 1, 32, 256) + sharded_tensor = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=mapper, + device=mesh_device, + ) + + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(sharded_tensor), pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize( + "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] +) +@pytest.mark.parametrize( + "M, K, N", + [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_direct_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): + torch.manual_seed(1234) + + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 32767, (1, 1, M, K)) + else: + torch_tensor = torch.randn(1, 1, M, K) + + core_grid = ttnn.CoreGrid(y=1, x=8) + + # If K < N it's FF1-like test case, else FF2-like test case + shard_dim = (0, 3) if K < N else (3, 0) + + K = K // mesh_shape[1] if K < N else K // mesh_shape[0] + N = N // mesh_shape[0] if K < N else N // mesh_shape[1] + + mapper = ttnn.ShardTensorTo2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + + sharded_tensor = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=mapper, + device=mesh_device, + ) + + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(sharded_tensor), pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_direct_concat_to_tensor_mesh(mesh_device, dtype): + torch.manual_seed(1234) + + mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) + + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) + else: + torch_tensor = torch.randn(1, 1, 32, 256) + sharded_tensor = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=mapper, + device=mesh_device, + ) + + composer = ttnn.CppConcatMeshToTensor(dim=3) + + concat_tensor = ttnn.to_torch(sharded_tensor, mesh_composer=composer) + + out_pass, out_pcc = comp_pcc(torch_tensor, concat_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize( + "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] +) +@pytest.mark.parametrize( + "M, K, N", + [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_direct_concat2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): + torch.manual_seed(1234) + + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 32767, (1, 1, M, K)) + else: + torch_tensor = torch.randn(1, 1, M, K) + + core_grid = ttnn.CoreGrid(y=1, x=8) + + # If K < N it's FF1-like test case, else FF2-like test case + shard_dim = (0, 3) if K < N else (3, 0) + concat_dim = (3, 1) if K < N else (1, 3) + + K = K // mesh_shape[1] if K < N else K // mesh_shape[0] + N = N // mesh_shape[0] if K < N else N // mesh_shape[1] + + mapper = ttnn.ShardTensorTo2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + + sharded_tensor = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=mapper, + device=mesh_device, + ) + + composer = ttnn.CppConcat2dMeshToTensor(mesh_device, dims=concat_dim) + + concat_tensor = ttnn.to_torch(sharded_tensor, mesh_composer=composer) + + out_pass, out_pcc = comp_pcc(torch_tensor, concat_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize( + "mesh_device", + [ + 32, + ], + indirect=True, +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_replicate_to_tensor_mesh(mesh_device, dtype): + torch.manual_seed(1234) + + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) + else: + torch_tensor = torch.randn(1, 1, 32, 256) + to_repl = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + ) + + mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) + replicated_tensors = ttnn.distribute_tensor(to_repl, mapper, mesh_device) + out_tensors = ttnn.get_device_tensors(replicated_tensors) + + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_shard_to_tensor_mesh(mesh_device, dtype): + torch.manual_seed(1234) + + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) + else: + torch_tensor = torch.randn(1, 1, 32, 256) + to_shard = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + ) + + mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) + + shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) + + out_tensor = ttnn.aggregate_as_tensor(shards) + + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(out_tensor), pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_concat_to_tensor(mesh_device, dtype): + torch.manual_seed(1234) + + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) + else: + torch_tensor = torch.randn(1, 1, 32, 256) + to_shard = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + ) + + mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) + + composer = ttnn.concat_mesh_to_tensor_composer(dim=3) + + out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) + + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(out_tensor), pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_concat_slice_to_tensor(mesh_device, dtype): + torch.manual_seed(1234) + + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) + else: + torch_tensor = torch.randn(1, 1, 32, 256) + to_shard = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + ) + + mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) + + composer = ttnn.concat_mesh_to_tensor_composer(dim=3) + + sharded_tensor = ttnn.distribute_tensor(to_shard, mapper, mesh_device) + + shards = ttnn.get_device_tensors(sharded_tensor) + + out_tensor = ttnn.aggregate_tensor(shards, composer) + + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(out_tensor), pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize( + "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] +) +@pytest.mark.parametrize( + "M, K, N", + [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): + torch.manual_seed(1234) + + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 32767, (1, 1, M, K)) + else: + torch_tensor = torch.randn(1, 1, M, K) + core_grid = ttnn.CoreGrid(y=1, x=8) + + # If K < N it's FF1-like test case, else FF2-like test case + shard_dim = (0, 3) if K < N else (3, 0) + + K = K // mesh_shape[1] if K < N else K // mesh_shape[0] + N = N // mesh_shape[0] if K < N else N // mesh_shape[1] + + to_shard = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=mesh_device, + ) + + mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + + shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) + + sharded_tensor = ttnn.aggregate_as_tensor(shards) + + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(sharded_tensor), pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize( + "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] +) +@pytest.mark.parametrize( + "M, K, N", + [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): + torch.manual_seed(1234) + + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 32767, (1, 1, M, K)) + else: + torch_tensor = torch.randn(1, 1, M, K) + core_grid = ttnn.CoreGrid(y=1, x=8) + + # If K < N it's FF1-like test case, else FF2-like test case + shard_dim = (0, 3) if K < N else (3, 0) + concat_dim = (3, 1) if K < N else (1, 3) + + K = K // mesh_shape[1] if K < N else K // mesh_shape[0] + N = N // mesh_shape[0] if K < N else N // mesh_shape[1] + + to_shard = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=mesh_device, + ) + + mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + + composer = ttnn.concat_2d_mesh_to_tensor_composer(mesh_device, dims=concat_dim) + + out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) + + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(out_tensor), pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass diff --git a/tests/ttnn/distributed/test_multidevice_TG.py b/tests/ttnn/distributed/test_multidevice_TG.py index 82b4381c4aa..4e31e2a10c6 100644 --- a/tests/ttnn/distributed/test_multidevice_TG.py +++ b/tests/ttnn/distributed/test_multidevice_TG.py @@ -11,9 +11,9 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc from ttnn import ( - ShardTensorToMesh, + shard_tensor_to_mesh_mapper, ShardTensor2dMesh, - ReplicateTensorToMesh, + replicate_tensor_to_mesh_mapper, ConcatMeshToTensor, ConcatMesh2dToTensor, MeshToTensor, @@ -39,14 +39,14 @@ def test_galaxy_matmul_1d_fracture(mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) weights = ttnn.from_torch( weights_pt, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), ) gt = act_pt @ weights_pt @@ -362,7 +362,7 @@ def test_galaxy_eltwise_add(M, N, mesh_device): layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=LN_OUTPUT_MEMCFG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) attn_output = ttnn.from_torch( @@ -371,7 +371,7 @@ def test_galaxy_eltwise_add(M, N, mesh_device): layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=LN_OUTPUT_MEMCFG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) gt = residual_pt + attn_output_pt @@ -420,7 +420,7 @@ def test_galaxy_attn_matmul(M, N, head_dim, num_heads, mesh_shape, mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) weights = ttnn.from_torch( @@ -536,7 +536,7 @@ def test_galaxy_nlp_create_heads_decode( layout=ttnn.TILE_LAYOUT, memory_config=CREATE_HEAD_INPUT_MEMCFG, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) # tt operation @@ -636,7 +636,7 @@ def test_galaxy_rotary_matmul(batch, seq_len, head_dim, n_local_heads, n_local_k layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ROTARY_INPUT_MEMCFG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) key_layer = ttnn.from_torch( @@ -645,7 +645,7 @@ def test_galaxy_rotary_matmul(batch, seq_len, head_dim, n_local_heads, n_local_k layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ROTARY_INPUT_MEMCFG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) rot_mats = ttnn.from_torch( @@ -654,7 +654,7 @@ def test_galaxy_rotary_matmul(batch, seq_len, head_dim, n_local_heads, n_local_k layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ROT_MAT_MEMCFG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) compute_kernel_rotary = ttnn.WormholeComputeKernelConfig( @@ -725,7 +725,7 @@ def test_fill_cache( dtype=cache_dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) for i in range(num_users): x = torch.randn(input_shape).bfloat16().float() @@ -753,7 +753,7 @@ def test_fill_cache( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=input_mem_config, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) cachett = ttnn.fill_cache(cachett, xt, i) @@ -794,7 +794,7 @@ def test_update_cache_decode( dtype=cache_dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) x = torch.randn(input_shape).bfloat16().float() @@ -828,7 +828,7 @@ def test_update_cache_decode( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=input_mem_config, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) cachett = ttnn.update_cache(cachett, xt, cache_idx, batch_offset=batch_offset) @@ -924,7 +924,7 @@ def run_test_sdpa_decode_single_iter( dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_V = ttnn.from_torch( @@ -933,7 +933,7 @@ def run_test_sdpa_decode_single_iter( dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) start_idx = s // 2 scale = d**-0.5 @@ -965,7 +965,7 @@ def run_test_sdpa_decode_single_iter( dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_back = ttnn.transformer.scaled_dot_product_attention_decode( @@ -1064,7 +1064,7 @@ def test_galaxy_nlp_concat_heads_decode( layout=ttnn.TILE_LAYOUT, memory_config=CONCAT_HEADS_INPUT_MEMCFG, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) concat_head_output = ttnn.experimental.nlp_concat_heads_decode( @@ -1151,7 +1151,7 @@ def test_galaxy_layernorm(M, N, mesh_device): layout=ttnn.TILE_LAYOUT, memory_config=LN_OUTPUT_MEMCFG, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) norm_weights_tt = ttnn.from_torch( @@ -1159,7 +1159,7 @@ def test_galaxy_layernorm(M, N, mesh_device): dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) norm_output = ttnn.rms_norm( diff --git a/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py b/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py index 87c38fc5780..1faa8724328 100644 --- a/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py +++ b/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py @@ -53,7 +53,7 @@ def test_tensor_parallel_falcon_mlp(): # Initialize input activations on all devices in the mesh # Alternatively, we can shard the input activations on the height dimension and # subsequently invoke all-gather on the height dimension to form a complete tensor per device. - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): hidden_states = ttnn.from_torch( torch_hidden_states, dtype=ttnn.bfloat16, @@ -62,7 +62,7 @@ def test_tensor_parallel_falcon_mlp(): ) # Shard model parameters on width dimension to devices in the mesh - with ttnn.distribute(ttnn.ShardTensorToMesh(mesh_device, dim=-1)): + with ttnn.distribute(ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1)): parameters = ttnn.model_preprocessing.preprocess_model_parameters( initialize_model=lambda: model, device=mesh_device, diff --git a/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py b/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py index d309befa0b2..34d9d72e7fc 100644 --- a/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py +++ b/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py @@ -33,10 +33,10 @@ def test_bert_attention_inference( config = hugging_face_reference_model.config mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 batch_size = 16 if mesh_device_flag else 8 - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: pytorch_attention_model, device=mesh_device, @@ -90,10 +90,10 @@ def test_bert_intermediate_inference( mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 batch_size = 16 if mesh_device_flag else 8 - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: pytorch_intermediate_model, device=mesh_device, @@ -137,10 +137,10 @@ def test_bert_output_inference( mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 batch_size = 16 if mesh_device_flag else 8 - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: pytorch_output_model, device=mesh_device, @@ -194,10 +194,10 @@ def test_bert_layer_inference( mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 batch_size = 16 if mesh_device_flag else 8 - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: pytorch_layer_model, device=mesh_device, @@ -244,10 +244,10 @@ def test_bert_for_question_answering(mesh_device, model_name, sequence_size, num mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 batch_size = 16 if mesh_device_flag else 8 - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: model, device=mesh_device, @@ -277,7 +277,7 @@ def test_bert_for_question_answering(mesh_device, model_name, sequence_size, num ttnn_attention_mask = ttnn.from_torch( torch_attention_mask, dtype=ttnn.bfloat16, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), device=mesh_device, ) diff --git a/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py b/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py index 5d2dd6284bd..ecf85b2ec18 100644 --- a/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py +++ b/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py @@ -28,14 +28,14 @@ def test_distilbert_for_question_answering(mesh_device, model_name, batch_size, tt_model_name = f"ttnn_{model_name}_optimized" - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) if ttnn.GetNumAvailableDevices() == 2: batch_size = batch_size * 2 - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( model_name=tt_model_name, initialize_model=lambda: HF_model, diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather.py index 2a42df95821..226288832fd 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather.py @@ -151,7 +151,7 @@ def run_all_gather_impl( dtype=input_dtype, layout=layout, tile=ttnn.Tile(tile), - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim), device=mesh_device, ) if trace_mode: diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_llama_perf_sweep.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_llama_perf_sweep.py index 4357e6996d9..0c76cb164cc 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_llama_perf_sweep.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_llama_perf_sweep.py @@ -9,7 +9,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 import itertools -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper from tests.ttnn.unit_tests.operations.ccl.test_all_gather import run_all_gather_sharded diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_matmul.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_matmul.py index 74f6af3b4e5..7ba28c6d07e 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_matmul.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_matmul.py @@ -6,7 +6,7 @@ import pytest from loguru import logger import ttnn -from ttnn import ShardTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from models.utility_functions import skip_for_grayskull, skip_for_wormhole_b0 from tests.ttnn.unit_tests.operations.ccl.test_all_gather import is_unsupported_case @@ -73,7 +73,7 @@ def run_all_gather_matmul_on_t3000_impl( layout=layout, device=t3k_mesh_device, memory_config=mem_config_weights, - mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=dim), tile=ttnn.Tile(tile), ) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py index 503c33121a2..e30c48dd5af 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py @@ -12,7 +12,7 @@ is_unsupported_case, run_all_gather_on_t3000_impl, ) -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapperesh_mapper # Enumerate the post-commit cases explicitly @@ -186,7 +186,7 @@ def run_line_all_gather_instances( input_tensor = torch.rand(input_shape).bfloat16() ttnn_tensor = ttnn.from_torch( - input_tensor, tile=ttnn.Tile(tile), mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=dim) + input_tensor, tile=ttnn.Tile(tile), ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=dim) ) input_tensor_mesh = ttnn.to_device(ttnn_tensor, t3k_mesh_device) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_barrier_t3000_frequent.py b/tests/ttnn/unit_tests/operations/ccl/test_barrier_t3000_frequent.py index d50ab25bce3..e8dd087d82d 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_barrier_t3000_frequent.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_barrier_t3000_frequent.py @@ -9,7 +9,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from models.utility_functions import skip_for_grayskull from tests.ttnn.unit_tests.operations.ccl.test_all_gather import is_unsupported_case_t3k -from ttnn.distributed.distributed import ShardTensorToMesh +from ttnn.distributed.distributed import shard_tensor_to_mesh_mapper def sharded_impl( @@ -73,7 +73,7 @@ def sharded_impl( device=device, dtype=input_dtype, layout=tensor_layout, - mesh_mapper=ShardTensorToMesh(mesh_device=device, dim=dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device=device, dim=dim), tile=ttnn.Tile(tile), ) @@ -142,7 +142,6 @@ def run_normal( device=device, dtype=input_dtype, layout=layout, - mesh_mapper=ShardTensorToMesh(mesh_device=device, dim=dim), tile=ttnn.Tile(tile), ) for i in range(num_iters): diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py index c1170936dff..7548f88ca9e 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py @@ -127,7 +127,7 @@ def run_reduce_scatter_test( torch_tensor, dtype=input_dtype, layout=layout, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim), device=mesh_device, ) # Run the op diff --git a/tests/ttnn/unit_tests/operations/prefetcher_common.py b/tests/ttnn/unit_tests/operations/prefetcher_common.py index bfc881c16dc..1d91ac26f3c 100644 --- a/tests/ttnn/unit_tests/operations/prefetcher_common.py +++ b/tests/ttnn/unit_tests/operations/prefetcher_common.py @@ -8,7 +8,7 @@ import math from loguru import logger -from ttnn import ReplicateTensorToMesh, ShardTensor2dMesh, ConcatMeshToTensor, ConcatMesh2dToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ShardTensor2dMesh, ConcatMeshToTensor, ConcatMesh2dToTensor from models.common.lightweightmodule import LightweightModule from tests.ttnn.utils_for_testing import assert_with_pcc from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( @@ -228,7 +228,7 @@ def run_prefetcher_mm( mesh_composer = None if isinstance(device, ttnn._ttnn.multi_device.MeshDevice): cluster_shape = device.shape - mesh_mapper = ReplicateTensorToMesh(device) + mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(device) mesh_composer = ConcatMesh2dToTensor(device, dims=(0, 1), mesh_shape=cluster_shape) pt_tensors = [] diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index dbc28079e16..ae38f5d7b49 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -487,8 +487,8 @@ def test_conv_features_multi_device( shard_layout=shard_layout, output_layout=output_layout, has_bias=True, - input_mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0), - weight_mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + input_mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), + weight_mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), output_mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0), groups=groups, ) diff --git a/tests/ttnn/unit_tests/tensor/test_tensor_prealloc_and_write.py b/tests/ttnn/unit_tests/tensor/test_tensor_prealloc_and_write.py index 029da544301..0fb0572f80d 100644 --- a/tests/ttnn/unit_tests/tensor/test_tensor_prealloc_and_write.py +++ b/tests/ttnn/unit_tests/tensor/test_tensor_prealloc_and_write.py @@ -74,7 +74,7 @@ def test_tensor_preallocation_and_write_apis( input_tensor_a, dtype=in_dtype, layout=tensor_layout, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) ttnn.copy_host_to_device_tensor(tt_input_tensor_a, preallocated_tensor) readback_tensors = [ diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index 231fa015962..55307e8608c 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -11,7 +11,7 @@ from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor ####### @@ -110,7 +110,7 @@ def test_ttnn_to_multi_device_multiple_times(mesh_device, layout, memory_config, torch_tensor = torch.rand((1, 1, 32, 32 * mesh_device.get_num_devices()), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ShardTensorToMesh(mesh_device, dim=3) + torch_tensor, dtype=dtype, layout=layout, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) ) ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device, memory_config=memory_config) ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device, memory_config=memory_config) @@ -136,7 +136,7 @@ def test_ttnn_to_and_from_multi_device_shard(mesh_device, layout, memory_config, torch_tensor = torch.rand((1, 1, 32, 32 * mesh_device.get_num_devices()), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ShardTensorToMesh(mesh_device, dim=3) + torch_tensor, dtype=dtype, layout=layout, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) ) ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device, memory_config=memory_config) ttnn_loop_back_tensor = ttnn.from_device(ttnn_tensor) @@ -161,7 +161,7 @@ def test_multi_device_check_per_device_shard(mesh_device, layout, memory_config, torch_tensor = torch.rand((1, 1, 32, 64 * mesh_device.get_num_devices()), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), layout=layout + torch_tensor, dtype=dtype, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), layout=layout ) ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device, memory_config=memory_config) ttnn_loop_back_tensor = ttnn.from_device(ttnn_tensor) @@ -182,14 +182,14 @@ def test_multi_device_check_per_device_shard(mesh_device, layout, memory_config, @pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) @pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) def test_multi_device_replicate(mesh_device, shape, layout, memory_config): - """Test ReplicateTensorToMesh to broadcast a tensor across multiple devices""" - from ttnn import ReplicateTensorToMesh + """Test replicate_tensor_to_mesh_mapper to broadcast a tensor across multiple devices""" + from ttnn import replicate_tensor_to_mesh_mapper full_tensor = torch.rand(shape, dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( full_tensor, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), layout=layout, memory_config=memory_config, device=mesh_device, @@ -214,7 +214,7 @@ def test_ttnn_multi_device_all_gather(pcie_mesh_device): pytest.skip("Requires multiple devices to run") full_tensor = torch.rand((1, 1, 32, 32 * pcie_mesh_device.get_num_devices()), dtype=torch.bfloat16) - ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=3)) + ttnn_tensor = ttnn.from_torch(full_tensor, ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3)) ttnn_tensor = ttnn.to_device(ttnn_tensor, pcie_mesh_device) ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1) @@ -237,7 +237,7 @@ def test_multi_device_single_op_unary(mesh_device): ttnn_input_tensor = ttnn.from_torch( torch_input_tensor, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), device=mesh_device, ) ttnn_output_tensor = ttnn.gelu(ttnn_input_tensor) @@ -261,13 +261,13 @@ def test_multi_device_single_op_binary(mesh_device): torch_input_a_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), ) ttnn_input_b_tensor = ttnn.from_torch( torch_input_b_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), ) ttnn_output_tensor = ttnn.add(ttnn_input_a_tensor, ttnn_input_b_tensor) @@ -289,7 +289,7 @@ def test_multi_device_multi_op(mesh_device): ttnn_input_tensor = ttnn.from_torch( torch_input_tensor, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), device=mesh_device, ) ttnn_gelu_output = ttnn.gelu(ttnn_input_tensor) @@ -314,13 +314,13 @@ def test_multi_device_data_parallel_matmul_op(mesh_device): torch_input_a_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) ttnn_input_b_tensor = ttnn.from_torch( torch_input_b_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) ttnn_output_tensor = ttnn_input_a_tensor @ ttnn_input_b_tensor @@ -349,7 +349,7 @@ def test_multi_device_as_tensor_api(mesh_device, layout, memory_config, dtype): layout=layout, memory_config=memory_config, device=mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) with tempfile.NamedTemporaryFile() as temp_file: @@ -360,7 +360,7 @@ def test_multi_device_as_tensor_api(mesh_device, layout, memory_config, dtype): device=mesh_device, memory_config=memory_config, cache_file_name=f"{temp_file.name}.weight", - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) ttnn_input_b_tensor = ttnn.as_tensor( @@ -370,7 +370,7 @@ def test_multi_device_as_tensor_api(mesh_device, layout, memory_config, dtype): device=mesh_device, memory_config=memory_config, cache_file_name=f"{temp_file.name}.weight", - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) ttnn_output_tensor = ttnn_input_a_tensor @ ttnn_input_b_tensor @@ -404,7 +404,7 @@ def test_multi_device_as_tensor_api_sharded_tensor(mesh_device, layout, memory_c device=mesh_device, memory_config=memory_config, cache_file_name=f"{temp_file.name}.weight", - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) load_tensor = ttnn.as_tensor( input_tensor, @@ -413,7 +413,7 @@ def test_multi_device_as_tensor_api_sharded_tensor(mesh_device, layout, memory_c device=mesh_device, memory_config=memory_config, cache_file_name=f"{temp_file.name}.weight", - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) torch_loaded_tensor = ttnn.to_torch(load_tensor, mesh_composer=ConcatMeshToTensor(mesh_device, dim=0)) expected_pcc = 0.98 if dtype == ttnn.bfloat4_b else 0.99 @@ -436,7 +436,7 @@ def test_multi_device_permute(mesh_device, layout, memory_config, dtype): torch_golden = torch.permute(torch_tensor, (0, 1, 3, 2)) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ShardTensorToMesh(mesh_device, dim=3) + torch_tensor, dtype=dtype, layout=layout, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) ) ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device, memory_config=memory_config) ttnn_permute = ttnn.permute(ttnn_tensor, (0, 1, 3, 2)) @@ -457,7 +457,7 @@ def test_max(mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) gate_logits_1SB8 = ttnn.to_device(gate_logits_1SB8, mesh_device) weights_ex0_1SB1 = ttnn.max(gate_logits_1SB8, dim=3) @@ -478,7 +478,7 @@ def test_ttnn_multi_device_all_gather_all_devices(t3k_mesh_device): for i in range(t3k_mesh_device.get_num_devices()): full_tensor[..., i * 32 : (i + 1) * 32] = i - ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=3)) + ttnn_tensor = ttnn.from_torch(full_tensor, ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=3)) ttnn_tensor = ttnn.to_device(ttnn_tensor, t3k_mesh_device) ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1) @@ -499,14 +499,14 @@ def test_sharded_matmul(t3k_mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=t3k_mesh_device, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + mesh_mapperttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) keys_1BDP = ttnn.from_torch( torch.randn(1, 32, 128, 32), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=t3k_mesh_device, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + mesh_mapperttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) q_heads_1B4D = ttnn.to_device(q_heads_1B4D, t3k_mesh_device) @@ -561,7 +561,7 @@ def test_4b_tensor(mesh_device): dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tensor = ttnn.to_device(tensor, mesh_device) x = ttnn.from_torch( @@ -569,7 +569,7 @@ def test_4b_tensor(mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) x = ttnn.to_device(x, mesh_device) tensor = ttnn.matmul( @@ -588,7 +588,7 @@ def test_slicing(mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tensor = ttnn.to_device(tensor, mesh_device) tensor = tensor[:, :, :, :1] @@ -601,7 +601,7 @@ def test_clone(mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) results_11BH = ttnn.to_device(results_11BH, mesh_device) results_11BH = ttnn.clone(results_11BH, dtype=ttnn.bfloat8_b, memory_config=ttnn.L1_MEMORY_CONFIG) @@ -617,7 +617,7 @@ def test_device_shard_to_torch(mesh_device): ttnn_input_tensor = ttnn.from_torch( torch_input_tensor, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), device=mesh_device, ) @@ -643,7 +643,7 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width): layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=memory_config, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), cache_file_name=tmp_path / "cache_file", ) assert tensor.dtype == ttnn.float32 @@ -657,7 +657,7 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width): layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=memory_config, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), cache_file_name=tmp_path / "cache_file", ) assert tensor.dtype == ttnn.float32 @@ -687,7 +687,7 @@ def model(submesh): for i in range(submesh.get_num_devices()): full_tensor[..., i * 32 : (i + 1) * 32] = i - ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(submesh, dim=3)) + ttnn_tensor = ttnn.from_torch(full_tensor, ttnn.shard_tensor_to_mesh_mapper(submesh, dim=3)) ttnn_tensor = ttnn.to_device(ttnn_tensor, submesh) ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1) @@ -724,7 +724,7 @@ def test_line_all_gather_after_reshape(mesh_device): def test_distribute_api(mesh_device): torch_hidden_states = torch.rand((1, 1, 32, 32), dtype=torch.bfloat16) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): hidden_states = ttnn.from_torch( torch_hidden_states, dtype=ttnn.bfloat8_b, diff --git a/tests/ttnn/unit_tests/test_multi_device_async.py b/tests/ttnn/unit_tests/test_multi_device_async.py index 3b1e75f500d..8f9e8c5b5df 100644 --- a/tests/ttnn/unit_tests/test_multi_device_async.py +++ b/tests/ttnn/unit_tests/test_multi_device_async.py @@ -9,7 +9,7 @@ from loguru import logger from tests.ttnn.utils_for_testing import assert_with_pcc import transformers - +from ttnn import shard_tensor_to_mesh_mapper ####### # Multi-Device Tensor tests running in async mode @@ -21,7 +21,7 @@ @pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) def test_ttnn_to_and_from_multi_device_shard(pcie_mesh_device, layout, memory_config, dtype): """Shard a tensor across devices, compose it back and verify loopback tensor is same as the original tensor""" - from ttnn import ShardTensorToMesh, ConcatMeshToTensor + from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor if dtype == ttnn.bfloat8_b and layout == ttnn.ROW_MAJOR_LAYOUT: pytest.skip("Unsupported test permutation: bfloat8_b with ROW_MAJOR_LAYOUT") @@ -31,7 +31,10 @@ def test_ttnn_to_and_from_multi_device_shard(pcie_mesh_device, layout, memory_co for i in range(100): torch_tensor = torch.rand((1, 1, 256, 512), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=3) + torch_tensor, + dtype=dtype, + layout=layout, + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3), ) ttnn_tensor = ttnn.to_device(ttnn_tensor, pcie_mesh_device, memory_config=memory_config) ttnn_loop_back_tensor = ttnn.from_device(ttnn_tensor) @@ -48,7 +51,7 @@ def test_ttnn_to_and_from_multi_device_shard(pcie_mesh_device, layout, memory_co @pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) def test_multi_device_check_per_device_shard(pcie_mesh_device, layout, memory_config, dtype): """This test checks if the tensor is correctly sharded across devices""" - from ttnn import ShardTensorToMesh, ConcatMeshToTensor + from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor if dtype == ttnn.bfloat8_b and layout == ttnn.ROW_MAJOR_LAYOUT: pytest.skip("Unsupported test permutation: bfloat8_b with ROW_MAJOR_LAYOUT") @@ -63,7 +66,10 @@ def test_multi_device_check_per_device_shard(pcie_mesh_device, layout, memory_co torch_tensor = torch.rand((8, 1, 1024, 1024), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=3) + torch_tensor, + dtype=dtype, + layout=layout, + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3), ) ttnn_tensor = ttnn.to_device(ttnn_tensor, pcie_mesh_device, memory_config=memory_config) ttnn_loop_back_tensor = ttnn.from_device(ttnn_tensor) @@ -83,8 +89,8 @@ def test_multi_device_check_per_device_shard(pcie_mesh_device, layout, memory_co @pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) @pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) def test_multi_device_replicate(pcie_mesh_device, shape, layout, memory_config): - """Test ReplicateTensorToMesh to broadcast a tensor across multiple devices""" - from ttnn import ReplicateTensorToMesh + """Test replicate_tensor_to_mesh_mapper to broadcast a tensor across multiple devices""" + from ttnn import replicate_tensor_to_mesh_mapper pcie_mesh_device.enable_async(True) @@ -93,7 +99,7 @@ def test_multi_device_replicate(pcie_mesh_device, shape, layout, memory_config): ttnn_tensor = ttnn.from_torch( full_tensor, - mesh_mapper=ReplicateTensorToMesh(pcie_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(pcie_mesh_device), layout=layout, memory_config=memory_config, device=pcie_mesh_device, @@ -114,7 +120,7 @@ def test_multi_device_replicate(pcie_mesh_device, shape, layout, memory_config): @pytest.mark.parametrize("dtype", [ttnn.bfloat8_b]) def test_ttnn_to_multi_device_tilized_parallel(pcie_mesh_device, layout, memory_config, dtype): """Test multi chip layout conversions on worker threads""" - from ttnn import ShardTensorToMesh, ConcatMeshToTensor + from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor shard_dim = 3 pcie_mesh_device.enable_async(True) @@ -122,7 +128,7 @@ def test_ttnn_to_multi_device_tilized_parallel(pcie_mesh_device, layout, memory_ torch_tensor = torch.rand((8, 1, 1024, 1024), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( torch_tensor, - mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=shard_dim), + ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=shard_dim), layout=layout, memory_config=memory_config, device=pcie_mesh_device, @@ -144,7 +150,7 @@ def test_ttnn_to_multi_device_tilized_parallel(pcie_mesh_device, layout, memory_ @pytest.mark.parametrize("shape", [(1, 1, 512, 512), (1, 3, 1024, 1024)]) def test_multi_device_unary_binary_op_chain(pcie_mesh_device, program_cache, shape): """Multidevice API test: Running tensor-parallel multi-device chain of eltwise ops""" - from ttnn import ShardTensorToMesh, ConcatMeshToTensor + from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor pcie_mesh_device.enable_async(True) if program_cache: @@ -164,7 +170,7 @@ def test_multi_device_unary_binary_op_chain(pcie_mesh_device, program_cache, sha ttnn_input_tensor = ttnn.from_torch( torch_input_tensor, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3), device=pcie_mesh_device, ) ttnn_output_tensor = ttnn.add( @@ -184,7 +190,7 @@ def test_multi_device_unary_binary_op_chain(pcie_mesh_device, program_cache, sha @pytest.mark.parametrize("input_a_shape", [(4, 1, 512, 512), (16, 1, 512, 512)]) def test_multi_device_data_parallel_op_chain(pcie_mesh_device, program_cache, input_a_shape): """Multidevice API: Running data-parallel chain of ops with matmul""" - from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh + from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper pcie_mesh_device.enable_async(True) if program_cache: @@ -206,13 +212,13 @@ def test_multi_device_data_parallel_op_chain(pcie_mesh_device, program_cache, in torch_input_a_tensor, layout=ttnn.TILE_LAYOUT, device=pcie_mesh_device, - mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=0), ) ttnn_input_b_tensor = ttnn.from_torch( torch_input_b_tensor, layout=ttnn.TILE_LAYOUT, device=pcie_mesh_device, - mesh_mapper=ReplicateTensorToMesh(pcie_mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(pcie_mesh_device), ) ttnn_output_tensor = ttnn.from_device( ttnn.mish( @@ -249,7 +255,7 @@ def test_multi_device_argmax(pcie_mesh_device, layout, mem_config): layout=layout, device=pcie_mesh_device, memory_config=mem_config, - mesh_mapper=ttnn.ReplicateTensorToMesh(pcie_mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(pcie_mesh_device), ) tt_out_11BH = ttnn.argmax(tt_out_11BH, dim=-1) @@ -264,7 +270,7 @@ def test_multi_device_argmax(pcie_mesh_device, layout, mem_config): @pytest.mark.parametrize("pcie_mesh_device", [2], indirect=True) def test_multi_device_explicit_dealloc(pcie_mesh_device): """Multidevice API: Ensure that deallocating multi-device tensors works as expected""" - from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh + from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper if pcie_mesh_device.get_num_devices() <= 1: pytest.skip("Requires multiple devices to run") @@ -278,13 +284,13 @@ def test_multi_device_explicit_dealloc(pcie_mesh_device): torch_input_a_tensor, layout=ttnn.TILE_LAYOUT, device=pcie_mesh_device, - mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=0), ) ttnn_input_b_tensor = ttnn.from_torch( torch_input_b_tensor, layout=ttnn.TILE_LAYOUT, device=pcie_mesh_device, - mesh_mapper=ReplicateTensorToMesh(pcie_mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(pcie_mesh_device), ) ttnn_output_tensor_1 = ttnn_input_a_tensor @ ttnn_input_b_tensor ttnn_output_tensor_2 = ttnn.gelu(ttnn_output_tensor_1) @@ -315,7 +321,7 @@ def test_add_1D_tensor_and_scalar(pcie_mesh_device, scalar, size): torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=pcie_mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(pcie_mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(pcie_mesh_device), ) output_tensor = input_tensor + scalar output_tensors = [ttnn.to_torch(shard) for shard in ttnn.get_device_tensors(output_tensor.cpu())] diff --git a/tests/ttnn/unit_tests/test_multi_device_events.py b/tests/ttnn/unit_tests/test_multi_device_events.py index b41c7cfaa3d..824282958e6 100644 --- a/tests/ttnn/unit_tests/test_multi_device_events.py +++ b/tests/ttnn/unit_tests/test_multi_device_events.py @@ -10,7 +10,7 @@ from loguru import logger import os from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor @pytest.mark.parametrize("shape", [(1, 1, 512, 512)]) @@ -52,10 +52,10 @@ def run_op_chain(input_0, input_1, workload_cq): ) # Convert torch tensors to TTNN Multi-Device Host Tensors ttnn_input_tensor_0 = ttnn.from_torch( - torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0) + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0) ) ttnn_input_tensor_1 = ttnn.from_torch( - torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0) + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0) ) # Copy TTNN host tensors into preallocated Mult-Device tensors, using data-movement CQ diff --git a/tests/ttnn/unit_tests/test_multi_device_trace.py b/tests/ttnn/unit_tests/test_multi_device_trace.py index 284a75c0a60..9de288ff0ce 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace.py @@ -10,7 +10,7 @@ from loguru import logger import os from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor NUM_TRACE_LOOPS = int(os.getenv("NUM_TRACE_LOOPS", 15)) @@ -84,10 +84,10 @@ def event_sync(device, record_cq, wait_cq): ) # Convert torch tensors to TTNN Multi-Device Host Tensors ttnn_input_tensor_0 = ttnn.from_torch( - torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0) + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0) ) ttnn_input_tensor_1 = ttnn.from_torch( - torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0) + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0) ) # Copy TTNN host tensors into preallocated Mult-Device tensors @@ -240,13 +240,13 @@ def event_sync(device, record_cq, wait_cq): # Convert torch tensors to TTNN Multi-Device Host Tensors ttnn_input_tensor_0 = ttnn.from_torch( - torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0) + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0) ) ttnn_input_tensor_1 = ttnn.from_torch( - torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0) + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0) ) ttnn_weight = ttnn.from_torch( - torch_weight, layout=ttnn.TILE_LAYOUT, mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device) + torch_weight, layout=ttnn.TILE_LAYOUT, ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device) ) # Copy TTNN host tensors into preallocated Mult-Device tensors diff --git a/tests/ttnn/unit_tests/test_multi_device_trace_TG.py b/tests/ttnn/unit_tests/test_multi_device_trace_TG.py index 5c24ab237aa..dab95c7d7ac 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace_TG.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace_TG.py @@ -10,7 +10,7 @@ from loguru import logger import os from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor NUM_TRACE_LOOPS = int(os.getenv("NUM_TRACE_LOOPS", 15)) @@ -80,10 +80,10 @@ def event_sync(device, record_cq, wait_cq): ) # Convert torch tensors to TTNN Multi-Device Host Tensors ttnn_input_tensor_0 = ttnn.from_torch( - torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) ttnn_input_tensor_1 = ttnn.from_torch( - torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) # Copy TTNN host tensors into preallocated Mult-Device tensors @@ -218,13 +218,13 @@ def event_sync(device, record_cq, wait_cq): # Convert torch tensors to TTNN Multi-Device Host Tensors ttnn_input_tensor_0 = ttnn.from_torch( - torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) ttnn_input_tensor_1 = ttnn.from_torch( - torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) ttnn_weight = ttnn.from_torch( - torch_weight, layout=ttnn.TILE_LAYOUT, mesh_mapper=ReplicateTensorToMesh(mesh_device) + torch_weight, layout=ttnn.TILE_LAYOUT, ttnn.replicate_tensor_to_mesh_mapper(mesh_device) ) # Copy TTNN host tensors into preallocated Mult-Device tensors diff --git a/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py b/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py index d1354f329ea..e7bcd4be2a4 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py @@ -10,7 +10,7 @@ from loguru import logger import os from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor NUM_TRACE_LOOPS = int(os.getenv("NUM_TRACE_LOOPS", 15)) @@ -80,10 +80,10 @@ def event_sync(device, record_cq, wait_cq): ) # Convert torch tensors to TTNN Multi-Device Host Tensors ttnn_input_tensor_0 = ttnn.from_torch( - torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) ttnn_input_tensor_1 = ttnn.from_torch( - torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) # Copy TTNN host tensors into preallocated Mult-Device tensors @@ -217,13 +217,13 @@ def event_sync(device, record_cq, wait_cq): # Convert torch tensors to TTNN Multi-Device Host Tensors ttnn_input_tensor_0 = ttnn.from_torch( - torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) ttnn_input_tensor_1 = ttnn.from_torch( - torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) ttnn_weight = ttnn.from_torch( - torch_weight, layout=ttnn.TILE_LAYOUT, mesh_mapper=ReplicateTensorToMesh(mesh_device) + torch_weight, layout=ttnn.TILE_LAYOUT, ttnn.replicate_tensor_to_mesh_mapper(mesh_device) ) # Copy TTNN host tensors into preallocated Mult-Device tensors diff --git a/tests/ttnn/unit_tests/test_reshape.py b/tests/ttnn/unit_tests/test_reshape.py index 40fd7c15052..8d1d87cfd35 100644 --- a/tests/ttnn/unit_tests/test_reshape.py +++ b/tests/ttnn/unit_tests/test_reshape.py @@ -546,7 +546,7 @@ def test_reshape_zero_element(input_shape, output_shape, layout, ttnn_reshape, u ) def test_reshape_replicated_tensor(mesh_device, input_shape, output_shape): torch_input_tensor = torch.randn(input_shape) - mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) tt_input_tensor = ttnn.from_torch( torch_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, mesh_mapper=mesh_mapper, device=mesh_device ) diff --git a/tests/ttnn/unit_tests/test_sub_device.py b/tests/ttnn/unit_tests/test_sub_device.py index 763a003fc7f..821faad79ab 100644 --- a/tests/ttnn/unit_tests/test_sub_device.py +++ b/tests/ttnn/unit_tests/test_sub_device.py @@ -53,7 +53,7 @@ def run_sub_devices(device, create_fabric_sub_device=False): def run_sub_devices_program(device, create_fabric_sub_device=False): is_mesh_device = isinstance(device, ttnn.MeshDevice) if is_mesh_device: - inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(device, dim=0) num_devices = device.get_num_devices() else: diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 0f6685dc5c3..9de607d505c 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -7,8 +7,10 @@ #include #include +#include "tt-metalium/assert.hpp" #include "tt-metalium/mesh_coord.hpp" #include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/tensor_impl.hpp" #include "ttnn/tensor/tensor_utils.hpp" #include "ttnn/distributed/distributed_tensor_config.hpp" #include @@ -93,6 +95,23 @@ Tensor aggregate_as_tensor( } auto storage = MultiDeviceHostStorage{config, std::move(host_owned_buffers), specs}; return Tensor(std::move(storage), reference_shard.get_tensor_spec()); + } else if (storage_type == StorageType::BORROWED) { + std::vector specs; + std::vector host_owned_buffers; + for (const auto& shard : tensor_shards) { + auto buffer = std::get(shard.get_storage()).buffer; + specs.push_back(shard.get_tensor_spec()); + + auto visitor = tt::stl::overloaded{[&shard, &host_owned_buffers](const auto& buffer) -> OwnedBuffer { + using BorrowedBufferType = std::vector::value_type>; + + return owned_buffer::create(BorrowedBufferType(buffer.begin(), buffer.end())); + }}; + + host_owned_buffers.push_back(std::visit(visitor, buffer)); + } + auto storage = MultiDeviceHostStorage{config, std::move(host_owned_buffers), specs}; + return Tensor(std::move(storage), reference_shard.get_tensor_spec()); } else { std::vector ordered_device_ids; std::unordered_map specs; diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 9ad24cf4aee..5d44ac3a514 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -3,16 +3,21 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttnn/distributed/distributed_pybind.hpp" -#include - #include - -#include +#include +#include +#include +#include #include "tt-metalium/mesh_coord.hpp" +#include "tt-metalium/assert.hpp" +#include "distributed_tensor.hpp" #include "ttnn/distributed/api.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" #include "ttnn/distributed/types.hpp" -#include "ttnn/tensor/tensor.hpp" -#include "ttnn/types.hpp" +#include "ttnn/operations/core/core.hpp" +#include "ttnn/tensor/tensor_utils.hpp" +#include +#include "ttnn/tensor/tensor_impl_wrapper.hpp" // This is required for automatic conversions, as in the creation of mesh devices // https://github.com/tenstorrent/tt-metal/issues/18082 @@ -24,7 +29,55 @@ namespace ttnn::distributed { namespace py = pybind11; +// Trampoline class to clear virtual method errors +struct ConcreteTensorToMesh : TensorToMesh { + using TensorToMesh::TensorToMesh; // Inherit constructors + + std::vector map(const Tensor& tensor) const override { + PYBIND11_OVERRIDE(std::vector, TensorToMesh, map, tensor); + } + + tt::tt_metal::DistributedTensorConfig config() const override { + PYBIND11_OVERRIDE(tt::tt_metal::DistributedTensorConfig, TensorToMesh, config); + } +}; + +// Trampoline class to clear virtual method errors +struct ConcreteMeshToTensor : MeshToTensor { + Tensor compose(const std::vector& tensors) const override { + PYBIND11_OVERRIDE(Tensor, MeshToTensor, compose, tensors); + } +}; + +// unused empty implementations to satisfy pybind's desire for unique objects +class ReplicateTensorToMesh : public TensorToMesh {}; +class ShardTensorToMesh : public TensorToMesh {}; +class ShardTensorTo2dMesh : public TensorToMesh {}; +class ConcatMeshToTensor : public MeshToTensor {}; +class Concat2dMeshToTensor : public MeshToTensor {}; + void py_module_types(py::module& module) { + py::class_>(module, "CppMeshToTensor"); + py::class_>(module, "TensorToMesh"); + + py::class_>( + module, "ReplicateTensorToMesh"); + py::class_>(module, "ShardTensorToMesh"); + py::class_>(module, "ShardTensorTo2dMesh"); + py::class_>(module, "CppConcatMeshToTensor"); + py::class_>( + module, "CppConcat2dMeshToTensor"); + + py::class_(module, "ReplicateTensor"); + py::class_(module, "ShardTensor"); + py::class_(module, "ShardTensor2d"); + py::class_(module, "ShardMesh"); + py::class_(module, "AllGatherTensor"); + py::class_(module, "DistributedTensorConfig"); + + py::class_(module, "Shard2dConfig"); + py::class_(module, "Concat2dConfig"); + py::class_>(module, "MeshDevice"); py::class_(module, "MeshSubDeviceManagerId"); py::class_(module, "MeshShape", "Shape of a mesh device."); @@ -360,6 +413,98 @@ void py_module(py::module& module) { back to all SubDevice IDs. )doc"); + auto py_tensor_to_mesh = static_cast>>( + module.attr("TensorToMesh")); + py_tensor_to_mesh + .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) + .def("map", &TensorToMesh::map) + .def("config", &TensorToMesh::config); + auto py_replicate_tensor_to_mesh = + static_cast>>(module.attr("ReplicateTensorToMesh")); + py_replicate_tensor_to_mesh + .def( + py::init([](MeshDevice& mesh_device) -> std::unique_ptr { + return replicate_tensor_to_mesh_mapper(mesh_device); + }), + py::arg("mesh_device")) + .def( + "map", [](const TensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) + .def("config", &TensorToMesh::config); + auto py_shard_tensor_to_mesh = + static_cast>>(module.attr("ShardTensorToMesh")); + py_shard_tensor_to_mesh + .def( + py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { + return shard_tensor_to_mesh_mapper(mesh_device, dim); + }), + py::arg("mesh_device"), + py::arg("dim")) + .def( + "map", [](const TensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) + .def("config", &TensorToMesh::config); + auto py_shard_tensor_to_2d_mesh = + static_cast>>(module.attr("ShardTensorTo2dMesh")); + py_shard_tensor_to_2d_mesh + .def( + py::init( + [](MeshDevice& mesh_device, + const std::tuple mesh_shape, + const std::tuple dims) -> std::unique_ptr { + int mesh_rows = std::get<0>(mesh_shape); + int mesh_cols = std::get<1>(mesh_shape); + + int config_rows = std::get<0>(dims); + int config_cols = std::get<1>(dims); + return shard_tensor_to_2d_mesh_mapper( + mesh_device, + MeshShape(mesh_rows, mesh_cols), + Shard2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); + }), + py::arg("mesh_device"), + py::arg("mesh_shape"), + py::arg("dims")) + .def( + "map", [](const TensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) + .def("config", &TensorToMesh::config); + auto py_mesh_to_tensor = static_cast>>( + module.attr("CppMeshToTensor")); + py_mesh_to_tensor + .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) + .def("compose", &MeshToTensor::compose); + auto py_concat_mesh_to_tensor = + static_cast>>(module.attr("CppConcatMeshToTensor")); + py_concat_mesh_to_tensor + .def( + py::init([](int dim) -> std::unique_ptr { return concat_mesh_to_tensor_composer(dim); }), + py::arg("dim")) + .def( + "compose", + [](const MeshToTensor& self, const std::vector& tensors) { return self.compose(tensors); }, + py::arg("tensors")); + + auto py_concat_2d_mesh_to_tensor = + static_cast>>(module.attr("CppConcat2dMeshToTensor")); + py_concat_2d_mesh_to_tensor + .def( + py::init([](MeshDevice& mesh_device, const std::tuple dims) -> std::unique_ptr { + int row_dim = std::get<0>(dims); + int col_dim = std::get<1>(dims); + return concat_2d_mesh_to_tensor_composer( + mesh_device, + Concat2dConfig{ + .row_dim = row_dim, + .col_dim = col_dim, + }); + }), + py::arg("mesh_device"), + py::arg("dims")) + .def( + "compose", + [](const MeshToTensor& self, const std::vector& tensors) -> Tensor { + return self.compose(tensors); + }, + py::arg("tensors")); + module.def( "open_mesh_device", &open_mesh_device, @@ -371,7 +516,6 @@ void py_module(py::module& module) { py::arg("offset"), py::arg("physical_device_ids"), py::arg("dispatch_core_config")); - module.def("close_mesh_device", &close_mesh_device, py::arg("mesh_device"), py::kw_only()); module.def( "get_device_tensor", @@ -382,15 +526,58 @@ void py_module(py::module& module) { R"doc( Get the tensor shard corresponding to the device_id. - Args: tensor (Tensor): The tensor to get the shard from. device_id (int): The device id to get the shard for. + Returns: + Tensor: The shard of the tensor corresponding to the device_id. + )doc"); + + auto py_replicate_tensor_config = static_cast>(module.attr("ShardTensor")); + py_replicate_tensor_config.def(py::init<>()) + .def(py::init(), py::arg("replication_factor") = 1) + .def_readwrite("shard_dimension", &ReplicateTensor::replication_factor) + .def("__eq__", [](const ReplicateTensor& a, const ReplicateTensor& b) { + return a.replication_factor == b.replication_factor; + }); + + auto py_shard_tensor_config = static_cast>(module.attr("ShardTensor")); + py_shard_tensor_config.def(py::init(), py::arg("shard_dimension")) + .def_readwrite("shard_dimension", &ShardTensor::shard_dimension) + .def("__eq__", [](const ShardTensor& a, const ShardTensor& b) { return a == b; }); + auto py_shard_mesh = static_cast>(module.attr("ShardMesh")); + py_shard_mesh.def(py::init<>()).def_readwrite("y", &ShardMesh::y).def_readwrite("x", &ShardMesh::x); + auto py_shard_tensor2d = static_cast>(module.attr("ShardTensor2d")); + py_shard_tensor2d.def(py::init(), py::arg("mesh")) + .def_readonly("shard_mesh", &ShardTensor2D::shard_mesh) + .def("__eq__", [](const ShardTensor2D& a, const ShardTensor2D& b) { return a == b; }); + auto py_allgather_config = + static_cast>(module.attr("AllGatherTensor")) + .def(py::init<>()) + .def("__eq__", [](const AllGatherTensor& a, const AllGatherTensor& b) { return a == b; }); + + auto py_shard2d_config = static_cast>(module.attr("Shard2dConfig")); + py_shard2d_config.def(py::init(), py::arg("row_dim"), py::arg("col_dim")) + .def_readwrite("row_dim", &Shard2dConfig::row_dim) + .def_readwrite("col_dim", &Shard2dConfig::col_dim); + auto py_concat2d_config = static_cast>(module.attr("Concat2dConfig")); + py_concat2d_config.def(py::init(), py::arg("row_dim"), py::arg("col_dim")) + .def_readwrite("row_dim", &Concat2dConfig::row_dim) + .def_readwrite("col_dim", &Concat2dConfig::col_dim); - Returns: - Tensor: The shard of the tensor corresponding to the device_id. - )doc"); + module.def( + "get_distributed_tensor_config", + &get_distributed_tensor_config, + py::arg("metadata"), + R"doc( + Returns a distributed_tensor_config object given a valid metadata object of the type + + { + "item": "field", + "item": "field", + } + )doc"); module.def( "get_device_tensor", py::overload_cast(&ttnn::distributed::get_device_tensor), @@ -400,7 +587,6 @@ void py_module(py::module& module) { R"doc( Get the tensor shard corresponding to the device. - Args: tensor (Tensor): The tensor to get the shard from. device (Device): The device to get the shard for. @@ -410,6 +596,122 @@ void py_module(py::module& module) { Tensor: The shard of the tensor corresponding to the device. )doc"); module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); + // TODO: Add rdocs + module.def( + "replicate_tensor_to_mesh_mapper", + [](MeshDevice& mesh_device) -> std::unique_ptr { + return replicate_tensor_to_mesh_mapper(mesh_device); + }, + py::arg("mesh_device")); + module.def( + "shard_tensor_to_mesh_mapper", + [](MeshDevice& mesh_device, int dim) -> std::unique_ptr { + return shard_tensor_to_mesh_mapper(mesh_device, dim); + }, + py::arg("mesh_device"), + py::arg("dim")); + module.def( + "shard_tensor_to_2d_mesh_mapper", + [](MeshDevice& mesh_device, + const MeshShape& mesh_shape, + const Shard2dConfig& config) -> std::unique_ptr { + return shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape, config); + }, + py::arg("mesh_device"), + py::arg("mesh_shape"), + py::arg("config")); + module.def( + "shard_tensor_to_2d_mesh_mapper", + [](MeshDevice& mesh_device, + const std::tuple mesh_shape, + const std::tuple dims) -> std::unique_ptr { + return shard_tensor_to_2d_mesh_mapper( + mesh_device, + MeshShape(std::get<0>(mesh_shape), std::get<1>(mesh_shape)), + Shard2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); + }, + py::arg("mesh_device"), + py::arg("mesh_shape"), + py::arg("dims"), + R"doc( + Create a ShardTensorTo2dMesh mapper with the given mesh device, mesh shape, and dimensions. + + Args: + mesh_device (MeshDevice): The mesh device to create the mapper for. + mesh_shape (MeshShape): The shape of the 2D mesh as (num_rows, num_cols). + dims (Tuple[int, int]): The dimensions to create the mapper for in (row, column) format. + + Returns: + TensorToMesh: The created ShardTensorTo2dMesh mapper. + )doc"); + module.def( + "concat_mesh_to_tensor_composer", + [](int dim) -> std::unique_ptr { return concat_mesh_to_tensor_composer(dim); }, + py::arg("dim")); + module.def( + "concat_2d_mesh_to_tensor_composer", + [](MeshDevice& mesh_device, + const std::tuple dims) -> std::unique_ptr { + return concat_2d_mesh_to_tensor_composer( + mesh_device, Concat2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); + }, + py::arg("mesh_device"), + py::arg("dims"), + R"doc( + Create a Concat2dMeshToTensor composer with the given mesh device and dimensions. + + Args: + mesh_device (MeshDevice): The mesh device to create the composer for. + dims (Tuple[int, int]): The dimensions to create the composer for in (row, column) format. + mesh_shape (Tuple[int, int]): The shape of the 2D mesh as (num_rows, num_cols). + + Returns: + TensorToMesh: The created Concat2dMeshToTensor composer. + )doc"); + module.def( + "distribute_tensor", + [](const Tensor& tensor, + const TensorToMesh& mapper, + std::optional> mesh_device = std::nullopt) -> Tensor { + + Tensor cpu_tensor; + if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { + cpu_tensor = tt::tt_metal::tensor_impl::to_host_mesh_tensor_wrapper(tensor, true); + } else { + cpu_tensor = from_device(tensor); + } + return distribute_tensor(cpu_tensor, mapper, mesh_device); + }, + py::arg("tensor"), + py::arg("mapper"), + py::arg("mesh_device") = py::none()); + module.def( + "aggregate_tensor", + [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { + Tensor cpu_tensor; + if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { + cpu_tensor = tt::tt_metal::tensor_impl::to_host_mesh_tensor_wrapper(tensor, true); + } else { + cpu_tensor = from_device(tensor); + } + return aggregate_tensor(cpu_tensor, composer); + }, + py::arg("tensor"), + py::arg("composer")); + module.def( + "aggregate_tensor", + [](const std::vector& tensors, const MeshToTensor& composer) -> Tensor { + Tensor aggregated_tensor = from_device(aggregate_as_tensor(tensors, AllGatherTensor{})); + Tensor cpu_tensor; + if (aggregated_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { + cpu_tensor = tt::tt_metal::tensor_impl::to_host_mesh_tensor_wrapper(aggregated_tensor, true); + } else { + cpu_tensor = from_device(aggregated_tensor); + } + return aggregate_tensor(cpu_tensor, composer); + }, + py::arg("tensor"), + py::arg("composer")); module.def( "aggregate_as_tensor", [](const std::vector& tensors) -> Tensor { return aggregate_as_tensor(tensors, AllGatherTensor{}); }, diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.hpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.hpp index 93d26f3f2d6..25c384363bc 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.hpp @@ -5,6 +5,7 @@ #pragma once #include "pybind11/pybind_fwd.hpp" #include +#include "pybind11/stl.h" namespace py = pybind11; diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index af3cf6d1fbf..79d3377d584 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -2,14 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 -#include - -#include "ttnn/distributed/api.hpp" #include "ttnn/distributed/distributed_tensor.hpp" -#include -#include "ttnn/distributed/distributed_tensor_config.hpp" -#include "ttnn/distributed/types.hpp" -#include "ttnn/tensor/xtensor/partition.hpp" +#include "tt-metalium/assert.hpp" namespace ttnn::distributed { namespace { @@ -166,7 +160,7 @@ std::unique_ptr shard_tensor_to_2d_mesh_mapper( TT_FATAL( mesh_shape[0] <= mesh_device.shape()[0] && // mesh_shape[1] <= mesh_device.shape()[1], - "Device mesh shape does not match the provided mesh shape."); + "Device mesh shape {} does not match the provided mesh shape ({}, {}).", mesh_device.shape(), mesh_shape[0], mesh_shape[1]); return std::make_unique(mesh_shape[0], mesh_shape[1], config); } @@ -194,7 +188,7 @@ Tensor distribute_tensor( std::vector tensors = mapper.map(tensor); Tensor output = aggregate_as_tensor(tensors, mapper.config()); if (mesh_device.has_value()) { - return output.to_device(&(mesh_device->get())); + return output.to_device(&(mesh_device->get()), output.memory_config()); } return output; } diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index 7d49ca932f4..b8fd7e8003e 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -4,24 +4,43 @@ #pragma once +#include "tt-metalium/mesh_device.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/distributed/types.hpp" +#include "ttnn/distributed/api.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" +#include "ttnn/distributed/types.hpp" +#include "ttnn/tensor/xtensor/partition.hpp" +#include +#include namespace ttnn::distributed { // Mapper interface that distributes a host tensor onto a multi-device configuration. +// The __attribute__((weak)) instructs pybind imports not to look for a symbol for these functions, as the linker won't +// create one. class TensorToMesh { public: virtual ~TensorToMesh() = default; - virtual std::vector map(const Tensor& tensor) const = 0; - virtual tt::tt_metal::DistributedTensorConfig config() const = 0; + virtual __attribute__((weak)) std::vector map(const Tensor& tensor) const = 0; + virtual __attribute__((weak)) tt::tt_metal::DistributedTensorConfig config() const = 0; }; // Composer interface that aggregates a multi-device tensor into a host tensor. class MeshToTensor { public: virtual ~MeshToTensor() = default; - virtual Tensor compose(const std::vector& tensors) const = 0; + virtual __attribute__((weak)) Tensor compose(const std::vector& tensors) const = 0; +}; + +struct Shard2dConfig { + std::optional row_dim; + std::optional col_dim; +}; + +struct Concat2dConfig { + int row_dim = -1; + int col_dim = -1; }; // Creates a mapper that replicates a tensor across all devices. @@ -32,10 +51,6 @@ std::unique_ptr shard_tensor_to_mesh_mapper(MeshDevice& mesh_devic // Creates a mapper that shards a tensor along two dimensions, which will be intepreted as rows and columns. // If either dimension is not specified, the tensor is replicated along that dimension. -struct Shard2dConfig { - std::optional row_dim; - std::optional col_dim; -}; std::unique_ptr shard_tensor_to_2d_mesh_mapper( MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config); @@ -43,10 +58,8 @@ std::unique_ptr shard_tensor_to_2d_mesh_mapper( std::unique_ptr concat_mesh_to_tensor_composer(int dim); // Creates a composer that concatenates a tensor across two dimensions. -struct Concat2dConfig { - int row_dim = -1; - int col_dim = -1; -}; + + std::unique_ptr concat_2d_mesh_to_tensor_composer(MeshDevice& mesh_device, const Concat2dConfig& config); // Distributes a host tensor onto multi-device configuration according to the `mapper`. diff --git a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp index 6429d55226b..db13774bbb9 100644 --- a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp +++ b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp @@ -14,6 +14,7 @@ #include #include "cpp/ttnn/operations/experimental/reshape/view.hpp" #include "ttnn/operations/core/core.hpp" +#include "ttnn/tensor/types.hpp" #include "ttnn/types.hpp" namespace ttnn { @@ -104,6 +105,7 @@ Tensor to_layout_impl( TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!"); return ttnn::untilize(tensor, output_memory_config, use_multicore_untilize); } else if (layout == ttnn::TILE_LAYOUT) { + std::cout << "tilizing1" << std::endl; if (tensor.is_sharded()) { const auto tensor_tile = tensor.get_tensor_spec().tile(); uint32_t tile_height = tensor_tile.get_height(); diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp b/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp index e566e554d39..8e9f0c426c7 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "tilize.hpp" +#include #include "device/tilize_op.hpp" #include "ttnn/common/queue_id.hpp" diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index 30bab3457b6..8fc1e6e3de7 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -763,6 +763,7 @@ Tensor Tensor::to_device(IDevice* target_device, const MemoryConfig& mem_config, Tensor Tensor::to_device(distributed::MeshDevice* mesh_device, const MemoryConfig& mem_config, QueueId cq_id) const { std::vector workers_to_use = ttnn::distributed::get_mapped_devices(*this, *mesh_device); + return tensor_ops::tensor_to_device(*this, workers_to_use, mem_config, cq_id); } diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index 05e51fc4fba..5001117f885 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -194,7 +194,6 @@ Tensor tensor_to_layout(const Tensor& input_tensor, Layout target_layout, distri host_storage != nullptr) { distributed_config = host_storage->strategy; } - Tensor tensor_modified_layout = Tensor(workers.size(), distributed_config); for (int worker_index = 0; worker_index < workers.size(); ++worker_index) { auto& worker = workers[worker_index]; diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 838a8cddd3d..496f7d8c5f9 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -95,9 +95,31 @@ def manage_config(name, value): from ttnn._ttnn.multi_device import ( + MeshDevice, + CppMeshToTensor, + TensorToMesh, + ReplicateTensorToMesh, + ShardTensorToMesh, + ShardTensorTo2dMesh, + CppConcatMeshToTensor, + CppConcat2dMeshToTensor, + ReplicateTensor, + ShardTensor, + ShardTensor2d, + ShardMesh, + AllGatherTensor, + DistributedTensorConfig, get_device_tensor, get_device_tensors, + get_distributed_tensor_config, aggregate_as_tensor, + replicate_tensor_to_mesh_mapper, + shard_tensor_to_mesh_mapper, + shard_tensor_to_2d_mesh_mapper, + concat_mesh_to_tensor_composer, + concat_2d_mesh_to_tensor_composer, + aggregate_tensor, + distribute_tensor, get_t3k_physical_device_ids_ring, ) diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py index bc90ce3cf20..ab0b239154d 100644 --- a/ttnn/ttnn/distributed/__init__.py +++ b/ttnn/ttnn/distributed/__init__.py @@ -2,9 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 +# TODO: All of the TensorTo and MeshTo classes will be slowly cut out over the next few days from .distributed import ( MeshDevice, DispatchCoreType, + MeshToTensor, + ConcatMeshToTensor, + ConcatMesh2dToTensor, open_mesh_device, close_mesh_device, get_num_pcie_devices, @@ -12,13 +16,7 @@ get_pcie_device_ids, get_device_ids, create_mesh_device, - TensorToMesh, - ShardTensorToMesh, - ShardTensor2dMesh, - ReplicateTensorToMesh, - MeshToTensor, - ConcatMeshToTensor, + synchronize_devices, visualize_mesh_device, - ConcatMesh2dToTensor, distribute, ) diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index db7e5f860e7..ec7480a4564 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -191,28 +191,31 @@ def create_mesh_device(*args, **kwargs): close_mesh_device(mesh_device) -class TensorToMesh: +def synchronize_devices( + devices: Union["ttnn.Device", "ttnn.MeshDevice"], + queue_id: Optional[int] = ttnn.DefaultQueueId, + sub_device_ids: List[ttnn.SubDeviceId] = [], +) -> None: """ - Defines the mapping of a torch.Tensor to a device mesh: e.g. Shard/Replicate. - You can also "Bring your own TensorToMesh" based on your custom mapping. - """ - - def __init__(self, mesh_device): - self.mesh_device = mesh_device - - def map(self, tensor: "torch.Tensor"): - raise NotImplementedError("Subclasses must implement this method") + synchronize_devices(devices: Union[ttnn.Device, ttnn.MeshDevice], queue_id: Optional[int] = None, sub_device_ids: List[ttnn.SubDeviceId] = []) -> None: - def config(self): - raise NotImplementedError("Subclasses must implement this method") + Synchronize the devices with host by waiting for all operations to complete. + If queue_id is provided then only the operations associated with that queue_id are waited for, + otherwise operations for all command queues are waited on. + """ + if isinstance(devices, ttnn.Device): + ttnn._ttnn.device.synchronize_device(devices, queue_id, sub_device_ids) + else: + for device in devices.get_device_ids(): + ttnn._ttnn.device.synchronize_device(devices.get_device(device), queue_id, sub_device_ids) +# TODO: All of the TensorTo and MeshTo classes will be slowly cut out over the next few days class MeshToTensor: """ Defines the inverse operation of TensorToMesh. Given a set of per-device ttnn.Tensor objects (aggregated into a single ttnn.Tensor), this class defines the mapping back to one or many torch.Tensor objects. - You can also "Bring your own MeshToTensor" based on your custom mapping. """ @@ -220,121 +223,9 @@ def compose(self, tensor: ttnn.Tensor): raise NotImplementedError("Subclasses must implement this method") -class ShardTensorToMesh(TensorToMesh): - def __init__(self, mesh_device, dim): - super().__init__(mesh_device) - self.shard_dim = dim - - def map(self, tensor: "torch.Tensor") -> Dict[int, ttnn.Tensor]: - import torch - - sliced_tensors = torch.chunk(tensor, self.mesh_device.get_num_devices(), dim=self.shard_dim) - return list(sliced_tensors) - - def config(self): - return { - "strategy": "shard", - "shard_dim": f"{self.shard_dim}", - } - - -class ShardTensor2dMesh(TensorToMesh): - """ - Shard a tensor across a 2D mesh of devices. - - This class implements a strategy for distributing a tensor across a 2D grid of devices, - allowing for efficient parallel processing in distributed computing environments. - """ - - def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[Optional[int], Optional[int]]): - """ - Initialize the ShardTensor2dMesh. - - Args: - mesh_device: The target device mesh for distributing the tensor. - mesh_shape: The shape of the 2D mesh as (rows, cols). - dims: The dimensions to shard along, specified as (row_dim, col_dim). - - The `dims` tuple determines how the tensor is sharded across the 2D mesh: - - row_dim: The dimension to shard across mesh rows (or None for replication). - - col_dim: The dimension to shard across mesh columns (or None for replication). - - Examples: - 1. dims=(2, 3) for a tensor of shape (A, B, C, D): - - Shard along dimension 2 (C) across mesh rows - - Shard along dimension 3 (D) across mesh columns - - 2. dims=(None, 3): - - Replicate across mesh rows - - Shard along dimension 3 (D) across mesh columns - - 3. dims=(None, None): - - Fully replicate the tensor across all devices - """ - super().__init__(mesh_device) - self.mesh_shape: Tuple[int, int] = mesh_shape - self.dims: Tuple[Optional[int], Optional[int]] = dims - - mesh_device_rows, mesh_device_cols = self.mesh_device.shape - if mesh_shape[0] > mesh_device_rows or mesh_shape[1] > mesh_device_cols: - raise ValueError("ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.") - - def map(self, tensor: "torch.Tensor") -> List["torch.Tensor"]: - """ - Map the input tensor to a list of sharded tensors. - - Args: - tensor: The input tensor to be sharded. - - Returns: - A list of sharded tensors, one for each device in the mesh. - - Raises: - ValueError: If the number of sharding dimensions is not 2. - """ - import torch - - if len(self.dims) != 2: - raise ValueError("ShardTensor2dMesh only supports 2D shard dimensions") - - rows, cols = self.mesh_shape - row_dim, col_dim = self.dims - - # Shard along rows - row_tensors = ( - [tensor.clone() for _ in range(rows)] if row_dim is None else torch.chunk(tensor, rows, dim=row_dim) - ) - - # Shard along columns - if col_dim is None: - return [t.clone() for t in row_tensors for _ in range(cols)] - tensor_shards = [tt for t in row_tensors for tt in torch.chunk(t, cols, dim=col_dim)] - - if len(tensor_shards) != rows * cols: - raise ValueError( - f"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh dimensions. Got {len(tensor_shards)} shards but expected {rows * cols} ({rows} rows * {cols} cols)." - ) - - return tensor_shards - - def config(self) -> Dict[str, str]: - """ - Provide the configuration of the sharding strategy. - - Returns: - A dictionary containing the sharding strategy and dimensions. - """ - return { - "strategy": "shard_2d", - "mesh_shape_y": str(self.mesh_shape[0]), - "mesh_shape_x": str(self.mesh_shape[1]), - } - - class ConcatMesh2dToTensor(MeshToTensor): """ Concatenate tensors from a 2D mesh back into a single tensor. - This class implements the inverse operation of ShardTensor2dMesh, combining sharded tensors from a 2D device mesh back into a single tensor. """ @@ -342,7 +233,6 @@ class ConcatMesh2dToTensor(MeshToTensor): def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[int, int]): """ Initialize the ConcatMesh2dToTensor. - Args: mesh_device: The source device mesh containing the sharded tensors. mesh_shape: The shape of the 2D mesh as (rows, cols). @@ -353,7 +243,6 @@ def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: T These dimensions correspond to the tensor dimensions, not the mesh dimensions. For example, if the original tensor was 4D with shape (batch, channel, height, width), and it was sharded across height and width, dims might be (-2, -1) or (2, 3). - Raises: ValueError: If either dimension in 'dims' is None or if both dimensions are the same. """ @@ -366,13 +255,10 @@ def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: T def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": """ Compose the sharded tensors back into a single tensor. - Args: tensor: A ttnn.Tensor object containing the sharded tensors distributed across multiple devices. - Returns: A single torch.Tensor that combines all the sharded tensors from all devices. - This method first concatenates the shards along the column dimension within each row, then concatenates the resulting tensors along the row dimension to form the final tensor. """ @@ -395,20 +281,6 @@ def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": return torch.cat(row_concatenated, dim=row_dim) -class ReplicateTensorToMesh(TensorToMesh): - def __init__(self, mesh_device: MeshDevice): - super().__init__(mesh_device) - - def map(self, tensor: "torch.Tensor"): - return [tensor for i in range(self.mesh_device.get_num_devices())] - - def config(self): - return { - "strategy": "replicate", - "replication_factor": str(self.mesh_device.get_num_devices()), - } - - class ConcatMeshToTensor(MeshToTensor): def __init__(self, mesh_device: MeshDevice, dim: int): self.concat_dim = dim @@ -424,7 +296,7 @@ def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": @contextlib.contextmanager -def distribute(default: Union[TensorToMesh, MeshToTensor]): +def distribute(default: Union[ttnn.TensorToMesh, ttnn.CppMeshToTensor, MeshToTensor]): """ Context manager to temporarily modify the behavior of ttnn.from_torch and ttnn.to_torch to use the specified mesh_mapper or mesh_composer for tensor distribution and composition to/from MeshDevice. @@ -436,20 +308,20 @@ def distribute(default: Union[TensorToMesh, MeshToTensor]): used to map tensors to a mesh or compose tensors from a mesh. Example: - with distribute(ShardTensorToMesh(mesh_device, dim=3)): + with distribute(shard_tensor_to_mesh_mapper(mesh_device, dim=3)): # Code here will use the default mapper result = ttnn.from_torch(torch_tensor) is equivalent to: - result = ttnn.from_torch(torch_tensor, mesh_mapper=ShardTensorToMesh(mesh_device, dim=3)) + result = ttnn.from_torch(torch_tensor, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3)) """ _original_to_torch = ttnn.to_torch _original_from_torch = ttnn.from_torch try: - if isinstance(default, TensorToMesh): + if isinstance(default, ttnn.TensorToMesh): ttnn.from_torch = functools.partial(_original_from_torch, mesh_mapper=default) - elif isinstance(default, MeshToTensor): + elif isinstance(default, MeshToTensor) or isinstance(default, ttnn.CppMeshToTensor): ttnn.to_torch = functools.partial(_original_to_torch, mesh_composer=default) else: raise ValueError("Argument must be an instance of either TensorToMesh or MeshToTensor.") diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 409480605bb..c4c6933078a 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -194,47 +194,53 @@ def from_torch( if memory_config.shard_spec.mode == ttnn.ShardMode.LOGICAL: return ttnn.Tensor(tensor, dtype, device, layout, memory_config, tile) + if memory_config is not None: + if device is None: + raise RuntimeError("ttnn.from_torch: device must be specified when memory_config is specified") + + if pad_value is not None: + if layout != ttnn.TILE_LAYOUT: + raise RuntimeError("ttnn.from_torch: layout must be TILE_LAYOUT when pad_value is specified") + logical_shape = None padded_shape = None + if dtype == ttnn.bfloat8_b or dtype == ttnn.bfloat4_b: if layout != ttnn.TILE_LAYOUT: raise RuntimeError("ttnn.from_torch: bfloat8_b/bfloat4_b requires TILE_LAYOUT!") - # Tilize tensor + # Tilize tensor, TODO: this is non-performant when done on host tensor = ttnn.from_torch(tensor, layout=ttnn.TILE_LAYOUT, tile=tile, pad_value=pad_value, mesh_mapper=None) logical_shape = tensor.shape padded_shape = tensor.padded_shape tensor = tensor.reshape(tensor.padded_shape) - tensor = ttnn.to_torch(tensor) - - if memory_config is not None: - if device is None: - raise RuntimeError("ttnn.from_torch: device must be specified when memory_config is specified") + else: + tensor = ttnn.Tensor(tensor, dtype) - if pad_value is not None: - if layout != ttnn.TILE_LAYOUT: - raise RuntimeError("ttnn.from_torch: layout must be TILE_LAYOUT when pad_value is specified") + strategy = {} + tilize_input = [] if mesh_mapper: - shards = mesh_mapper.map(tensor) - if tile is not None: - tensor = ttnn.Tensor(shards, dtype, mesh_mapper.config(), tile) - else: - tensor = ttnn.Tensor(shards, dtype, mesh_mapper.config()) - else: - if tile is not None: - tensor = ttnn.Tensor(tensor, dtype, {}, tile) - else: - tensor = ttnn.Tensor(tensor, dtype) + tensor = ttnn.distribute_tensor(tensor, mesh_mapper, device) + tilize_input = ttnn.to_torch(tensor) + + # TODO: find cleaner way of tilizing + if tile is not None: + tensor = ttnn.Tensor(tilize_input, dtype, strategy, tile) if layout is not None and not (dtype == ttnn.bfloat8_b or dtype == ttnn.bfloat4_b): if pad_value is not None: tensor = tensor.pad_to_tile(pad_value) + if ttnn.is_tensor_storage_on_device(tensor): + # TODO: support tilizing non bfloat/float types on device tensors making this conversion unnecessary + tensor = ttnn.from_device(tensor, cq_id=cq_id) tensor = ttnn.to_layout(tensor, layout, device=device) if device is not None: if memory_config is None: memory_config = ttnn.DRAM_MEMORY_CONFIG - tensor = ttnn.to_device(tensor, device, memory_config=memory_config, cq_id=cq_id) + # Handle sharding case which would have already output to a multidevice + if not ttnn.is_tensor_storage_on_device(tensor): + tensor = ttnn.to_device(tensor, device, memory_config=memory_config, cq_id=cq_id) if logical_shape is not None and logical_shape != tensor.shape and mesh_mapper is None: tensor = ttnn.reshape(tensor, logical_shape, padded_shape) @@ -269,7 +275,7 @@ def to_torch( dtype: Optional[torch.dtype] = None, *, torch_rank: Optional[int] = None, - mesh_composer: Optional[ttnn.MeshToTensor] = None, + mesh_composer: Optional[Union[ttnn.MeshToTensor, ttnn.CppMeshToTensor]] = None, device: Optional[ttnn.Device] = None, cq_id: Optional[int] = ttnn.DefaultQueueId, ) -> "torch.Tensor": @@ -302,7 +308,10 @@ def to_torch( tensor = ttnn.from_device(tensor, cq_id=cq_id) if mesh_composer: - return mesh_composer.compose(tensor) + if isinstance(mesh_composer, ttnn.MeshToTensor): + return mesh_composer.compose(tensor) + else: + return ttnn.aggregate_tensor(tensor, mesh_composer).to_torch() if tensor.storage_type() == ttnn.DEVICE_STORAGE_TYPE: raise RuntimeError("ttnn.Tensor cannot be on device when converting to torch.Tensor!") @@ -326,7 +335,6 @@ def to_torch( if tensor.shape[0] != 1: raise RuntimeError("ttnn: Unable to squeeze to desired rank!") tensor = tensor.squeeze(0) - torch_tensor = TorchTensor(tensor) if dtype is not None: @@ -616,9 +624,7 @@ def from_torch_and_dump( ttnn._ttnn.tensor.dump_tensor(cache_file_name, tensor, distributed_config) return tensor - if isinstance(mesh_mapper, ttnn.ReplicateTensorToMesh): - storage_type = f"_multi_device" if mesh_mapper else "" - elif mesh_mapper: + if mesh_mapper: storage_type = f"_multi_device_{device.get_num_devices()}" else: storage_type = ""