Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stablehlo] Add batching dims to stablehlo.gather and stablehlo.scatter #2259

Merged
merged 50 commits into from
May 15, 2024

Conversation

tomnatan30
Copy link
Contributor

Add operand_batching_dims and start_indices_batching_dims attributes to stablehlo.gather. operand_batching_dims refers to the dimensions of the operand that are treated as batch. start_indices_batching_dims refers to the dimensions of the start_indices that are treated as batch. The corresponding dimension sizes must be equal. The semantics is equivalent to concatenating the outputs of the gather with each slices of operand and start_indices.

Similarly, add input_batching_dims and scatter_indices_batching_dims attributes to stablehlo.scatter. input_batching_dims refers to the dimensions of each tensor in inputs that are treated as batch. scatter_indices_batching_dims refers to the dimensions of the scatter_indices that are treated as batch.

See #2084 for more information

@tomnatan30 tomnatan30 changed the title [stablehlo] Add batching dims to stablehlo.gather and stablehlo.scatter [stablehlo] Add batching dims to stablehlo.gather and stablehlo.scatter Apr 26, 2024
@tomnatan30 tomnatan30 closed this Apr 26, 2024
@tomnatan30 tomnatan30 reopened this Apr 26, 2024
Copy link
Member

@ghpvnist ghpvnist left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the thorough implementation! This is round 1 of X of my reviews :)
I haven't reviewed ops_stablehlo.mlir and verify_scatter.mlir because my brain is fried after reviewing this, and I expect the test names to change with the constraint number bump so I'll leave the review until the remaining feedbacks are addressed.

docs/spec.md Show resolved Hide resolved
docs/spec.md Show resolved Hide resolved
stablehlo/dialect/StablehloOps.td Outdated Show resolved Hide resolved
stablehlo/tests/interpret/gather.mlir Show resolved Hide resolved
stablehlo/dialect/TypeInference.cpp Outdated Show resolved Hide resolved
stablehlo/dialect/TypeInference.cpp Outdated Show resolved Hide resolved
stablehlo/dialect/TypeInference.cpp Outdated Show resolved Hide resolved
stablehlo/dialect/TypeInference.cpp Outdated Show resolved Hide resolved
stablehlo/dialect/TypeInference.cpp Outdated Show resolved Hide resolved
stablehlo/reference/Ops.cpp Show resolved Hide resolved
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 16, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to google-deepmind/tf2jax that referenced this pull request Sep 21, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 21, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Sep 21, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Sep 21, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to google-deepmind/tf2jax that referenced this pull request Sep 23, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 23, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Sep 23, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to google-deepmind/tf2jax that referenced this pull request Sep 24, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 24, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Sep 24, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to google-deepmind/tf2jax that referenced this pull request Sep 24, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 24, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to google-deepmind/tf2jax that referenced this pull request Sep 24, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 24, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 24, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 24, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 24, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 25, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 25, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 25, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 25, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 25, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 25, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 25, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 25, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 25, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 25, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Sep 25, 2024
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 678649138
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants