Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#15061: Expose replicate and 1D shard mappers #18720

Draft
wants to merge 76 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
1001e95
expose classes to python
jjiangTT Feb 7, 2025
c80d173
one type error left
jjiangTT Feb 8, 2025
8f59edc
move class definitions from from distributed_tensor.cpp to.hpp so the…
jjiangTT Feb 10, 2025
ff90ba9
fix naming errors, add tests, add imports - TODO, fix weird aliasing …
jjiangTT Feb 10, 2025
f4cb249
fix mesh device conflict, add aggregate/distribute and config pybinds…
jjiangTT Feb 14, 2025
8fc6e5f
add aggregate/distribute imports to init
jjiangTT Feb 14, 2025
f45b660
add configs to pybind
jjiangTT Feb 14, 2025
a5759f5
change test cases to use distribute/aggregate
jjiangTT Feb 14, 2025
cc14dd2
fix test mappers, convert to cpu_tensor
jjiangTT Feb 14, 2025
bd1b931
clean up imports, fix test cases and change them to use mapper/compos…
jjiangTT Feb 18, 2025
3f26cb2
remove python implementations
jjiangTT Feb 18, 2025
cb23266
fix rebase
jjiangTT Feb 18, 2025
b059826
clean up deprecated imports
jjiangTT Feb 18, 2025
24dead9
add shard2dconfig, concat2dconfig methods and map/compose constructors
jjiangTT Feb 19, 2025
381de5d
Replace none types, expose configs, fix tuple errors
jjiangTT Feb 19, 2025
54dd2d4
overload for concatmeshtotensor with meshdevice
jjiangTT Feb 19, 2025
d0678b3
remove extraneous comments
jjiangTT Feb 20, 2025
c2b9bc7
fix deviceconcat errors
jjiangTT Feb 20, 2025
a452991
add back distributed.py for now, clean up class overloads
jjiangTT Feb 20, 2025
24b703c
remove unused import
jjiangTT Feb 20, 2025
795e2b1
rearrange from_torch.py, start migrating cpp classes and testing inte…
jjiangTT Feb 20, 2025
5c160a9
expose classes to python
jjiangTT Feb 7, 2025
5db7735
one type error left
jjiangTT Feb 8, 2025
a21afeb
move class definitions from from distributed_tensor.cpp to.hpp so the…
jjiangTT Feb 10, 2025
935d2e5
fix naming errors, add tests, add imports - TODO, fix weird aliasing …
jjiangTT Feb 10, 2025
d02a0fb
fix mesh device conflict, add aggregate/distribute and config pybinds…
jjiangTT Feb 14, 2025
9afad34
add aggregate/distribute imports to init
jjiangTT Feb 14, 2025
1071396
add configs to pybind
jjiangTT Feb 14, 2025
7f54f90
change test cases to use distribute/aggregate
jjiangTT Feb 14, 2025
e9d21c5
fix test mappers, convert to cpu_tensor
jjiangTT Feb 14, 2025
670de83
clean up imports, fix test cases and change them to use mapper/compos…
jjiangTT Feb 18, 2025
1d1ff5a
remove python implementations
jjiangTT Feb 18, 2025
0b679db
fix rebase
jjiangTT Feb 18, 2025
c8feeae
add shard2dconfig, concat2dconfig methods and map/compose constructors
jjiangTT Feb 19, 2025
1d53fb9
Replace none types, expose configs, fix tuple errors
jjiangTT Feb 19, 2025
5a696a3
overload for concatmeshtotensor with meshdevice
jjiangTT Feb 19, 2025
bcf4508
remove extraneous comments
jjiangTT Feb 20, 2025
4c89683
fix deviceconcat errors
jjiangTT Feb 20, 2025
a6d2016
add back distributed.py for now, clean up class overloads
jjiangTT Feb 20, 2025
58b8d46
remove unused import
jjiangTT Feb 20, 2025
1d208e0
rearrange from_torch.py, start migrating cpp classes and testing inte…
jjiangTT Feb 20, 2025
b910b6d
interim work for supporting mappers
jjiangTT Feb 21, 2025
b116604
start trying to fix rebase errors
jjiangTT Feb 25, 2025
28cdd3b
fix rebase errors
jjiangTT Feb 25, 2025
058891e
fix last rebase errors, re-add borrowed support for aggregate_tensor,…
jjiangTT Feb 25, 2025
5dbc31e
add temporary debugging, re-add copyright header, add memoryconfig fo…
jjiangTT Feb 25, 2025
3257aaa
fix spec error
jjiangTT Feb 26, 2025
45c56e6
debugging prints for tilize, add switch back and move all classes bac…
jjiangTT Feb 26, 2025
06eacac
fix from_torch device, typing errors
jjiangTT Feb 27, 2025
7c83bcb
remove debug prints
jjiangTT Feb 28, 2025
e2c189f
reformat tilize, fix golden comparisons in testing, add direct_concat…
jjiangTT Feb 28, 2025
d3614d7
fix uint errors
jjiangTT Mar 4, 2025
cac1121
fix out of bounds error
jjiangTT Mar 4, 2025
befdcc5
actual fix with correct copy paste
jjiangTT Mar 4, 2025
a5c6847
make the switch, satisfy linker without dummy virtual function defini…
jjiangTT Mar 5, 2025
f4c994f
remove replicate distinction
jjiangTT Mar 5, 2025
08739b6
remove tensortomesh from distributed.py imports
jjiangTT Mar 5, 2025
a707c0b
remove duplicate meshtotensor imports
jjiangTT Mar 5, 2025
4ca9981
fix syntax error for shard
jjiangTT Mar 5, 2025
317f873
fix test syntax error
jjiangTT Mar 5, 2025
04dfba6
improved shape error message
jjiangTT Mar 5, 2025
c16ed91
syntax fix
jjiangTT Mar 5, 2025
b548526
rationalize composer check and method signature
jjiangTT Mar 5, 2025
2bc7f46
fix composer path
jjiangTT Mar 5, 2025
0ee4c22
fix memoryconfig error
jjiangTT Mar 5, 2025
954ef6c
cleanup
jjiangTT Mar 5, 2025
0f8d038
add back distributed.py since it has uses
jjiangTT Mar 6, 2025
e416e17
change llama_common based tests over
jjiangTT Mar 6, 2025
6820794
switch replicate
jjiangTT Mar 6, 2025
5244781
switch shardtensortomesh
jjiangTT Mar 6, 2025
29a10de
unsaved sharding switch
jjiangTT Mar 6, 2025
3ca0bbd
fix replacement errors
jjiangTT Mar 6, 2025
8f52467
fix more replace errors
jjiangTT Mar 6, 2025
cf9400f
fix replace errors x3
jjiangTT Mar 6, 2025
1a408dd
manual pre-commit
jjiangTT Mar 6, 2025
3746a28
add back distributed to imports, fix it
jjiangTT Mar 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/common/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=weight_memory_config,
cache_file_name=cache_name,
mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None,
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(device) if is_mesh_device else None,
)

