Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing sharding specs when annotating sharding over views #8662

Open
rpsilva-aws opened this issue Feb 1, 2025 · 6 comments
Open

Missing sharding specs when annotating sharding over views #8662

rpsilva-aws opened this issue Feb 1, 2025 · 6 comments
Assignees
Labels
bug Something isn't working SPMD / Distributed

Comments

@rpsilva-aws
Copy link
Collaborator

rpsilva-aws commented Feb 1, 2025

🐛 Bug

The HLO instruction for the custom sharding call is missing the sharding specs, leading to has_sharding failures on XLA:

RuntimeError: Bad StatusOr access: INVALID_ARGUMENT: HloOptimization: error condition !(status.ok()): 13RET_CHECK failure (external/xla/xla/service/sharding_propagation.cc:1464) instruction->has_sharding() Sharding instruction must have a sharding attribute

This issue was earlier identified in #8427, but with manual sharding. @JackCaoG did some investigation, but we didn't entirely RCA the issue yet. The issue can be minimally reproduce with the mark sharding as well, and we observe the same problem when adding any custom sharding prior to the input layer normalization for Llama3.

To Reproduce

  1. Similar underlying behavior as the embedding:
device_ids = list(range(32))
mesh = xs.Mesh(device_ids, (1, 1, 32), ('data', 'other', 'model'))
device = xm.xla_device()

indices = torch.zeros((1, 4096), dtype=torch.int64).to(device)  # p0.1 shape
weight = torch.randn((128256, 4096), dtype=torch.float32).to(device)  # p1.3 shape
xs.mark_sharding(weight, mesh, ("model", None))

r0 = torch.index_select(weight, 0, indices.view(-1)).view(1, 4096, 4096)  # or reshape

xs.mark_sharding(r0, mesh, ("data", "model", None))

r0 = r0.view(1, 4096, 4096)  # or reshape

print(r0)
HloModule IrToHlo.11, entry_computation_layout={(s64[1,4096]{1,0}, f32[128256,4096]{1,0})->(f32[1,4096,4096]{2,1,0})}

ENTRY %IrToHlo.11 (p0.1: s64[1,4096], p1.3: f32[128256,4096]) -> (f32[1,4096,4096]) {
  %p1.3 = f32[128256,4096]{1,0} parameter(1), sharding={devices=[32,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
  %p0.1 = s64[1,4096]{1,0} parameter(0), sharding={replicated}
  %reshape.2 = s64[4096]{0} reshape(s64[1,4096]{1,0} %p0.1)
  %convert.4 = u32[4096]{0} convert(s64[4096]{0} %reshape.2)
  %gather.5 = f32[4096,4096]{1,0} gather(f32[128256,4096]{1,0} %p1.3, u32[4096]{0} %convert.4), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,4096}
  %reshape.6 = f32[1,4096,4096]{2,1,0} reshape(f32[4096,4096]{1,0} %gather.5)
  %custom-call.7 = f32[1,4096,4096]{2,1,0} custom-call(f32[1,4096,4096]{2,1,0} %reshape.6), custom_call_target="Sharding"
  %reshape.8 = f32[4096,4096]{1,0} reshape(f32[1,4096,4096]{2,1,0} %custom-call.7)
  %reshape.9 = f32[1,4096,4096]{2,1,0} reshape(f32[4096,4096]{1,0} %reshape.8)
  ROOT %tuple.10 = (f32[1,4096,4096]{2,1,0}) tuple(f32[1,4096,4096]{2,1,0} %reshape.9)
}
  1. flash_attention: support also cross attention. #8427

Expected behavior

We expect the appropriate sharding spec to be present in the custom sharding call, namely (for 1) above), to include:

sharding={devices=[1,32,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
  • torch_xla version: 2.6
@rpsilva-aws
Copy link
Collaborator Author

cc: @miladm, similar to #8427.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Feb 1, 2025

I thought @yaochengji already fixed this.

@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented Feb 1, 2025

I pointed to #8427 as possibly being the same underlying issue - but that one was without functionalization. The failure even happens with (smaller) cases, but I don't see why it should end up causing an issue on XLA (let me know if this is expected - since this is a minimal repro example that may not capture a real use case). In the end, we're trying to reconcile different sharding specs before and after the reshape/view, but we see this with Llama 3 when trying to add custom sharding at different points (e.g. SP):

HloModule IrToHlo.7, entry_computation_layout={(f32[128256,4096]{1,0})->(f32[1,128256,4096]{2,1,0})}

ENTRY %IrToHlo.7 (p0.1: f32[128256,4096]) -> (f32[1,128256,4096]) {
  %p0.1 = f32[128256,4096]{1,0} parameter(0), sharding={devices=[32,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
  %reshape.2 = f32[1,128256,4096]{2,1,0} reshape(f32[128256,4096]{1,0} %p0.1)
  %custom-call.3 = f32[1,128256,4096]{2,1,0} custom-call(f32[1,128256,4096]{2,1,0} %reshape.2), custom_call_target="Sharding"
  %reshape.4 = f32[128256,4096]{1,0} reshape(f32[1,128256,4096]{2,1,0} %custom-call.3)
  %reshape.5 = f32[1,128256,4096]{2,1,0} reshape(f32[128256,4096]{1,0} %reshape.4)
  ROOT %tuple.6 = (f32[1,128256,4096]{2,1,0}) tuple(f32[1,128256,4096]{2,1,0} %reshape.5)
}

following with a reshape on a subsequent custom sharding:

    xs.mark_sharding(weight, mesh, ("model", None))
    r0 = weight.reshape(1, -1, 4096)
    xs.mark_sharding(r0, mesh, ("data", "model", None))
    r0 = r0.reshape(1, -1, 4096)

@miladm miladm added bug Something isn't working SPMD / Distributed labels Feb 3, 2025
@miladm
Copy link
Collaborator

miladm commented Feb 3, 2025

Thanks for submitting this issue. #8427 solved for addressing an issue that was only visible without functionalization. A few clarifying questions while @yaochengji gets a chance to look into this issue:

  • I didn't take a close look into the mesh + sharding spec yet. Does this config work on the equivalent JAX code?
  • Does your implementation use XLA_DISABLE_FUNCTIONALIZATION=0?
  • We have an implementation of Llama3 on tpu-recipe with SPMD support. I wonder if you've had a chance to look into it.

cc @lsy323 @yaochengji for viz

@miladm miladm assigned ysiraichi and unassigned yaochengji Feb 3, 2025
@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented Feb 3, 2025

Thanks Milad.

I didn't take a close look into the mesh + sharding spec yet. Does this config work on the equivalent JAX code?

I will try to reproduce in the meantime.

Does your implementation use XLA_DISABLE_FUNCTIONALIZATION=0?

Yes.

We have an implementation of Llama3 on tpu-recipe with SPMD support. I wonder if you've had a chance to look into it.

We have a working Llama3 with SPMD, but this issue comes when there's the need to reconcile different sharding specifications across views/reshapes. For instance, if using sequence parallelism to shard the sequence dimension at any point (e.g. prior to the first layer norm), this can be an issue if needing to reshape the hidden tensors at any point.

@rpsilva-aws
Copy link
Collaborator Author

It actually only reproduces with the Neuron PjRt backend, and not with TPU/CPU. It could be an outdated XLA version.

Hence, I'll take this item, and see if there's any backwards compatible change needed here, or ideally, that we properly sort it out outside of torch-xla.

@rpsilva-aws rpsilva-aws assigned rpsilva-aws and unassigned ysiraichi Feb 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working SPMD / Distributed
Projects
None yet
Development

No branches or pull requests

5 participants