if self.is_distributed:
Expand Down
4 changes: 2 additions & 2 deletions models/common/tests/test_rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml"

import ttnn
from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor
from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor

from models.common.rmsnorm import RMSNorm as TtRMSNorm
from models.utility_functions import (
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_rmsnorm_multidevice(t3k_mesh_device, is_sharded, use_program_cache, res
device=t3k_mesh_device,
dtype=dtype,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device),
mesh_mapper=replicate_tensor_to_mesh_mapper(t3k_mesh_device),
)

tt_output = tt_model(tt_input)
Expand Down
4 changes: 2 additions & 2 deletions models/demos/falcon7b_common/tests/test_falcon_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from loguru import logger
import ttnn
from ttnn import ShardTensorToMesh
from ttnn import shard_tensor_to_mesh_mapper
from models.demos.falcon7b_common.tt.falcon_mlp import TtFalconMLPDecode, TtFalconMLPPrefill
from models.demos.falcon7b_common.tt.model_config import get_model_config
from models.demos.falcon7b_common.tests.test_utils import load_hf_model, tt_from_torch, get_num_devices
Expand Down Expand Up @@ -79,7 +79,7 @@ def run_test_FalconMLP_inference(
dtype=model_config["DEFAULT_DTYPE"],
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ShardTensorToMesh(mesh_device, dim=0),
mesh_mapper=shard_tensor_to_mesh_mapper(mesh_device, dim=0),
)

tt_out = tt_FalconMLP_model(tt_mlp_input)
Expand Down
24 changes: 12 additions & 12 deletions models/demos/falcon7b_common/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import ttnn
from ttnn import ShardTensorToMesh, ReplicateTensorToMesh
from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper
from transformers import FalconForCausalLM
from models.utility_functions import tt_tensors_to_torch_tensors

Expand All @@ -20,14 +20,14 @@ def initialize_kv_cache(configuration, num_layers, batch_size, max_seq_len, mesh
dtype=ttnn.bfloat16,
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device),
)
tt_v_cache = tt_from_torch(
v_cache,
dtype=ttnn.bfloat16,
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device),
)
kv_cache += ((tt_k_cache, tt_v_cache),)
return kv_cache
Expand Down Expand Up @@ -106,7 +106,7 @@ def get_rand_falcon_inputs(
dtype=model_config["DEFAULT_DTYPE"],
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ShardTensorToMesh(mesh_device, dim=0),
mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0),
)

if model_config["PREFILL_OPTIMIZED_MODE"] and seq_len in [2048, 128, 1024]:
Expand All @@ -121,7 +121,7 @@ def get_rand_falcon_inputs(
dtype=ttnn.bfloat4_b,
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ShardTensorToMesh(mesh_device, dim=0),
mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0),
)
for attn_mask in attn_masks
]
Expand All @@ -131,7 +131,7 @@ def get_rand_falcon_inputs(
dtype=model_config["DEFAULT_DTYPE"],
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ShardTensorToMesh(mesh_device, dim=0),
mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0),
)

# Generate kvcache for each layer
Expand All @@ -145,14 +145,14 @@ def get_rand_falcon_inputs(
dtype=model_config["DEFAULT_DTYPE"],
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ShardTensorToMesh(mesh_device, dim=0),
mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0),
)
tt_v_cache = tt_from_torch(
tt_v_cache.unsqueeze(1),
dtype=model_config["DEFAULT_DTYPE"],
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ShardTensorToMesh(mesh_device, dim=0),
mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0),
)
tt_layer_past += ((tt_k_cache, tt_v_cache),)

Expand All @@ -169,7 +169,7 @@ def get_rand_falcon_inputs(
dtype=model_config["DEFAULT_DTYPE"],
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ShardTensorToMesh(mesh_device, dim=2),
mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=2),
)

attention_mask_bool = torch.zeros(global_batch, 1, q_len, kv_len, dtype=bool)
Expand Down Expand Up @@ -200,7 +200,7 @@ def get_rand_falcon_inputs(
device=mesh_device,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=model_config["ATTN_MASK_MEMCFG"],
mesh_mapper=ShardTensorToMesh(mesh_device, dim=device_shard_dim),
mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=device_shard_dim),
)
if not model_config["l1_sharded"]:
# Tilize attn masks
Expand All @@ -227,14 +227,14 @@ def get_rand_falcon_inputs(
dtype=model_config["DEFAULT_DTYPE"],
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ShardTensorToMesh(mesh_device, dim=0),
mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0),
)
tt_v_cache = tt_from_torch(
tt_v_cache.unsqueeze(1),
dtype=model_config["DEFAULT_DTYPE"],
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ShardTensorToMesh(mesh_device, dim=0),
mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0),
)
tt_layer_past += ((tt_k_cache, tt_v_cache),)

Expand Down
8 changes: 4 additions & 4 deletions models/demos/falcon7b_common/tt/falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from models.demos.falcon7b_common.tt.model_utils import get_falcon_default_core_grid
import ttnn
from ttnn import ReplicateTensorToMesh
from ttnn import replicate_tensor_to_mesh_mapper

from models.utility_functions import (
nearest_32,
Expand Down Expand Up @@ -155,7 +155,7 @@ def __init__(
dtype=model_config["DEFAULT_DTYPE"],
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device),
)

# optimized version can utilize single float value for softmax
Expand All @@ -175,7 +175,7 @@ def __init__(
device=self.mesh_device,
layout=ttnn.TILE_LAYOUT,
memory_config=self.model_config["ATTN_OPTIMIZED_MEMCFG"],
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device),
)
self.model_config["ATTN_OUTPUT_TENSORS"][seq_len] = tt_tensors

Expand Down Expand Up @@ -553,7 +553,7 @@ def __init__(
dtype=model_config["DEFAULT_DTYPE"],
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device),
)

def forward(
Expand Down
4 changes: 2 additions & 2 deletions models/demos/falcon7b_common/tt/falcon_causallm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
import ttnn
from ttnn import ReplicateTensorToMesh
from ttnn import replicate_tensor_to_mesh_mapper
from models.demos.falcon7b_common.tt.falcon_lm_head import falcon_lm_head_matmul_2d
from models.demos.falcon7b_common.tt.falcon_model import TtFalconModelShared
from models.demos.falcon7b_common.tt.model_utils import (
Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(
device=self.mesh_device,
layout=ttnn.TILE_LAYOUT,
memory_config=self.model_config["LM_HEAD_MM_INPUT_MEMCFG"],
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device),
)

self.lm_head_weights = get_weights_cached(
Expand Down
8 changes: 4 additions & 4 deletions models/demos/falcon7b_common/tt/falcon_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import ttnn
from ttnn import ReplicateTensorToMesh
from ttnn import replicate_tensor_to_mesh_mapper
from models.demos.falcon7b_common.tt.model_utils import (
get_falcon_default_core_grid,
get_weights_cached,
Expand Down Expand Up @@ -176,7 +176,7 @@ def _load_mlp_padded_tensors(self):
device=self.mesh_device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device),
)
mlp_padding_tensors[seq_len] = tt_padding
self.model_config["MLP_PREFILL_PADDING_TENSORS"] = mlp_padding_tensors
Expand All @@ -191,7 +191,7 @@ def _allocate_output_mlp_tensors(self):
device=self.mesh_device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device),
)
self.model_config["MLP_OUTPUT_TENSORS"] = out_tt

Expand Down Expand Up @@ -344,7 +344,7 @@ def _load_mlp_padded_tensors(self):
device=self.mesh_device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device),
)
self.model_config["MLP_DECODE_PADDING_TENSORS"] = tt_paddings

Expand Down
12 changes: 6 additions & 6 deletions models/demos/falcon7b_common/tt/falcon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
import ttnn
from ttnn import ReplicateTensorToMesh, ShardTensorToMesh
from ttnn import replicate_tensor_to_mesh_mapper, shard_tensor_to_mesh_mapper

from models.demos.falcon7b_common.tt.falcon_decoder import TtFalconDecoderLayer
from models.demos.falcon7b_common.tt.model_utils import get_weights_cached, layernorm
Expand Down Expand Up @@ -134,7 +134,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token
device=self.mesh_device,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=self.model_config["ATTN_MASK_MEMCFG"],
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device),
)
for attention_mask_slice in attention_mask_
]
Expand All @@ -156,7 +156,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token
device=self.mesh_device,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=self.model_config["ATTN_MASK_MEMCFG"],
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device),
)
# Repeat attn masks for all heads
tt_attention_mask = ttnn.repeat(
Expand All @@ -177,7 +177,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token
layout=ttnn.ROW_MAJOR_LAYOUT,
device=self.mesh_device,
memory_config=self.model_config["INPUT_MEMCFG"],
mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=0),
mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=0),
)
elif llm_mode == "decode":
assert batch_size % 32 == 0, "For decode, batch_size must be multiple of 32!"
Expand Down Expand Up @@ -210,7 +210,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token
device=self.mesh_device,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=self.model_config["ATTN_MASK_MEMCFG"],
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device),
)
if not self.model_config["l1_sharded"]:
# Tilize attn masks
Expand All @@ -226,7 +226,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token
layout=ttnn.ROW_MAJOR_LAYOUT,
device=self.mesh_device,
memory_config=self.model_config["INPUT_MEMCFG"],
mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=1),
mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1),
)
else:
raise NotImplementedError(f"Llm mode {llm_mode} is not supported! Must be one of prefill or decode.")
Expand Down
6 changes: 4 additions & 2 deletions models/demos/falcon7b_common/tt/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import ttnn
from ttnn import ReplicateTensorToMesh
from ttnn import replicate_tensor_to_mesh_mapper

from models.utility_functions import is_wormhole_b0

Expand Down Expand Up @@ -50,7 +50,9 @@ def preprocess_weights(weights_to_cache):
layout=tt_layout,
device=mesh_device,
memory_config=model_config[f"{weight_config_str}_MEMCFG"],
mesh_mapper=ReplicateTensorToMesh(mesh_device) if type(mesh_device) == ttnn.MeshDevice else None,
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device)
if type(mesh_device) == ttnn.MeshDevice
else None,
cache_file_name=str(path),
preprocess=preprocess_weights,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
##### TTNN imports #####
import ttnn
from ttnn import experimental as ttl
from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh
from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper
from models.utility_functions import skip_for_grayskull
from models.utility_functions import (
comp_pcc,
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_llama_class_embedding_inference(
layout=layout,
device=mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ReplicateTensorToMesh(mesh_device),
ttnn.replicate_tensor_to_mesh_mapper(mesh_device),
)
logger.info(f"TT Input tensor shape: {tt_input_tensor.shape}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
##### TTNN imports #####
import ttnn
from ttnn import experimental as ttl
from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh
from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper
from models.utility_functions import skip_for_grayskull
from models.utility_functions import (
comp_pcc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
dtype=ttnn.bfloat16,
mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=1),
mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1),
)
for _ in range(2)
]
Expand Down Expand Up @@ -170,15 +170,15 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device),
)
tt_full_text_mask = ttnn.from_torch(
full_text_mask_expand[b : b + 1],
device=mesh_device,
dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device),
)
tt_out = tt_model(
tt_tensor_x,
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device),
)
tt_xattn_mask = ttnn.reshape(
tt_xattn_mask,
Expand All @@ -224,7 +224,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device),
)
tt_full_text_mask = ttnn.reshape(
tt_full_text_mask,
Expand Down
Loading
Loading