From 954165331b554f96bceb3b1d8b1affb6e259c964 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Wed, 5 Feb 2025 04:06:20 -0800 Subject: [PATCH] #sdy Add support for all-to-all in `-sdy-reshard-to-collectives` and make further improvements. Improvements: 1. Introduces a CollectiveInserter class that holds and manipulates the current state of the transformation. 2. Simplify the handling of sub-axes by aligning them between the input and output shardings, so that we can treat them as full axes and use them as keys in hash maps. 3. Make all-slice and all-to-all insertion greedy, i.e., we slice axes if they are available and all-to-all axes if they are the suffix of the source dimension, regardless of whether they will end up in the right position in the target dimension. 4. All-slice in a dimension that isn't the one in the output sharding, if we can then all-to-all the sliced axis together with an axis that was already in that source dimension, to the same target dimension. A few notes: 1. Greedy all-slice and all-to-all is only possible if the target dimension has the capacity, we ignore this constraint in this CL and will address in a follow up. 2. This pass is still missing support for collective-permute, which is crucial for getting optimal collectives in cases where axes are reordered or replaced with others in the same dimension. Added some use cases to the test. PiperOrigin-RevId: 723451739 --- shardy/dialect/sdy/ir/attrs.td | 40 +- shardy/dialect/sdy/ir/dialect.cc | 77 +- shardy/dialect/sdy/ir/dialect_test.cc | 104 +-- .../export/reshard_to_collectives.cc | 769 ++++++++++++------ .../export/test/reshard_to_collectives.mlir | 264 ++++-- 5 files changed, 802 insertions(+), 452 deletions(-) diff --git a/shardy/dialect/sdy/ir/attrs.td b/shardy/dialect/sdy/ir/attrs.td index 8c1dc729..c4ef24a2 100644 --- a/shardy/dialect/sdy/ir/attrs.td +++ b/shardy/dialect/sdy/ir/attrs.td @@ -362,25 +362,25 @@ def Sdy_AxisRef : AttrDef { // "a":(1)3, "a":(2)3 -> false bool canCoexist(AxisRefAttr other) const; - // Returns the largest prefix of this axis that overlaps with `other`, or - // `std::nullopt` if the prefix does not exist. + // If this axis or sub-axis overlaps with `other`, returns that overlapping + // axis or sub-axis, otherwise returns `std::nullopt`. // // If this axis and `other` can't coexist, returns `std::nullopt` (see // AxisRefAttr::canCoexist). // // For example: - // "a", "a" -> "a" - // "a":(2)2, "a" -> "a":(2)2 + // "a", "a":(2)2 -> "a":(2)2 // "a":(2)2, "a":(2)2 -> "a":(2)2 + // "a":(1)4, "a":(2)4 -> "a":(2)2 + // "a":(2)4, "a":(1)4 -> "a":(2)2 // "a":(1)4, "a":(1)2 -> "a":(1)2 - // "a":(2)8, "a":(1)4 -> "a":(2)2 - // "a", "b" -> std::nullopt - // "a":(2)2, "b" -> std::nullopt - // "a":(1)4, "a":(2)4 -> std::nullopt + // "a":(2)8, "a":(4)2 -> "a":(4)2 + // "a":(1)4, "a":(4)2 -> std::nullopt + // "a":(1)2, "a":(4)2 -> std::nullopt + // "a":(1)4, "b":(2)4 -> std::nullopt // "a":(1)2, "a":(1)3 -> std::nullopt // "a":(3)2, "a":(2)3 -> std::nullopt - std::optional getPrefixWithOverlap( - AxisRefAttr other, MeshAttr mesh) const; + std::optional getOverlap(AxisRefAttr other) const; // If there is no overlap between this axis and `other`, return this axis. // Otherwise, return the largest prefix of this axis by removing the @@ -443,26 +443,6 @@ def Sdy_AxisRef : AttrDef { // "a":(1)2, "a":(2)4 -> std::nullopt std::optional getGreatestCommonPrefix(AxisRefAttr other) const; - // Removes the common prefix of this axis and `other` from this axis. If the - // two axes do not have common prefix or `other` is greater or equal to this - // axis, return `std::nullopt`. - // - // If this axis and `other` can't coexist, returns `std::nullopt` (see - // AxisRefAttr::canCoexist). - // - // For example: - // "a", "a":(1)4 -> "a":(4)2 (size("a") == 8) - // "a":(1)4, "a":(1)2 -> "a":(2)2 - // "a":(2)8, "a":(2)4 -> "a":(8)2 - // "a", "b" -> std::nullopt - // "a", "a" -> std::nullopt - // "a":(1)4, "a" -> std::nullopt - // "a":(2)4, "a":(2)8 -> std::nullopt - // "a":(1)2, "a":(2)4 -> std::nullopt - // "a":(1)2, "a":(1)3 -> std::nullopt - std::optional removeCommonPrefix( - AxisRefAttr prefix, MeshAttr mesh) const; - // Returns whether this axis-ref can be merged with `other`, i.e., they are // consecutive sub-axes of the same full axis and this sub-axis is major to // `other`. diff --git a/shardy/dialect/sdy/ir/dialect.cc b/shardy/dialect/sdy/ir/dialect.cc index 01ec5019..a912576e 100644 --- a/shardy/dialect/sdy/ir/dialect.cc +++ b/shardy/dialect/sdy/ir/dialect.cc @@ -436,6 +436,22 @@ bool AxisRefAttr::overlaps(AxisRefAttr other) const { otherSubAxisInfo.getPreSize() < thisSubAxisInfo.getNextPreSize(); } +namespace { + +bool canSubAxesCoexist(int64_t minPreSize, int64_t maxPreSize, + int64_t minNextPreSize, int64_t maxNextPreSize) { + if (minNextPreSize > maxPreSize) { + // Sub-axes overlap, check if overlapping and non-overlapping parts are + // valid. + return minNextPreSize % maxPreSize == 0 && maxPreSize % minPreSize == 0 && + maxNextPreSize % minNextPreSize == 0; + } + // Sub-axes don't overlap, check if the gap is valid. + return maxPreSize % minNextPreSize == 0; +} + +} // namespace + bool AxisRefAttr::canCoexist(AxisRefAttr other) const { if (getName() != other.getName()) { return true; @@ -457,31 +473,46 @@ bool AxisRefAttr::canCoexist(AxisRefAttr other) const { auto [minNextPreSize, maxNextPreSize] = std::minmax(thisNextPreSize, otherNextPreSize); - if (minNextPreSize > maxPreSize) { - // Sub-axes overlap, check if overlapping and non-overlapping parts are - // valid. - return minNextPreSize % maxPreSize == 0 && maxPreSize % minPreSize == 0 && - maxNextPreSize % minNextPreSize == 0; - } - // Sub-axes don't overlap, check if the gap is valid. - return maxPreSize % minNextPreSize == 0; + return canSubAxesCoexist(minPreSize, maxPreSize, minNextPreSize, + maxNextPreSize); } -std::optional AxisRefAttr::getPrefixWithOverlap( - AxisRefAttr other, MeshAttr mesh) const { - int64_t thisPreSize = getSubAxisPreSize(); - if (!canCoexist(other) || !overlaps(other) || - other.getSubAxisPreSize() > thisPreSize) { +std::optional AxisRefAttr::getOverlap(AxisRefAttr other) const { + if (other.getName() != getName()) { return std::nullopt; } - if (other.contains(*this)) { + + SubAxisInfoAttr thisSubAxisInfo = getSubAxisInfo(); + SubAxisInfoAttr otherSubAxisInfo = other.getSubAxisInfo(); + + if (!thisSubAxisInfo) { + // This is a full axis. + return other; + } + + if (!otherSubAxisInfo) { + // Other is a full axis. return *this; } - int64_t thisNextPreSize = getNextPreSizeOrFullSize(mesh); - int64_t otherNextPreSize = other.getNextPreSizeOrFullSize(mesh); - return AxisRefAttr::get( - getContext(), getName(), thisPreSize, - std::min(thisNextPreSize, otherNextPreSize) / thisPreSize); + + int64_t thisPreSize = thisSubAxisInfo.getPreSize(); + int64_t otherPreSize = otherSubAxisInfo.getPreSize(); + int64_t thisNextPreSize = thisSubAxisInfo.getNextPreSize(); + int64_t otherNextPreSize = otherSubAxisInfo.getNextPreSize(); + + auto [minPreSize, maxPreSize] = std::minmax(thisPreSize, otherPreSize); + auto [minNextPreSize, maxNextPreSize] = + std::minmax(thisNextPreSize, otherNextPreSize); + + if (minNextPreSize <= maxPreSize || + !canSubAxesCoexist(minPreSize, maxPreSize, minNextPreSize, + maxNextPreSize)) { + // No overlap or can't co-exist. + return std::nullopt; + } + + return AxisRefAttr::get(getContext(), getName(), maxPreSize, + minNextPreSize / maxPreSize); } std::optional AxisRefAttr::getPrefixWithoutOverlap( @@ -556,14 +587,6 @@ std::optional AxisRefAttr::getGreatestCommonPrefix( return std::nullopt; } -std::optional AxisRefAttr::removeCommonPrefix( - AxisRefAttr prefix, MeshAttr mesh) const { - if (!prefix.strictPrefixOf(*this)) { - return std::nullopt; - } - return getSuffixWithoutOverlap(prefix, mesh); -} - //===----------------------------------------------------------------------===// // DimensionShardingAttr //===----------------------------------------------------------------------===// diff --git a/shardy/dialect/sdy/ir/dialect_test.cc b/shardy/dialect/sdy/ir/dialect_test.cc index cde03344..ffbf8de2 100644 --- a/shardy/dialect/sdy/ir/dialect_test.cc +++ b/shardy/dialect/sdy/ir/dialect_test.cc @@ -252,57 +252,47 @@ TEST_F(DialectTest, AxisRefAttrCompare) { compare(createSubAxis("x", 1, 4), createSubAxis("x", 2, 2)); } -TEST_F(DialectTest, AxisRefAttrGetPrefixWithOverlap) { - auto mesh = MeshAttr::get(&context, {MeshAxisAttr::get(&context, "x", 16), - MeshAxisAttr::get(&context, "y", 4)}); - auto samePrefix = [&](AxisRefAttr a, AxisRefAttr b) { - AxisRefAttr smaller = std::min(a, b); - EXPECT_EQ(a.getPrefixWithOverlap(b, mesh), smaller); - EXPECT_EQ(b.getPrefixWithOverlap(a, mesh), smaller); +TEST_F(DialectTest, AxisRefAttrGetOverlap) { + auto contained = [](AxisRefAttr small, AxisRefAttr large) { + EXPECT_TRUE(large.contains(small)); + EXPECT_EQ(large.getOverlap(small), small); + EXPECT_EQ(small.getOverlap(large), small); }; - samePrefix(createAxis("x"), createAxis("x")); - samePrefix(createSubAxis("x", 2, 2), createSubAxis("x", 2, 2)); - samePrefix(createSubAxis("x", 1, 4), createSubAxis("x", 1, 2)); - - // "x":(2)4 and "x" - EXPECT_EQ( - createSubAxis("x", 2, 4).getPrefixWithOverlap(createAxis("x"), mesh), - createSubAxis("x", 2, 4)); - EXPECT_EQ( - createAxis("x").getPrefixWithOverlap(createSubAxis("x", 2, 4), mesh), - std::nullopt); - - // "x":(2)4 and "x":(1)4 - EXPECT_EQ(createSubAxis("x", 2, 4).getPrefixWithOverlap( - createSubAxis("x", 1, 4), mesh), - createSubAxis("x", 2, 2)); - EXPECT_EQ(createSubAxis("x", 1, 4).getPrefixWithOverlap( - createSubAxis("x", 2, 4), mesh), - std::nullopt); - - // "x"(4)2 and "x":(2)8 - EXPECT_EQ(createSubAxis("x", 4, 2).getPrefixWithOverlap( - createSubAxis("x", 2, 8), mesh), - createSubAxis("x", 4, 2)); - EXPECT_EQ(createSubAxis("x", 2, 8).getPrefixWithOverlap( - createSubAxis("x", 4, 2), mesh), - std::nullopt); + contained(createAxis("x"), createAxis("x")); + contained(createSubAxis("x", 1, 4), createAxis("x")); + contained(createSubAxis("x", 4, 8), createAxis("x")); + contained(createSubAxis("x", 2, 2), createSubAxis("x", 2, 2)); + contained(createSubAxis("x", 1, 2), createSubAxis("x", 1, 4)); + contained(createSubAxis("x", 2, 2), createSubAxis("x", 1, 4)); + contained(createSubAxis("x", 2, 2), createSubAxis("x", 1, 8)); + + auto overlaps = [](AxisRefAttr a, AxisRefAttr b, AxisRefAttr expected) { + EXPECT_EQ(a.getOverlap(b), expected); + EXPECT_EQ(b.getOverlap(a), expected); + }; + overlaps(createSubAxis("x", 1, 4), createSubAxis("x", 2, 4), + createSubAxis("x", 2, 2)); + overlaps(createSubAxis("x", 4, 4), createSubAxis("x", 2, 4), + createSubAxis("x", 4, 2)); - auto checkNoOverlap = [&](AxisRefAttr a, AxisRefAttr b) { - EXPECT_EQ(a.getPrefixWithOverlap(b, mesh), std::nullopt); - EXPECT_EQ(b.getPrefixWithOverlap(a, mesh), std::nullopt); + auto checkNoOverlap = [](AxisRefAttr a, AxisRefAttr b) { + EXPECT_FALSE(a.overlaps(b)); + EXPECT_EQ(a.getOverlap(b), std::nullopt); + EXPECT_EQ(b.getOverlap(a), std::nullopt); }; checkNoOverlap(createAxis("x"), createAxis("y")); checkNoOverlap(createAxis("x"), createSubAxis("y", 1, 2)); checkNoOverlap(createSubAxis("x", 1, 4), createSubAxis("x", 4, 2)); checkNoOverlap(createSubAxis("x", 1, 2), createSubAxis("x", 4, 2)); - auto checkCannotCoexist = [&](AxisRefAttr a, AxisRefAttr b) { - EXPECT_EQ(a.getPrefixWithOverlap(b, mesh), std::nullopt); - EXPECT_EQ(b.getPrefixWithOverlap(a, mesh), std::nullopt); + auto checkCannotCoexist = [](AxisRefAttr a, AxisRefAttr b) { + EXPECT_FALSE(a.canCoexist(b)); + EXPECT_EQ(a.getOverlap(b), std::nullopt); + EXPECT_EQ(b.getOverlap(a), std::nullopt); }; - checkCannotCoexist(createSubAxis("x", 1, 2), createSubAxis("x", 1, 3)); - checkCannotCoexist(createSubAxis("x", 3, 2), createSubAxis("x", 2, 3)); + checkCannotCoexist(createSubAxis("x", 1, 2), createSubAxis("x", 3, 2)); + checkCannotCoexist(createSubAxis("x", 1, 3), createSubAxis("x", 2, 3)); + checkCannotCoexist(createSubAxis("x", 2, 3), createSubAxis("x", 3, 2)); } // The test cases are the same as DialectTest.AxisRefAttrOverlaps. @@ -453,36 +443,6 @@ TEST_F(DialectTest, AxisRefAttrGetGreatestCommonPrefix) { prefix(createSubAxis("x", 2, 4), createSubAxis("x", 2, 8)); } -TEST_F(DialectTest, AxisRefAttrRemoveCommonPrefix) { - auto mesh = MeshAttr::get(&context, {MeshAxisAttr::get(&context, "x", 16), - MeshAxisAttr::get(&context, "y", 4)}); - auto isNotPrefix = [&](AxisRefAttr a, AxisRefAttr b) { - EXPECT_EQ(a.removeCommonPrefix(b, mesh), std::nullopt); - EXPECT_EQ(b.removeCommonPrefix(a, mesh), std::nullopt); - }; - isNotPrefix(createAxis("x"), createAxis("y")); - isNotPrefix(createSubAxis("x", 1, 2), createSubAxis("y", 1, 2)); - isNotPrefix(createSubAxis("x", 1, 2), createSubAxis("x", 2, 4)); - isNotPrefix(createSubAxis("x", 1, 2), createSubAxis("x", 1, 3)); - - auto equals = [&](AxisRefAttr a) { - EXPECT_EQ(a.removeCommonPrefix(a, mesh), std::nullopt); - }; - equals(createAxis("x")); - equals(createSubAxis("x", 2, 4)); - - auto prefix = [&](AxisRefAttr small, AxisRefAttr large, - AxisRefAttr expected) { - EXPECT_EQ(large.removeCommonPrefix(small, mesh), expected); - EXPECT_EQ(small.removeCommonPrefix(large, mesh), std::nullopt); - }; - prefix(createSubAxis("x", 1, 4), createAxis("x"), createSubAxis("x", 4, 4)); - prefix(createSubAxis("x", 1, 2), createSubAxis("x", 1, 4), - createSubAxis("x", 2, 2)); - prefix(createSubAxis("x", 2, 4), createSubAxis("x", 2, 8), - createSubAxis("x", 8, 2)); -} - TEST_F(DialectTest, TensorShardingAttrCanShardOrReplicate) { TensorShardingAttr sharding = createTensorSharding( {createDimSharding({createAxis("x"), createSubAxis("z", 2, 2)}, diff --git a/shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc b/shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc index 955f0930..19ac62eb 100644 --- a/shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc +++ b/shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc @@ -15,17 +15,20 @@ limitations under the License. #include #include +#include #include #include #include // IWYU pragma: keep #include -#include #include +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep +#include "mlir/IR/Attributes.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" // IWYU pragma: keep @@ -46,36 +49,13 @@ namespace { using OptionalAxisRef = std::optional; -// We use an std::list so we can pop from the front and the back and with an +using AxesPerDim = SmallVector>; + +// We use an std::list so we can pop from the front, back, and with a specific // iterator at constant time. -// TODO(tomnatan): Consider using AxisListRef instead of std::list once it can -// also replace the first axis in the list with a different sub-axis. using AxisList = std::list; -// We use an std::set so sub-axes are ordered by their pre-size and size, and -// we can use set::lower_bound to find the first overlapping axis (see -// getFirstOverlapping). -using AvailableAxes = std::set; - -// Removes the common prefix of both `first` and `second`. -void removeCommonPrefix(AxisList& first, AxisList& second, MeshAttr mesh) { - while (!first.empty() && !second.empty() && first.front() == second.front()) { - first.pop_front(); - second.pop_front(); - } - if (first.empty() || second.empty()) { - return; - } - if (OptionalAxisRef suffix = - first.front().removeCommonPrefix(second.front(), mesh)) { - first.front() = *suffix; - second.pop_front(); - } else if (OptionalAxisRef suffix = - second.front().removeCommonPrefix(first.front(), mesh)) { - second.front() = *suffix; - first.pop_front(); - } -} +using AxisRefToDimMap = llvm::SmallDenseMap; // Returns a vector of `InnerAxisList` per dimension from the given `sharding`. template @@ -88,304 +68,555 @@ SmallVector getAxesPerDim(TensorShardingAttr sharding) { return axesPerDim; } -AvailableAxes::const_iterator getPrevOrEnd(AvailableAxes::iterator it, - const AvailableAxes& availableAxes) { - return it == availableAxes.begin() ? availableAxes.end() : std::prev(it); -} - -// Returns an iterator to the first axis in `availableAxes` that overlaps with -// `axis`, or `availableAxes.end()` if there is no such axis. -AvailableAxes::iterator getFirstOverlapping( - AxisRefAttr axis, const AvailableAxes& availableAxes) { - if (availableAxes.empty()) { - return availableAxes.end(); +// Returns an iterator to the first axis in `orderedAxes` that overlaps with +// `axis`, or `orderedAxes.end()` if there is no such axis. +ArrayRef::iterator getFirstOverlapping( + AxisRefAttr axis, ArrayRef orderedAxes) { + if (orderedAxes.empty()) { + return orderedAxes.end(); } - auto afterIt = availableAxes.lower_bound(axis); - auto beforeIt = getPrevOrEnd(afterIt, availableAxes); + auto* afterIt = llvm::lower_bound(orderedAxes, axis); // If there is at least one overlapping axis, the first one is necessarily - // `afterIt` or `beforeIt`. + // `afterIt` or `beforeIt = std::prev(afterIt)`. // // Proof: - // Let `axis` be A and the first overlapping axis in `availableAxes` be B. + // Given the definition of `lower_bound`, we have `beforeIt < A <= afterIt`, + // where A is `axis`. // - // Note that there can't be two overlapping available axes. `lower_bound` - // returns the first available axis greater or equal to A. - // - // * If `B >= A`, then there can't be another available axis C such that - // `A <= C < B` since it would have to be overlapping with A and thus the - // first overlapping axis instead of B. Therefore, `lower_bound` will - // return B. - // * If `B < A`, then there can't be another available axis C such that - // `B < C < A` since B and C can't overlap. Therefore, `lower_bound` will - // return the axis after B, which doesn't overlap with A. - - if (beforeIt != availableAxes.end() && beforeIt->overlaps(axis)) { - return beforeIt; + // - For any entry B with `B < beforeIt < A`, B and `beforeIt` cannot overlap. + // Thus `beforeIt` isolates A and B such that they cannot overlap. + // - For any entry C with `A <= afterIt < C`, if A and C overlap, then A and + // `afterIt` must overlap as well. + + if (afterIt != orderedAxes.begin() && std::prev(afterIt)->overlaps(axis)) { + return std::prev(afterIt); } - if (afterIt != availableAxes.end() && afterIt->overlaps(axis)) { + if (afterIt != orderedAxes.end() && afterIt->overlaps(axis)) { return afterIt; } - return availableAxes.end(); + return orderedAxes.end(); } -// Removes `availableAxis` from `availableAxes` and adds the prefix and suffix -// of `availableAxis` that don't overlap with `overlap` back to `availableAxes`. -// -// We assume that `availableAxis` overlaps with `overlap`. -void removeOverlapFromAvailable(AxisRefAttr availableAxis, AxisRefAttr overlap, - AvailableAxes& availableAxes, MeshAttr mesh) { - availableAxes.erase(availableAxis); - if (OptionalAxisRef prefix = availableAxis.getPrefixWithoutOverlap(overlap)) { - availableAxes.insert(*prefix); +// Returns a map from `AxisRefAttr` to the dimension in `axesPerDim` that this +// axis appears. +AxisRefToDimMap getAxisRefToDimMap(ArrayRef axesPerDim) { + AxisRefToDimMap result; + for (auto [dim, axes] : llvm::enumerate(axesPerDim)) { + for (AxisRefAttr axis : axes) { + result.try_emplace(axis, dim); + } } - if (OptionalAxisRef suffix = - availableAxis.getSuffixWithoutOverlap(overlap, mesh)) { - availableAxes.insert(*suffix); + return result; +} + +SmallVector getOrderedAxes(ArrayRef axesPerDim) { + SmallVector result; + for (const AxisList& axes : axesPerDim) { + result.append(axes.begin(), axes.end()); } + llvm::sort(result); + return result; } -// Adds `axis` to `availableAxes` and merges it with sub-axes in -// `availableAxes` that can be merged with `axis`. -// -// We assume that `axis` doesn't overlap with any axis in `availableAxes`. -void addAvailableAxis(AxisRefAttr axis, AvailableAxes& availableAxes, - MeshAttr mesh) { - // `lower_bound` returns the first available axis greater or equal to `axis`, - // and we know `axis` doesn't overlap with any available axis. - auto afterIt = availableAxes.lower_bound(axis); - auto beforeIt = getPrevOrEnd(afterIt, availableAxes); - AxisRefAttr axisToAdd = axis; - // Try to merge `axisToAdd` with the first axis greater than it from the left. - if (afterIt != availableAxes.end() && axisToAdd.canMerge(*afterIt)) { - axisToAdd = axisToAdd.merge(*afterIt, mesh); - availableAxes.erase(afterIt); +// Remove the common prefix of `inAxesPerDim` and `outAxesPerDim`. +void removeCommonPrefix(SmallVector& inAxesPerDim, + SmallVector& outAxesPerDim) { + for (auto [inAxes, outAxes] : llvm::zip_equal(inAxesPerDim, outAxesPerDim)) { + while (!inAxes.empty() && !outAxes.empty() && + inAxes.front() == outAxes.front()) { + inAxes.pop_front(); + outAxes.pop_front(); + } } +} - // Try to merge `axisToAdd` with the last axis less than it from the right. - if (beforeIt != availableAxes.end() && beforeIt->canMerge(axisToAdd)) { - axisToAdd = beforeIt->merge(axisToAdd, mesh); - availableAxes.erase(beforeIt); +// In case an axis A in `axes` overlaps but isn't equal to an axis B in +// `orderedOtherAxes`, decomposes A into 1-3 sub-axes (overlap and +// non-overlapping prefix and suffix), and replaces A with the decomposed +// sub-axes that form it. +void alignSubAxesByDecomposition(AxisList& axes, + ArrayRef orderedOtherAxes, + MeshAttr mesh) { + auto axisIt = axes.begin(); + while (axisIt != axes.end()) { + AxisRefAttr axis = *axisIt; + auto* overlapIt = getFirstOverlapping(axis, orderedOtherAxes); + // There are two paths to complete the while loop below: + // 1. the while condition is not met from the start, in which case we need + // to advance `axisIt`. + // 2. we enter the while until the condition isn't met, in which case we + // only need to advance `axisIt` if it points to a created suffix. + bool axisAdvancedInWhile = false; + while (overlapIt != orderedOtherAxes.end() && overlapIt->canCoexist(axis) && + !overlapIt->contains(axis) && overlapIt->overlaps(axis)) { + axisIt = axes.erase(axisIt); + if (OptionalAxisRef prefix = axis.getPrefixWithoutOverlap(*overlapIt)) { + axes.insert(axisIt, *prefix); + } + axes.insert(axisIt, *axis.getOverlap(*overlapIt)); + if (OptionalAxisRef suffix = + axis.getSuffixWithoutOverlap(*overlapIt, mesh)) { + // If there is a suffix, that should be the next axis to process. + axisIt = axes.insert(axisIt, *suffix); + axis = *suffix; + ++overlapIt; + axisAdvancedInWhile = false; + } else { + // Otherwise, we're done with the current axis. + axisAdvancedInWhile = true; + break; + } + } + if (!axisAdvancedInWhile) { + ++axisIt; + } } - availableAxes.insert(axisToAdd); } -// If there is a prefix of `axis` that fully overlaps with an axis in -// `availableAxes`, returns that prefix and removes it from `availableAxes`. -// Otherwise, returns `std::nullopt` and leaves `availableAxes` unchanged. -std::optional takeAvailablePrefix(AxisRefAttr axis, - AvailableAxes& availableAxes, - MeshAttr mesh) { - // It's enough to check the first overlapping axis since any other overlapping - // axis would necessarily not fully overlap with a prefix of `axis`. - auto availableIt = getFirstOverlapping(axis, availableAxes); - if (availableIt == availableAxes.end()) { - return std::nullopt; +// In case two `AxisRefAttr` in `inAxesPerDim` and `outAxesPerDim` respectively +// overlap but aren't equal, decomposes them into up to three sub-axes (overlap +// and non-overlapping prefix and suffix), and replaces each original axis with +// the decomposed sub-axes that form it (see overload above). +// +// For example, "a":(1)8 and "a":(4)4 are decomposed into "a":(1)4, "a":(4)2, +// and "a":(8)2. Then "a":(1)8 is replaced with ["a":(1)4, "a":(4)2] and +// "a":(4)4 is replaced with ["a":(4)2, "a":(8)2]. +void alignSubAxesByDecomposition(SmallVector& inAxesPerDim, + SmallVector& outAxesPerDim, + MeshAttr mesh) { + SmallVector orderedInAxes = getOrderedAxes(inAxesPerDim); + SmallVector orderedOutAxes = getOrderedAxes(outAxesPerDim); + for (AxisList& inAxes : inAxesPerDim) { + alignSubAxesByDecomposition(inAxes, orderedOutAxes, mesh); } - AxisRefAttr availableAxis = *availableIt; - if (OptionalAxisRef result = axis.getPrefixWithOverlap(availableAxis, mesh)) { - removeOverlapFromAvailable(availableAxis, *result, availableAxes, mesh); - return result; + for (AxisList& outAxes : outAxesPerDim) { + alignSubAxesByDecomposition(outAxes, orderedInAxes, mesh); } - return std::nullopt; } -// Removes all axis refs in `axes` from `availableAxes`. +// Removes the axes in `axesToPop` from the back of `currentAxes`. // -// We assume for every axis ref in `axes` there is exactly one axis ref in -// `availableAxes` that contains it, and if they aren't equal, we remove the -// containing axis and add back the prefix and suffix that don't overlap, if -// exist. -void removeUnavailableAxes(ArrayRef axes, MeshAttr mesh, - AvailableAxes& availableAxes) { - for (AxisRefAttr axis : axes) { - removeOverlapFromAvailable(*getFirstOverlapping(axis, availableAxes), axis, - availableAxes, mesh); +// Note that `axesToPop` can have decomposed sub-axes of an axis in +// `currentAxes`, which is taken into account. +void popBackFromCurrentAxes(SmallVector& currentAxes, + const AxisList& axesToPop, + AxisList::iterator startIt) { + for (auto it = axesToPop.rbegin(); it != std::make_reverse_iterator(startIt); + ++it) { + if (auto prefix = currentAxes.back().getPrefixWithoutOverlap(*it)) { + currentAxes.back() = *prefix; + } else { + currentAxes.pop_back(); + } } } -// Returns all available axes or sub-axes in `mesh` that aren't used in -// `axesPerDim`. -AvailableAxes getAvailableAxes(ArrayRef> axesPerDim, - MeshAttr mesh) { - AvailableAxes unboundAxes; - for (MeshAxisAttr axis : mesh.getAxes()) { - unboundAxes.insert(AxisRefAttr::get(mesh.getContext(), axis.getName())); +struct AllToAllInfo { + SmallVector axes; + int64_t tgtDim; + + explicit AllToAllInfo(int64_t tgtDim) : tgtDim(tgtDim) {} +}; + +// A class that applies an algorithm to transform an input sharding into an +// output sharding via a sequence of collectives. +// +// The current sharding is initialized with the input sharding, and after each +// collective insertion, the current sharding is updated w.r.t the collective, +// until it matches the output sharding and we are done. +// +// We define the current state of the transformation as follows: +// +// - `inAxesPerDim` - the axes in the current sharding per dimension, such that +// the common prefix with the output sharding is removed. +// - `outAxesPerDim` - the axes in the output sharding per dimension, such that +// the common prefix with the current sharding is removed. +// - `currentAxesPerDim` - the axes in the current sharding, including the +// common prefix with the output sharding. +// +// These invariants are maintained throughout the algorithm, and specifically +// after each collective insertion. +// +// We also maintain `inAxisToDimMap` and `outAxisToDimMap`, which are used to +// find the dimension in `inAxesPerDim` and `outAxesPerDim` respectively where +// a given axis ref appears. `inAxisToDimMap` is updated when in axes are +// removed or moved to another dim, and `outAxisToDimMap` remains unchanged. +// +// Note that `inAxesPerDim` and `outAxesPerDim` represent the *diff* between the +// current and output sharding, i.e., when they are empty the shardings match +// exactly. The algorithm inserts collectives and updates the current state +// accordingly, until both `inAxesPerDim` and `outAxesPerDim` are empty. +class CollectiveInserter { + public: + CollectiveInserter(TensorShardingAttr inSharding, + TensorShardingAttr outSharding, MeshAttr mesh, + Value result, ConversionPatternRewriter& rewriter, + Location loc) + : rewriter(rewriter), + loc(loc), + mesh(mesh), + meshOrRef(inSharding.getMeshOrRef()), + result(result), + inAxesPerDim(getAxesPerDim(inSharding)), + outAxesPerDim(getAxesPerDim(outSharding)), + currentAxesPerDim(getAxesPerDim>(inSharding)), + collectiveAxesPerDim(inSharding.getRank()) { + // We align sub-axes between the input and output axes, so that we can treat + // sub-axes like full axes and assume any two sub-axes that overlap are also + // equal, which allows using them as keys in a hash map. + alignSubAxesByDecomposition(inAxesPerDim, outAxesPerDim, mesh); + // We remove the common prefix of `inAxesPerDim` and `outAxesPerDim`, since + // those axes stay exactly the same during the reshard. We are left with + // `inAxesPerDim` and `outAxesPerDim` that need to become empty, via a + // sequence of collectives. + removeCommonPrefix(inAxesPerDim, outAxesPerDim); + + inAxisToDimMap = getAxisRefToDimMap(inAxesPerDim); + outAxisToDimMap = getAxisRefToDimMap(outAxesPerDim); + } + + // Inserts a sequence of collectives to transform the input sharding into the + // output sharding, and returns the result of the final collective. + // + // If the input and output sharding are the same, returns the input value + // without inserting any collective. + Value insert() { + while (!isDone()) { + // 1. Try to insert an all-slice, that decreases the size of the tensor. + tryAllSlice(); + + // 2. Try to insert all-to-alls, that preserves the size of the tensor. + tryAllToAlls(); + + // 3. Try to insert an all-gather, that increases the size of the tensor. + tryAllGather(); + } + + return result; } - for (ArrayRef axes : axesPerDim) { - removeUnavailableAxes(axes, mesh, unboundAxes); + + private: + // Returns true if the input sharding has been transformed into the output + // sharding, i.e., both `inAxesPerDim` and `outAxesPerDim` are empty. + bool isDone() const { + return llvm::all_of(inAxesPerDim, std::mem_fn(&AxisList::empty)) && + llvm::all_of(outAxesPerDim, std::mem_fn(&AxisList::empty)); } - return unboundAxes; -} -// Returns the axes to slice for a specific dimension. -// -// If `inAxes` is empty, the prefix of `outAxes` that is available (i.e., fully -// contained by axes in `availableAxes`) can be sliced. The slicing axes are -// removed from `outAxes` and `availableAxes`, and added to `currentAxes`. -SmallVector getSlicingAxes(const AxisList& inAxes, - AxisList& outAxes, - SmallVector& currentAxes, - AvailableAxes& availableAxes, - MeshAttr mesh) { - if (!inAxes.empty()) { - return {}; + MLIRContext* getContext() const { return rewriter.getContext(); } + + int64_t getRank() const { return inAxesPerDim.size(); } + + TensorShardingAttr getCurrentSharding() const { + return TensorShardingAttr::getClosed(getContext(), meshOrRef, + currentAxesPerDim); } - SmallVector slicingAxes; - while (!outAxes.empty()) { - AxisRefAttr outAxis = outAxes.front(); - std::optional availablePrefix = - takeAvailablePrefix(outAxis, availableAxes, mesh); - if (!availablePrefix) { - break; + + // If an all-gather can be performed on `dim`, returns the axes to gather for + // that dimension. + // + // We gather all axes in `gatheringAxes = inAxesPerDim[dim]`, and update the + // internal state as follows: + // + // - `inAxesPerDim[dim]` is cleared. + // - `gatheringAxes` are popped from the back of `currentAxesPerDim[dim]`. + // + // For example: + // + // Input: `dim = 1` + // + // Initial state: + // - `inAxesPerDim = [[], ["x", "y"]]`, + // - `outAxesPerDim = [[], []]` + // - `currentAxesPerDim = [["w"], ["z", "x", "y"]]` + // + // Returns: `["x", "y"]`, and updates: + // - `inAxesPerDim = [[], []]`, + // - `outAxesPerDim = [[], []]` + // - `currentAxesPerDim = [["w"], ["z"]]` + SmallVector getGatheringAxes(int64_t dim) { + AxisList& inAxes = inAxesPerDim[dim]; + if (inAxes.empty()) { + return {}; } - slicingAxes.push_back(*availablePrefix); - addAxisOrMerge(currentAxes, *availablePrefix, mesh); - outAxes.pop_front(); - if (*availablePrefix != outAxis) { - // Safe to dereference since we know `availablePrefix` and `outAxis` have - // a common prefix and aren't equal. - outAxes.push_front( - *outAxis.getSuffixWithoutOverlap(*availablePrefix, mesh)); - break; + SmallVector& currentAxes = currentAxesPerDim[dim]; + SmallVector gatheringAxes; + gatheringAxes.reserve(inAxes.size()); + popBackFromCurrentAxes(currentAxes, inAxes, inAxes.begin()); + for (AxisRefAttr axis : inAxes) { + addAxisOrMerge(gatheringAxes, axis, mesh); + inAxisToDimMap.erase(axis); } + inAxes.clear(); + return gatheringAxes; } - return slicingAxes; -} -// Returns the axes to gather for a specific dimension. -// -// All axes in `inAxes` are gathered greedily. The gathering axes are removed -// from `availableAxes`, popped from the back of `currentAxes`, and `inAxes` is -// cleared. -SmallVector getGatheringAxes(AxisList& inAxes, - SmallVector& currentAxes, - AvailableAxes& availableAxes, - MeshAttr mesh) { - if (inAxes.empty()) { - return {}; + // Tries to insert an `sdy.all_gather`. + void tryAllGather() { + bool hasGatheringAxes = false; + for (auto [dim, collectiveAxes] : llvm::enumerate(collectiveAxesPerDim)) { + SmallVector gatheringAxes = getGatheringAxes(dim); + if (!gatheringAxes.empty()) { + hasGatheringAxes = true; + } + collectiveAxes = AxisRefListAttr::get(getContext(), gatheringAxes); + } + if (hasGatheringAxes) { + result = rewriter.create(loc, result, collectiveAxesPerDim, + getCurrentSharding()); + } + } + + // TODO(b/392952931): currently we are greedily slicing and all-to-all-ing + // axes even if the destination dimension is too small to accommodate the + // extra axes. This would introduce padding which is sub-optimal, thus we + // should only do this if the dimension has enough space left, or slice as + // much as possible to fill the space. + + // If an all-slice can be performed, returns the axes to slice for each + // dimension. + // + // For each dimension d, each axis X in `outAxesPerDim[d]` that isn't present + // in `inAxisToDimMap` (i.e., available to slice) is sliced as follows: + // + // - If the last axis Y before X in `outAxesPerDim[d]` that isn't sliced holds + // `inAxisToDimMap[Y] == d`, or there isn't such an axis, then X is sliced + // on that dimension. + // - Otherwise, X is sliced on the mapped dimension (`inAxisToDimMap[Y]`), so + // we can later do an all-to-all on a smaller tensor to move both axes to + // the other dimension. + // + // Returns std::nullopt if there are no slicing axes in any dimension. + // + // The internal state is updated as follows for each dimension `d` and the + // slicing axes on that dimension (`slicingAxes`): + // + // - `slicingAxes` are appended to `inAxesPerDim[d]` and + // `currentAxesPerDim[d]`. + // - The common prefix between `inAxesPerDim[d]` and `outAxesPerDim[d]` is + // removed from both. + // + // Note that this brings us closer to being done, i.e., having both + // `inAxesPerDim` and `outAxesPerDim` empty, because we take axes that are + // present in `outAxesPerDim` but not in `inAxesPerDim`, and either: + // + // - Remove them from `outAxesPerDim`, if they are where they need to be. + // - Add them to `inAxesPerDim` otherwise, which will allow us to perform an + // all-to-all or collective-permute on them to get them to the right place. + // + // For example: + // + // Initial state: + // - `inAxesPerDim = [[], ["y"], []]`, + // - `outAxesPerDim = [["x"], [], ["y", "z", "w"]]` + // - `currentAxesPerDim = [["u"], ["y"], []]` + // + // Returns: `[["x"], ["z", "w"], []]`, and updates: + // - `inAxesPerDim = [[], ["y", "z", "w"], []]`, + // - `outAxesPerDim = [[], [], ["y", "z", "w"]]` + // - `currentAxesPerDim = [["u", "x"], ["y", "z", "w"], []]` + std::optional getSlicingAxesPerDim() { + AxesPerDim slicingAxesPerDim(currentAxesPerDim.size()); + + bool hasSlicingAxes = false; + for (auto [outDim, outAxes] : llvm::enumerate(outAxesPerDim)) { + auto outIt = outAxes.begin(); + std::optional lastInDim; + while (outIt != outAxes.end()) { + AxisRefAttr outAxis = *outIt; + if (auto inAxisEntryIt = inAxisToDimMap.find(outAxis); + inAxisEntryIt != inAxisToDimMap.end()) { + // Out axis isn't available to slice. + lastInDim = inAxisEntryIt->second; + ++outIt; + continue; + } + // We should slice `outAxis` at `lastInDim` if present or `outDim` + // otherwise. + hasSlicingAxes = true; + int64_t slicingDim = lastInDim.value_or(outDim); + addAxisOrMerge(slicingAxesPerDim[slicingDim], outAxis, mesh); + addAxisOrMerge(currentAxesPerDim[slicingDim], outAxis, mesh); + AxisList& inAxes = inAxesPerDim[slicingDim]; + if (inAxes.empty() && outIt == outAxes.begin()) { + // Slicing axis is where it needs to be. + outIt = outAxes.erase(outIt); + } else { + inAxisToDimMap.try_emplace(outAxis, slicingDim); + inAxes.push_back(outAxis); + ++outIt; + } + } + } + + return hasSlicingAxes ? std::make_optional(slicingAxesPerDim) + : std::nullopt; + } + + // Tries to insert an `sdy.all_slice`. + void tryAllSlice() { + if (std::optional slicingAxesPerDim = getSlicingAxesPerDim()) { + for (auto [collectiveAxes, slicingAxes] : + llvm::zip_equal(collectiveAxesPerDim, *slicingAxesPerDim)) { + collectiveAxes = AxisRefListAttr::get(getContext(), slicingAxes); + } + result = rewriter.create(loc, result, collectiveAxesPerDim, + getCurrentSharding()); + } } - SmallVector gatheringAxes = llvm::to_vector(inAxes); - currentAxes.pop_back_n(inAxes.size() - 1); - if (OptionalAxisRef prefix = - currentAxes.back().getPrefixWithoutOverlap(inAxes.front())) { - currentAxes.back() = *prefix; - } else { - currentAxes.pop_back(); + + // If an all-to-all can be performed for the given source dimension `srcDim`, + // returns the axes and target dimension of this all-to-all. + // + // The suffix of axes in `inAxesPerDim[srcDim]` that are mapped to the same + // dimension in `outAxisToDimMap` are all-to-all-ed with the mapped dimension + // as the target (tgtDim). + // + // The internal state is updated as follows for `allToAllAxes` and `tgtDim`: + // + // - `allToAllAxes` are popped from the back of `inAxesPerDim[srcDim]` and + // `currentAxesPerDim[srcDim]`. + // - `allToAllAxes` are appended to `inAxesPerDim[tgtDim]` and + // `currentAxesPerDim[tgtDim]`. + // - The common prefix between `inAxesPerDim[tgtDim]` and + // `outAxesPerDim[tgtDim]` is removed from both. + // + // Note that this brings us closer to being done, i.e., having both + // `inAxesPerDim` and `outAxesPerDim` empty, because we move axes from + // `inAxesPerDim[srcDim]` to either: + // + // - Where they need to be in `tgtDim`, in which case they are removed from + // `outAxesPerDim[tgtDim]`. + // - Move axes from `inAxesPerDim[srcDim]` to `inAxesPerDim[tgtDim]`, which + // will allow us to perform a collective permute on them to get them to the + // right place. + // + // For example: + // + // Input: `srcDim = 1` + // + // Initial state: + // - `inAxesPerDim = [["w"], ["x", "y", "z"], []]`, + // - `outAxesPerDim = [["x"], [], ["y", "z"]]` + // - `currentAxesPerDim = [["w"], ["x", "y", "z"], []]` + // + // First call returns: `{axes = ["y", "z"], tgtDim = 2}`, and updates: + // - `inAxesPerDim = [["w"], ["x"], []]`, + // - `outAxesPerDim = [["x"], [], []]` + // - `currentAxesPerDim = [["w"], ["x"], ["y", "z"]]` + // + // Second call returns: `{axes = ["x"], tgtDim = 0}`, and updates: + // - `inAxesPerDim = [["w", "x"], [], []]`, + // - `outAxesPerDim = [["x"], [], []]` + // - `currentAxesPerDim = [["w", "x"], [], ["y", "z"]]` + std::optional getAllToAllInfo(int64_t srcDim) { + AxisList& srcInAxes = inAxesPerDim[srcDim]; + + auto axisRevIt = srcInAxes.rbegin(); + int64_t numAxes = 0; + std::optional optTgtDim; + for (; axisRevIt != srcInAxes.rend(); ++axisRevIt) { + auto outAxisEntryIt = outAxisToDimMap.find(*axisRevIt); + if (outAxisEntryIt == outAxisToDimMap.end()) { + break; + } + int64_t outAxisDim = outAxisEntryIt->second; + if (outAxisDim == srcDim || (optTgtDim && outAxisDim != *optTgtDim)) { + break; + } + optTgtDim = outAxisDim; + ++numAxes; + } + + if (!optTgtDim) { + // Can't do an all-to-all from `srcDim` to any dimension. + return std::nullopt; + } + + auto startInAxisIt = axisRevIt.base(); + + AllToAllInfo result(*optTgtDim); + auto& [allToAllAxes, tgtDim] = result; + allToAllAxes.reserve(numAxes); + + SmallVector& srcCurrentAxes = currentAxesPerDim[srcDim]; + SmallVector& tgtCurrentAxes = currentAxesPerDim[tgtDim]; + + popBackFromCurrentAxes(srcCurrentAxes, srcInAxes, startInAxisIt); + + AxisList& tgtInAxes = inAxesPerDim[tgtDim]; + AxisList& tgtOutAxes = outAxesPerDim[tgtDim]; + auto srcInAxisIt = startInAxisIt; + while (srcInAxisIt != srcInAxes.end()) { + AxisRefAttr axis = *srcInAxisIt; + addAxisOrMerge(allToAllAxes, axis, mesh); + addAxisOrMerge(tgtCurrentAxes, axis, mesh); + srcInAxisIt = srcInAxes.erase(srcInAxisIt); + inAxisToDimMap.erase(axis); + if (tgtInAxes.empty() && tgtOutAxes.front() == axis) { + tgtOutAxes.pop_front(); + } else { + tgtInAxes.push_back(axis); + inAxisToDimMap.try_emplace(axis, tgtDim); + } + } + + return result; } - for (AxisRefAttr axis : inAxes) { - addAvailableAxis(axis, availableAxes, mesh); + // Tries to insert a sequence of `sdy.all_to_all`s. + void tryAllToAlls() { + bool allToAllCreated = false; + do { + allToAllCreated = false; + for (int64_t srcDim = 0; srcDim < getRank(); ++srcDim) { + if (auto info = getAllToAllInfo(srcDim)) { + result = + rewriter.create(loc, result, srcDim, info->tgtDim, + info->axes, getCurrentSharding()); + allToAllCreated = true; + } + } + } while (allToAllCreated); } - inAxes.clear(); - return gatheringAxes; -} + + ConversionPatternRewriter& rewriter; + Location loc; + MeshAttr mesh; + Attribute meshOrRef; + Value result; + SmallVector inAxesPerDim, outAxesPerDim; + AxesPerDim currentAxesPerDim; + SmallVector collectiveAxesPerDim; + AxisRefToDimMap inAxisToDimMap, outAxisToDimMap; +}; class ReshardPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; private: - // For the moment we only consider all_gather and all_slice. - // TODO(b/380226848): Add support for other collectives. LogicalResult matchAndRewrite( ReshardOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - TensorShardingAttr inputSharding = getSharding(adaptor.getInput()); - TensorShardingAttr outputSharding = adaptor.getSharding(); + TensorShardingAttr inSharding = getSharding(adaptor.getInput()); + TensorShardingAttr outSharding = adaptor.getSharding(); // Here it's safe to assume that shardings' meshes have a name. - if (inputSharding.getRank() != outputSharding.getRank() || - inputSharding.getMeshName() != outputSharding.getMeshName()) { + if (inSharding.getRank() != outSharding.getRank() || + inSharding.getMeshName() != outSharding.getMeshName()) { return rewriter.notifyMatchFailure( op, [](Diagnostic& diag) { diag << "Incompatible shardings"; }); } - int64_t rank = inputSharding.getRank(); // TODO(tomnatan): we should verify that the operand of ReshardOp has a // sharding. // TODO(tomnatan): use a SymbolTable. - MeshAttr mesh = inputSharding.getMesh(op); - SmallVector inAxesPerDim = getAxesPerDim(inputSharding); - SmallVector outAxesPerDim = - getAxesPerDim(outputSharding); - // We remove the common prefix of `inAxes` and `outAxes`, since those axes - // stay exactly the same during the reshard. We are left with `inAxes` that - // need to be transformed into `outAxes`, via a sequence of collectives. - for (auto [inAxes, outAxes] : - llvm::zip_equal(inAxesPerDim, outAxesPerDim)) { - removeCommonPrefix(inAxes, outAxes, mesh); - } - - auto hasRemainingAxes = [](const AxisList& axes) { return !axes.empty(); }; - bool hasRemainingInAxes = llvm::any_of(inAxesPerDim, hasRemainingAxes); - bool hasRemainingOutAxes = llvm::any_of(outAxesPerDim, hasRemainingAxes); - - if (!hasRemainingInAxes && !hasRemainingOutAxes) { - rewriter.replaceOp(op, adaptor.getInput()); - return success(); - } - - SmallVector> currentAxesPerDim = - getAxesPerDim>(inputSharding); - AvailableAxes availableAxes = getAvailableAxes(currentAxesPerDim, mesh); - - Value input = adaptor.getInput(); - MLIRContext* context = rewriter.getContext(); - - auto getCurrentSharding = [&]() { - return TensorShardingAttr::getClosed( - context, inputSharding.getMeshOrRef(), currentAxesPerDim); - }; - - SmallVector collectiveAxesPerDim(rank); - - // We aren't done until both `inAxesPerDim` and `outAxesPerDim` are - // empty. - // TODO(b/380226848): this is an initial implementation that only inserts - // all-gathers and all-slices, and greedily all-gathers axes after the first - // attempt to insert an all-slice. - while (hasRemainingInAxes || hasRemainingOutAxes) { - // 1. Try to insert an all-slice first, as it decreases the size of the - // tensor. - hasRemainingOutAxes = false; - bool hasSlicingAxes = false; - for (auto [inAxes, outAxes, currentAxes, collectiveAxes] : - llvm::zip_equal(inAxesPerDim, outAxesPerDim, currentAxesPerDim, - collectiveAxesPerDim)) { - SmallVector slicingAxes = - getSlicingAxes(inAxes, outAxes, currentAxes, availableAxes, mesh); - if (!slicingAxes.empty()) { - hasSlicingAxes = true; - } - if (!outAxes.empty()) { - hasRemainingOutAxes = true; - } - collectiveAxes = AxisRefListAttr::get(context, slicingAxes); - } - if (hasSlicingAxes) { - input = rewriter.create( - op.getLoc(), input, collectiveAxesPerDim, getCurrentSharding()); - } - - // 2. Try to insert an all-gather, that increases the size of the tensor. - hasRemainingInAxes = false; - bool hasGatheringAxes = false; - for (auto [inAxes, currentAxes, collectiveAxes] : llvm::zip_equal( - inAxesPerDim, currentAxesPerDim, collectiveAxesPerDim)) { - SmallVector gatheringAxes = - getGatheringAxes(inAxes, currentAxes, availableAxes, mesh); - if (!gatheringAxes.empty()) { - hasGatheringAxes = true; - } - collectiveAxes = AxisRefListAttr::get(context, gatheringAxes); - } - if (hasGatheringAxes) { - input = rewriter.create( - op.getLoc(), input, collectiveAxesPerDim, getCurrentSharding()); - } - } + CollectiveInserter collectiveInserter( + inSharding, outSharding, inSharding.getMesh(op), adaptor.getInput(), + rewriter, op.getLoc()); + rewriter.replaceOp(op, collectiveInserter.insert()); - rewriter.replaceOp(op, input); return success(); } }; @@ -397,7 +628,7 @@ struct ReshardToCollectivesPass LogicalResult initialize(MLIRContext* context) final { target = std::make_shared(*context); target->addIllegalOp(); - target->addLegalOp(); + target->addLegalOp(); RewritePatternSet patternsInternal(context); patternsInternal.add(context); diff --git a/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir b/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir index 0f927278..1a877217 100644 --- a/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir +++ b/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir @@ -5,6 +5,8 @@ sdy.mesh @mesh2d = <["x"=2, "y"=2]> sdy.mesh @mesh2d_4x2 = <["x"=4, "y"=2]> sdy.mesh @mesh2d_2x8 = <["x"=2, "y"=8]> sdy.mesh @mesh3d = <["x"=2, "y"=2, "z"=2]> +sdy.mesh @mesh3d_4x2x4 = <["x"=4, "y"=2, "z"=4]> +sdy.mesh @mesh4d_z4 = <["x"=2, "y"=2, "z"=4, "w"=2]> sdy.mesh @mesh4d_w4 = <["x"=2, "y"=2, "z"=2, "w"=4]> sdy.mesh @mesh4d_w16 = <["x"=2, "y"=2, "z"=2, "w"=16]> @@ -15,72 +17,222 @@ func.func @redundant_reshard(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.shardin return %0 : tensor<16x8xf32> } -// CHECK-LABEL: func @reshard_to_all_gather_single_axis -func.func @reshard_to_all_gather_single_axis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{"y"}, {"x"}]>}) -> tensor<16x8xf32> { +// CHECK-LABEL: func @all_gather_single_axis +func.func @all_gather_single_axis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{"y"}, {"x"}]>}) -> tensor<16x8xf32> { // CHECK-NEXT: sdy.all_gather [{}, {"x"}] %arg0 out_sharding=<@mesh2d, [{"y"}, {}]> %0 = sdy.reshard %arg0 <@mesh2d, [{"y"}, {}]> : tensor<16x8xf32> return %0 : tensor<16x8xf32> } -// CHECK-LABEL: func @reshard_to_all_gather_multiple_axes -func.func @reshard_to_all_gather_multiple_axes(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{"x", "y", "z"}, {}]>}) -> tensor<16x8xf32> { +// CHECK-LABEL: func @all_gather_multiple_axes +func.func @all_gather_multiple_axes(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{"x", "y", "z"}, {}]>}) -> tensor<16x8xf32> { // CHECK-NEXT: sdy.all_gather [{"y", "z"}, {}] %arg0 out_sharding=<@mesh3d, [{"x"}, {}]> %0 = sdy.reshard %arg0 <@mesh3d, [{"x"}, {}]> : tensor<16x8xf32> return %0 : tensor<16x8xf32> } -// CHECK-LABEL: func @reshard_to_all_gather_multiple_dims -func.func @reshard_to_all_gather_multiple_dims(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{"y", "z"}, {"x"}]>}) -> tensor<16x8xf32> { +// CHECK-LABEL: func @all_gather_multiple_dims +func.func @all_gather_multiple_dims(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{"y", "z"}, {"x"}]>}) -> tensor<16x8xf32> { // CHECK-NEXT: sdy.all_gather [{"z"}, {}] %arg0 out_sharding=<@mesh3d, [{"y"}, {"x"}]> %0 = sdy.reshard %arg0 <@mesh3d, [{"y"}, {"x"}]> : tensor<16x8xf32> return %0 : tensor<16x8xf32> } -// CHECK-LABEL: func @reshard_to_all_gather_with_subaxis -func.func @reshard_to_all_gather_with_subaxis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d_2x8, [{"y"}, {"x"}]>}) -> tensor<16x8xf32> { +// CHECK-LABEL: func @all_gather_with_subaxis +func.func @all_gather_with_subaxis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d_2x8, [{"y"}, {"x"}]>}) -> tensor<16x8xf32> { // CHECK-NEXT: sdy.all_gather [{"y":(4)2}, {}] %arg0 out_sharding=<@mesh2d_2x8, [{"y":(1)4}, {"x"}]> %0 = sdy.reshard %arg0 <@mesh2d_2x8, [{"y":(1)4}, {"x"}]> : tensor<16x8xf32> return %0 : tensor<16x8xf32> } -// CHECK-LABEL: func @reshard_to_all_slice_multiple_axes -func.func @reshard_to_all_slice_multiple_axes(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{}, {}]>}) -> tensor<16x8xf32> { +// CHECK-LABEL: func @all_slice_multiple_axes +func.func @all_slice_multiple_axes(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{}, {}]>}) -> tensor<16x8xf32> { // CHECK-NEXT: sdy.all_slice [{"x"}, {"y", "z"}] %arg0 out_sharding=<@mesh3d, [{"x"}, {"y", "z"}]> %0 = sdy.reshard %arg0 <@mesh3d, [{"x"}, {"y", "z"}]> : tensor<16x8xf32> return %0 : tensor<16x8xf32> } -// CHECK-LABEL: func @reshard_to_all_slice_minor_axis -func.func @reshard_to_all_slice_minor_axis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{"x"}, {"y"}]>}) -> tensor<16x8xf32> { +// CHECK-LABEL: func @all_slice_minor_axis +func.func @all_slice_minor_axis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{"x"}, {"y"}]>}) -> tensor<16x8xf32> { // CHECK-NEXT: sdy.all_slice [{}, {"z"}] %arg0 out_sharding=<@mesh3d, [{"x"}, {"y", "z"}]> %0 = sdy.reshard %arg0 <@mesh3d, [{"x"}, {"y", "z"}]> : tensor<16x8xf32> return %0 : tensor<16x8xf32> } +// CHECK-LABEL: func @all_slice_with_subaxis +func.func @all_slice_with_subaxis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d_4x2x4, [{"x":(1)2}, {"y"}]>}) -> tensor<16x8xf32> { + // CHECK-NEXT: sdy.all_slice [{"x":(2)2}, {"z":(1)2}] %arg0 out_sharding=<@mesh3d_4x2x4, [{"x"}, {"y", "z":(1)2}]> + %0 = sdy.reshard %arg0 <@mesh3d_4x2x4, [{"x"}, {"y", "z":(1)2}]> : tensor<16x8xf32> + return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: func @all_to_all_single_axis +func.func @all_to_all_single_axis(%arg0 : tensor<16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{"x"}, {"y"}, {}]>}) -> tensor<16x8x8xf32> { + // CHECK-NEXT: sdy.all_to_all {"x"} 0->2 %arg0 out_sharding=<@mesh3d, [{}, {"y"}, {"x"}]> + %0 = sdy.reshard %arg0 <@mesh3d, [{}, {"y"}, {"x"}]> : tensor<16x8x8xf32> + return %0 : tensor<16x8x8xf32> +} + +// CHECK-LABEL: func @all_to_all_multiple_axes +func.func @all_to_all_multiple_axes(%arg0 : tensor<16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{"x"}, {}, {"y", "z"}]>}) -> tensor<16x8x8xf32> { + // CHECK-NEXT: sdy.all_to_all {"y", "z"} 2->1 %arg0 out_sharding=<@mesh3d, [{"x"}, {"y", "z"}, {}]> + %0 = sdy.reshard %arg0 <@mesh3d, [{"x"}, {"y", "z"}, {}]> : tensor<16x8x8xf32> + return %0 : tensor<16x8x8xf32> +} + +// CHECK-LABEL: func @two_all_to_alls_different_tgt_dims +func.func @two_all_to_alls_different_tgt_dims(%arg0 : tensor<16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d_4x2x4, [{}, {"y", "x"}, {}]>}) -> tensor<16x8x8xf32> { + // CHECK-NEXT: %[[ALL_TO_ALL_0:.*]] = sdy.all_to_all {"x"} 1->0 %arg0 out_sharding=<@mesh3d_4x2x4, [{"x"}, {"y"}, {}]> + // CHECK-NEXT: %[[ALL_TO_ALL_1:.*]] = sdy.all_to_all {"y"} 1->2 %[[ALL_TO_ALL_0]] out_sharding=<@mesh3d_4x2x4, [{"x"}, {}, {"y"}]> + // CHECK-NEXT: return %[[ALL_TO_ALL_1]] + %0 = sdy.reshard %arg0 <@mesh3d_4x2x4, [{"x"}, {}, {"y"}]> : tensor<16x8x8xf32> + return %0 : tensor<16x8x8xf32> +} + +// CHECK-LABEL: func @two_all_to_alls_tgt_dim_not_empty +func.func @two_all_to_alls_tgt_dim_not_empty(%arg0 : tensor<16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d_4x2x4, [{"x"}, {"y", "z"}, {}]>}) -> tensor<16x8x8xf32> { + // CHECK-NEXT: %[[ALL_TO_ALL_0:.*]] = sdy.all_to_all {"z"} 1->0 %arg0 out_sharding=<@mesh3d_4x2x4, [{"x", "z"}, {"y"}, {}]> + // CHECK-NEXT: %[[ALL_TO_ALL_1:.*]] = sdy.all_to_all {"y"} 1->2 %[[ALL_TO_ALL_0]] out_sharding=<@mesh3d_4x2x4, [{"x", "z"}, {}, {"y"}]> + // CHECK-NEXT: return %[[ALL_TO_ALL_1]] + %0 = sdy.reshard %arg0 <@mesh3d_4x2x4, [{"x", "z"}, {}, {"y"}]> : tensor<16x8x8xf32> + return %0 : tensor<16x8x8xf32> +} + +// CHECK-LABEL: func @slice_then_all_to_alls +func.func @slice_then_all_to_alls(%arg0 : tensor<16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d_4x2x4, [{}, {"y", "z"}, {}]>}) -> tensor<16x8x8xf32> { + // CHECK-NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{"x"}, {}, {}] %arg0 out_sharding=<@mesh3d_4x2x4, [{"x"}, {"y", "z"}, {}]> + // CHECK-NEXT: %[[ALL_TO_ALL_0:.*]] = sdy.all_to_all {"z"} 1->0 %[[ALL_SLICE]] out_sharding=<@mesh3d_4x2x4, [{"x", "z"}, {"y"}, {}]> + // CHECK-NEXT: %[[ALL_TO_ALL_1:.*]] = sdy.all_to_all {"y"} 1->2 %[[ALL_TO_ALL_0]] out_sharding=<@mesh3d_4x2x4, [{"x", "z"}, {}, {"y"}]> + // CHECK-NEXT: return %[[ALL_TO_ALL_1]] + %0 = sdy.reshard %arg0 <@mesh3d_4x2x4, [{"x", "z"}, {}, {"y"}]> : tensor<16x8x8xf32> + return %0 : tensor<16x8x8xf32> +} + +// CHECK-LABEL: func @all_to_all_subaxis_then_all_gather +func.func @all_to_all_subaxis_then_all_gather(%arg0 : tensor<16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d_4x2x4, [{"x"}, {"z", "y"}, {}]>}) -> tensor<16x8x8xf32> { + // CHECK-NEXT: %[[ALL_TO_ALL_0:.*]] = sdy.all_to_all {"y"} 1->2 %arg0 out_sharding=<@mesh3d_4x2x4, [{"x"}, {"z"}, {"y"}]> + // CHECK-NEXT: %[[ALL_TO_ALL_1:.*]] = sdy.all_to_all {"z":(2)2} 1->0 %[[ALL_TO_ALL_0]] out_sharding=<@mesh3d_4x2x4, [{"x", "z":(2)2}, {"z":(1)2}, {"y"}]> + // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{}, {"z":(1)2}, {}] %[[ALL_TO_ALL_1]] out_sharding=<@mesh3d_4x2x4, [{"x", "z":(2)2}, {}, {"y"}]> + // CHECK-NEXT: return %[[ALL_GATHER]] + %0 = sdy.reshard %arg0 <@mesh3d_4x2x4, [{"x", "z":(2)2}, {}, {"y"}]> : tensor<16x8x8xf32> + return %0 : tensor<16x8x8xf32> +} + +// CHECK-LABEL: func @all_to_all_subaxis_and_full_axis_then_all_gather +func.func @all_to_all_subaxis_and_full_axis_then_all_gather(%arg0 : tensor<16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh4d_z4, [{"x"}, {"z", "w", "y"}, {}]>}) -> tensor<16x8x8xf32> { + // CHECK-NEXT: %[[ALL_TO_ALL_0:.*]] = sdy.all_to_all {"y"} 1->2 %arg0 out_sharding=<@mesh4d_z4, [{"x"}, {"z", "w"}, {"y"}]> + // CHECK-NEXT: %[[ALL_TO_ALL_1:.*]] = sdy.all_to_all {"z":(2)2, "w"} 1->0 %[[ALL_TO_ALL_0]] out_sharding=<@mesh4d_z4, [{"x", "z":(2)2, "w"}, {"z":(1)2}, {"y"}]> + // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{}, {"z":(1)2}, {}] %[[ALL_TO_ALL_1]] out_sharding=<@mesh4d_z4, [{"x", "z":(2)2, "w"}, {}, {"y"}]> + // CHECK-NEXT: return %[[ALL_GATHER]] + %0 = sdy.reshard %arg0 <@mesh4d_z4, [{"x", "z":(2)2, "w"}, {}, {"y"}]> : tensor<16x8x8xf32> + return %0 : tensor<16x8x8xf32> +} + +// CHECK-LABEL: func @slice_on_src_dim_then_all_to_all_subaxis +func.func @slice_on_src_dim_then_all_to_all_subaxis(%arg0 : tensor<16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh4d_w4, [{}, {"w":(1)2}, {}]>}) -> tensor<16x8x8xf32> { + // CHECK-NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{}, {"w":(2)2}, {}] %arg0 out_sharding=<@mesh4d_w4, [{}, {"w"}, {}]> + // CHECK-NEXT: %[[ALL_TO_ALL:.*]] = sdy.all_to_all {"w"} 1->0 %[[ALL_SLICE]] out_sharding=<@mesh4d_w4, [{"w"}, {}, {}]> + // CHECK-NEXT: return %[[ALL_TO_ALL]] + %0 = sdy.reshard %arg0 <@mesh4d_w4, [{"w"}, {}, {}]> : tensor<16x8x8xf32> + return %0 : tensor<16x8x8xf32> +} + +// CHECK-LABEL: func @slice_on_src_dim_then_all_to_all_multiple_axes +func.func @slice_on_src_dim_then_all_to_all_multiple_axes(%arg0 : tensor<16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh4d_w4, [{}, {"x"}, {}]>}) -> tensor<16x8x8xf32> { + // CHECK-NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{}, {"y", "z"}, {}] %arg0 out_sharding=<@mesh4d_w4, [{}, {"x", "y", "z"}, {}]> + // CHECK-NEXT: %[[ALL_TO_ALL:.*]] = sdy.all_to_all {"x", "y", "z"} 1->2 %[[ALL_SLICE]] out_sharding=<@mesh4d_w4, [{}, {}, {"x", "y", "z"}]> + // CHECK-NEXT: return %[[ALL_TO_ALL]] + %0 = sdy.reshard %arg0 <@mesh4d_w4, [{}, {}, {"x", "y", "z"}]> : tensor<16x8x8xf32> + return %0 : tensor<16x8x8xf32> +} + +// TODO(b/380226848): the tests below require collective permute to do the right +// thing. At the moment, we do a redundant all-slice or all-to-all, just to +// all-gather the added axes and slice again in the right order. + +// CHECK-LABEL: func @all_to_all_axes_at_src_out_of_order +func.func @all_to_all_axes_at_src_out_of_order(%arg0 : tensor<16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d_4x2x4, [{"z"}, {"y", "x"}, {}]>}) -> tensor<16x8x8xf32> { + // CHECK-NEXT: %[[ALL_TO_ALL:.*]] = sdy.all_to_all {"y", "x"} 1->0 %arg0 out_sharding=<@mesh3d_4x2x4, [{"z", "y", "x"}, {}, {}]> + // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"z", "y", "x"}, {}, {}] %[[ALL_TO_ALL]] out_sharding=<@mesh3d_4x2x4, [{}, {}, {}]> + // CHECK-NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{"x", "y", "z"}, {}, {}] %[[ALL_GATHER]] out_sharding=<@mesh3d_4x2x4, [{"x", "y", "z"}, {}, {}]> + // CHECK-NEXT: return %[[ALL_SLICE]] + %0 = sdy.reshard %arg0 <@mesh3d_4x2x4, [{"x", "y", "z"}, {}, {}]> : tensor<16x8x8xf32> + return %0 : tensor<16x8x8xf32> +} + +// CHECK-LABEL: func @all_to_all_two_tgt_dims_src_out_of_order +func.func @all_to_all_two_tgt_dims_src_out_of_order(%arg0 : tensor<16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d_4x2x4, [{}, {"x", "z", "y"}, {}]>}) -> tensor<16x8x8xf32> { + // CHECK-NEXT: %[[ALL_TO_ALL_0:.*]] = sdy.all_to_all {"y"} 1->0 %arg0 out_sharding=<@mesh3d_4x2x4, [{"y"}, {"x", "z"}, {}]> + // CHECK-NEXT: %[[ALL_TO_ALL_1:.*]] = sdy.all_to_all {"z"} 1->2 %[[ALL_TO_ALL_0]] out_sharding=<@mesh3d_4x2x4, [{"y"}, {"x"}, {"z"}]> + // CHECK-NEXT: %[[ALL_TO_ALL_2:.*]] = sdy.all_to_all {"x"} 1->0 %[[ALL_TO_ALL_1]] out_sharding=<@mesh3d_4x2x4, [{"y", "x"}, {}, {"z"}]> + // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"y", "x"}, {}, {}] %[[ALL_TO_ALL_2]] out_sharding=<@mesh3d_4x2x4, [{}, {}, {"z"}]> + // CHECK-NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{"x", "y"}, {}, {}] %[[ALL_GATHER]] out_sharding=<@mesh3d_4x2x4, [{"x", "y"}, {}, {"z"}]> + // CHECK-NEXT: return %[[ALL_SLICE]] + %0 = sdy.reshard %arg0 <@mesh3d_4x2x4, [{"x", "y"}, {}, {"z"}]> : tensor<16x8x8xf32> + return %0 : tensor<16x8x8xf32> +} + +// CHECK-LABEL: func @all_to_all_two_tgt_dims_src_out_of_order_2 +func.func @all_to_all_two_tgt_dims_src_out_of_order_2(%arg0 : tensor<16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d_4x2x4, [{}, {"y", "z", "x"}, {}]>}) -> tensor<16x8x8xf32> { + // CHECK-NEXT: %[[ALL_TO_ALL_0:.*]] = sdy.all_to_all {"x"} 1->0 %arg0 out_sharding=<@mesh3d_4x2x4, [{"x"}, {"y", "z"}, {}]> + // CHECK-NEXT: %[[ALL_TO_ALL_1:.*]] = sdy.all_to_all {"z"} 1->2 %[[ALL_TO_ALL_0]] out_sharding=<@mesh3d_4x2x4, [{"x"}, {"y"}, {"z"}]> + // CHECK-NEXT: %[[ALL_TO_ALL_2:.*]] = sdy.all_to_all {"y"} 1->0 %[[ALL_TO_ALL_1]] out_sharding=<@mesh3d_4x2x4, [{"x", "y"}, {}, {"z"}]> + // CHECK-NEXT: return %[[ALL_TO_ALL_2]] + %0 = sdy.reshard %arg0 <@mesh3d_4x2x4, [{"x", "y"}, {}, {"z"}]> : tensor<16x8x8xf32> + return %0 : tensor<16x8x8xf32> +} + +// CHECK-LABEL: func @all_to_all_and_gather_src_dim_out_of_order +func.func @all_to_all_and_gather_src_dim_out_of_order(%arg0 : tensor<16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh4d_z4, [{"x"}, {"y", "z", "w"}, {}]>}) -> tensor<16x8x8xf32> { + // CHECK-NEXT: %[[ALL_TO_ALL:.*]] = sdy.all_to_all {"z":(2)2, "w"} 1->0 %arg0 out_sharding=<@mesh4d_z4, [{"x", "z":(2)2, "w"}, {"y", "z":(1)2}, {}]> + // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{}, {"y", "z":(1)2}, {}] %[[ALL_TO_ALL]] out_sharding=<@mesh4d_z4, [{"x", "z":(2)2, "w"}, {}, {}]> + // CHECK-NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{}, {}, {"y"}] %[[ALL_GATHER]] out_sharding=<@mesh4d_z4, [{"x", "z":(2)2, "w"}, {}, {"y"}]> + // CHECK-NEXT: return %[[ALL_SLICE]] + %0 = sdy.reshard %arg0 <@mesh4d_z4, [{"x", "z":(2)2, "w"}, {}, {"y"}]> : tensor<16x8x8xf32> + return %0 : tensor<16x8x8xf32> +} + +// CHECK-LABEL: func @slice_then_reorder_axes +func.func @slice_then_reorder_axes(%arg0 : tensor<16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{"y"}, {}, {}]>}) -> tensor<16x8x8xf32> { + // CHECK-NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{"x"}, {}, {}] %arg0 out_sharding=<@mesh2d, [{"y", "x"}, {}, {}]> + // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"y", "x"}, {}, {}] %[[ALL_SLICE]] out_sharding=<@mesh2d, [{}, {}, {}]> + // CHECK-NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{"x", "y"}, {}, {}] %[[ALL_GATHER]] out_sharding=<@mesh2d, [{"x", "y"}, {}, {}]> + // CHECK-NEXT: return %[[ALL_SLICE]] + %0 = sdy.reshard %arg0 <@mesh2d, [{"x", "y"}, {}, {}]> : tensor<16x8x8xf32> + return %0 : tensor<16x8x8xf32> +} + // CHECK-LABEL: func @major_axis_available_to_slice func.func @major_axis_available_to_slice(%arg0: tensor<16x8xf32> {sdy.sharding = #sdy.sharding<@mesh4d_w4, [{"y", "z", "w"}, {}]>}) -> tensor<16x8xf32> { - // CHECK-NEXT: %[[ALL_SLICE_0:.*]] = sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh4d_w4, [{"y", "z", "w"}, {"x"}]> - // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"z", "w"}, {}] %[[ALL_SLICE_0]] out_sharding=<@mesh4d_w4, [{"y"}, {"x"}]> - // CHECK-NEXT: %[[ALL_SLICE_1:.*]] = sdy.all_slice [{}, {"w"}] %[[ALL_GATHER]] out_sharding=<@mesh4d_w4, [{"y"}, {"x", "w"}]> - // CHECK-NEXT: return %[[ALL_SLICE_1]] + // CHECK-NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh4d_w4, [{"y", "z", "w"}, {"x"}]> + // CHECK-NEXT: %[[ALL_TO_ALL:.*]] = sdy.all_to_all {"w"} 0->1 %[[ALL_SLICE]] out_sharding=<@mesh4d_w4, [{"y", "z"}, {"x", "w"}]> + // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"z"}, {}] %[[ALL_TO_ALL]] out_sharding=<@mesh4d_w4, [{"y"}, {"x", "w"}]> + // CHECK-NEXT: return %[[ALL_GATHER]] %0 = sdy.reshard %arg0 <@mesh4d_w4, [{"y"}, {"x", "w"}]> : tensor<16x8xf32> return %0 : tensor<16x8xf32> } // CHECK-LABEL: func @prefix_subaxis_available_to_slice func.func @prefix_subaxis_available_to_slice(%arg0: tensor<16x8xf32> {sdy.sharding = #sdy.sharding<@mesh4d_w4, [{"y", "z", "w":(2)2}, {}]>}) -> tensor<16x8xf32> { - // CHECK-NEXT: %[[ALL_SLICE_0:.*]] = sdy.all_slice [{}, {"x", "w":(1)2}] %arg0 out_sharding=<@mesh4d_w4, [{"y", "z", "w":(2)2}, {"x", "w":(1)2}]> - // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"z", "w":(2)2}, {}] %[[ALL_SLICE_0]] out_sharding=<@mesh4d_w4, [{"y"}, {"x", "w":(1)2}]> - // CHECK-NEXT: %[[ALL_SLICE_1:.*]] = sdy.all_slice [{}, {"w":(2)2}] %[[ALL_GATHER]] out_sharding=<@mesh4d_w4, [{"y"}, {"x", "w"}]> - // CHECK-NEXT: return %[[ALL_SLICE_1]] + // CHECK-NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{}, {"x", "w":(1)2}] %arg0 out_sharding=<@mesh4d_w4, [{"y", "z", "w":(2)2}, {"x", "w":(1)2}]> + // CHECK-NEXT: %[[ALL_TO_ALL:.*]] = sdy.all_to_all {"w":(2)2} 0->1 %[[ALL_SLICE]] out_sharding=<@mesh4d_w4, [{"y", "z"}, {"x", "w"}]> + // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"z"}, {}] %[[ALL_TO_ALL]] out_sharding=<@mesh4d_w4, [{"y"}, {"x", "w"}]> + // CHECK-NEXT: return %[[ALL_GATHER]] %0 = sdy.reshard %arg0 <@mesh4d_w4, [{"y"}, {"x", "w"}]> : tensor<16x8xf32> return %0 : tensor<16x8xf32> } +// NOTE: this test case should have the following result with collective permute: +// %0 = sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh4d_w16, [{"y", "w":(4)2, "z", "w":(1)2}, {"x"}]> : tensor<16x8xf32> +// %1 = sdy.collective_permute %0 out_sharding=<@mesh4d_w16, [{"y", "w":(2)8}, {"x"}]> : tensor<16x8xf32> +// %2 = sdy.all_to_all {"w":(2)8} 0->1 %1 out_sharding=<@mesh4d_w16, [{"y"}, {"x", "w":(2)8}]> : tensor<16x8xf32> +// return %2 : tensor<16x8xf32> + // CHECK-LABEL: func @prefix_subaxis_available_to_slice_2 func.func @prefix_subaxis_available_to_slice_2(%arg0: tensor<16x8xf32> {sdy.sharding = #sdy.sharding<@mesh4d_w16, [{"y", "w":(4)2, "z", "w":(1)2}, {}]>}) -> tensor<16x8xf32> { - // CHECK-NEXT: %[[ALL_SLICE_0:.*]] = sdy.all_slice [{}, {"x", "w":(2)2}] %arg0 out_sharding=<@mesh4d_w16, [{"y", "w":(4)2, "z", "w":(1)2}, {"x", "w":(2)2}]> - // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"w":(4)2, "z", "w":(1)2}, {}] %[[ALL_SLICE_0]] out_sharding=<@mesh4d_w16, [{"y"}, {"x", "w":(2)2}]> + // CHECK-NEXT: %[[ALL_SLICE_0:.*]] = sdy.all_slice [{"w":(8)2}, {"x", "w":(2)2}] %arg0 out_sharding=<@mesh4d_w16, [{"y", "w":(4)2, "z", "w":(1)2, "w":(8)2}, {"x", "w":(2)2}]> + // CHECK-NEXT: %[[ALL_TO_ALL:.*]] = sdy.all_to_all {"w":(8)2} 0->1 %[[ALL_SLICE_0]] out_sharding=<@mesh4d_w16, [{"y", "w":(4)2, "z", "w":(1)2}, {"x", "w":(2)2, "w":(8)2}]> + // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"w":(4)2, "z", "w":(1)2}, {"w":(8)2}] %[[ALL_TO_ALL]] out_sharding=<@mesh4d_w16, [{"y"}, {"x", "w":(2)2}]> // CHECK-NEXT: %[[ALL_SLICE_1:.*]] = sdy.all_slice [{}, {"w":(4)4}] %[[ALL_GATHER]] out_sharding=<@mesh4d_w16, [{"y"}, {"x", "w":(2)8}]> // CHECK-NEXT: return %[[ALL_SLICE_1]] %0 = sdy.reshard %arg0 <@mesh4d_w16, [{"y"}, {"x", "w":(2)8}]> : tensor<16x8xf32> @@ -90,7 +242,8 @@ func.func @prefix_subaxis_available_to_slice_2(%arg0: tensor<16x8xf32> {sdy.shar // CHECK-LABEL: func @split_full_axis_not_available_to_slice func.func @split_full_axis_not_available_to_slice(%arg0: tensor<16x8xf32> {sdy.sharding = #sdy.sharding<@mesh4d_w4, [{"y", "w":(1)2, "z", "w":(2)2}, {}]>}) -> tensor<16x8xf32> { // CHECK-NEXT: %[[ALL_SLICE_0:.*]] = sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh4d_w4, [{"y", "w":(1)2, "z", "w":(2)2}, {"x"}]> - // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"w":(1)2, "z", "w":(2)2}, {}] %[[ALL_SLICE_0]] out_sharding=<@mesh4d_w4, [{"y"}, {"x"}]> + // CHECK-NEXT: %[[ALL_TO_ALL:.*]] = sdy.all_to_all {"w":(2)2} 0->1 %[[ALL_SLICE_0]] out_sharding=<@mesh4d_w4, [{"y", "w":(1)2, "z"}, {"x", "w":(2)2}]> + // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"w":(1)2, "z"}, {"w":(2)2}] %[[ALL_TO_ALL]] out_sharding=<@mesh4d_w4, [{"y"}, {"x"}]> // CHECK-NEXT: %[[ALL_SLICE_1:.*]] = sdy.all_slice [{}, {"w"}] %[[ALL_GATHER]] out_sharding=<@mesh4d_w4, [{"y"}, {"x", "w"}]> // CHECK-NEXT: return %[[ALL_SLICE_1]] %0 = sdy.reshard %arg0 <@mesh4d_w4, [{"y"}, {"x", "w"}]> : tensor<16x8xf32> @@ -99,47 +252,50 @@ func.func @split_full_axis_not_available_to_slice(%arg0: tensor<16x8xf32> {sdy.s // CHECK-LABEL: func @prefix_subaxis_not_available_to_slice func.func @prefix_subaxis_not_available_to_slice(%arg0: tensor<16x8xf32> {sdy.sharding = #sdy.sharding<@mesh4d_w4, [{"y", "z", "w":(1)2}, {}]>}) -> tensor<16x8xf32> { - // CHECK-NEXT: %[[ALL_SLICE_0:.*]] = sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh4d_w4, [{"y", "z", "w":(1)2}, {"x"}]> - // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"z", "w":(1)2}, {}] %[[ALL_SLICE_0]] out_sharding=<@mesh4d_w4, [{"y"}, {"x"}]> - // CHECK-NEXT: %[[ALL_SLICE_1:.*]] = sdy.all_slice [{}, {"w"}] %[[ALL_GATHER]] out_sharding=<@mesh4d_w4, [{"y"}, {"x", "w"}]> - // CHECK-NEXT: return %[[ALL_SLICE_1]] + // CHECK-NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{"w":(2)2}, {"x"}] %arg0 out_sharding=<@mesh4d_w4, [{"y", "z", "w"}, {"x"}]> + // CHECK-NEXT: %[[ALL_TO_ALL:.*]] = sdy.all_to_all {"w"} 0->1 %[[ALL_SLICE]] out_sharding=<@mesh4d_w4, [{"y", "z"}, {"x", "w"}]> + // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"z"}, {}] %[[ALL_TO_ALL]] out_sharding=<@mesh4d_w4, [{"y"}, {"x", "w"}]> + // CHECK-NEXT: return %[[ALL_GATHER]] %0 = sdy.reshard %arg0 <@mesh4d_w4, [{"y"}, {"x", "w"}]> : tensor<16x8xf32> return %0 : tensor<16x8xf32> } // CHECK-LABEL: func @prefix_and_suffix_subaxes_not_available_to_slice func.func @prefix_and_suffix_subaxes_not_available_to_slice(%arg0: tensor<16x8xf32> {sdy.sharding = #sdy.sharding<@mesh4d_w16, [{"y", "w":(4)2, "z", "w":(1)2}, {}]>}) -> tensor<16x8xf32> { - // CHECK-NEXT: %[[ALL_SLICE_0:.*]] = sdy.all_slice [{}, {"x"}] %arg0 out_sharding=<@mesh4d_w16, [{"y", "w":(4)2, "z", "w":(1)2}, {"x"}]> - // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"w":(4)2, "z", "w":(1)2}, {}] %[[ALL_SLICE_0]] out_sharding=<@mesh4d_w16, [{"y"}, {"x"}]> - // CHECK-NEXT: %[[ALL_SLICE_1:.*]] = sdy.all_slice [{}, {"w"}] %[[ALL_GATHER]] out_sharding=<@mesh4d_w16, [{"y"}, {"x", "w"}]> + // CHECK-NEXT: %[[ALL_SLICE_0:.*]] = sdy.all_slice [{"w":(2)2, "w":(8)2}, {"x"}] %arg0 out_sharding=<@mesh4d_w16, [{"y", "w":(4)2, "z", "w":(1)4, "w":(8)2}, {"x"}]> + // CHECK-NEXT: %[[ALL_TO_ALL:.*]] = sdy.all_to_all {"w":(1)4, "w":(8)2} 0->1 %[[ALL_SLICE_0]] out_sharding=<@mesh4d_w16, [{"y", "w":(4)2, "z"}, {"x", "w":(1)4, "w":(8)2}]> + // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"w":(4)2, "z"}, {"w":(8)2}] %[[ALL_TO_ALL]] out_sharding=<@mesh4d_w16, [{"y"}, {"x", "w":(1)4}]> + // CHECK-NEXT: %[[ALL_SLICE_1:.*]] = sdy.all_slice [{}, {"w":(4)4}] %[[ALL_GATHER]] out_sharding=<@mesh4d_w16, [{"y"}, {"x", "w"}]> // CHECK-NEXT: return %[[ALL_SLICE_1]] %0 = sdy.reshard %arg0 <@mesh4d_w16, [{"y"}, {"x", "w"}]> : tensor<16x8xf32> return %0 : tensor<16x8xf32> } -// CHECK-LABEL: func @reshard_with_non_divisible_subaxes_same_pre_size -func.func @reshard_with_non_divisible_subaxes_same_pre_size(%arg0 : tensor<6x2xf32> {sdy.sharding=#sdy.sharding<@mesh1d_6, [{"x":(1)2}, {}]>}) -> tensor<6x2xf32> { - // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"x":(1)2}, {}] %arg0 out_sharding=<@mesh1d_6, [{}, {}]> - // CHECK-NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{"x":(1)3}, {}] %[[ALL_GATHER]] out_sharding=<@mesh1d_6, [{"x":(1)3}, {}]> - // CHECK-NEXT: return %[[ALL_SLICE]] - %0 = sdy.reshard %arg0 <@mesh1d_6, [{"x":(1)3}, {}]> : tensor<6x2xf32> - return %0 : tensor<6x2xf32> -} +// TODO(b/391138813): Add proper support for axes that can't co-exist -// CHECK-LABEL: func @reshard_with_non_divisible_overlapping_subaxes -func.func @reshard_with_non_divisible_overlapping_subaxes(%arg0 : tensor<6x2xf32> {sdy.sharding=#sdy.sharding<@mesh1d_6, [{"x":(2)3}, {}]>}) -> tensor<6x2xf32> { - // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"x":(2)3}, {}] %arg0 out_sharding=<@mesh1d_6, [{}, {}]> - // CHECK-NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{"x":(1)3}, {}] %[[ALL_GATHER]] out_sharding=<@mesh1d_6, [{"x":(1)3}, {}]> - // CHECK-NEXT: return %[[ALL_SLICE]] - %0 = sdy.reshard %arg0 <@mesh1d_6, [{"x":(1)3}, {}]> : tensor<6x2xf32> - return %0 : tensor<6x2xf32> -} +// LABEL: func @reshard_with_non_divisible_subaxes_same_pre_size +// func.func @reshard_with_non_divisible_subaxes_same_pre_size(%arg0 : tensor<6x2xf32> {sdy.sharding=#sdy.sharding<@mesh1d_6, [{"x":(1)2}, {}]>}) -> tensor<6x2xf32> { +// NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"x":(1)2}, {}] %arg0 out_sharding=<@mesh1d_6, [{}, {}]> +// NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{"x":(1)3}, {}] %[[ALL_GATHER]] out_sharding=<@mesh1d_6, [{"x":(1)3}, {}]> +// NEXT: return %[[ALL_SLICE]] +// %0 = sdy.reshard %arg0 <@mesh1d_6, [{"x":(1)3}, {}]> : tensor<6x2xf32> +// return %0 : tensor<6x2xf32> +// } -// CHECK-LABEL: func @reshard_with_non_divisible_overlapping_diff_dim -func.func @reshard_with_non_divisible_overlapping_diff_dim(%arg0 : tensor<6x2xf32> {sdy.sharding=#sdy.sharding<@mesh1d_6, [{"x":(2)3}, {}]>}) -> tensor<6x2xf32> { - // CHECK-NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"x":(2)3}, {}] %arg0 out_sharding=<@mesh1d_6, [{}, {}]> - // CHECK-NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{}, {"x":(1)3}] %[[ALL_GATHER]] out_sharding=<@mesh1d_6, [{}, {"x":(1)3}]> - // CHECK-NEXT: return %[[ALL_SLICE]] - %0 = sdy.reshard %arg0 <@mesh1d_6, [{}, {"x":(1)3}]> : tensor<6x2xf32> - return %0 : tensor<6x2xf32> -} +// LABEL: func @reshard_with_non_divisible_overlapping_subaxes +// func.func @reshard_with_non_divisible_overlapping_subaxes(%arg0 : tensor<6x2xf32> {sdy.sharding=#sdy.sharding<@mesh1d_6, [{"x":(2)3}, {}]>}) -> tensor<6x2xf32> { +// NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"x":(2)3}, {}] %arg0 out_sharding=<@mesh1d_6, [{}, {}]> +// NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{"x":(1)3}, {}] %[[ALL_GATHER]] out_sharding=<@mesh1d_6, [{"x":(1)3}, {}]> +// NEXT: return %[[ALL_SLICE]] +// %0 = sdy.reshard %arg0 <@mesh1d_6, [{"x":(1)3}, {}]> : tensor<6x2xf32> +// return %0 : tensor<6x2xf32> +// } + +// LABEL: func @reshard_with_non_divisible_overlapping_diff_dim +// func.func @reshard_with_non_divisible_overlapping_diff_dim(%arg0 : tensor<6x2xf32> {sdy.sharding=#sdy.sharding<@mesh1d_6, [{"x":(2)3}, {}]>}) -> tensor<6x2xf32> { +// NEXT: %[[ALL_GATHER:.*]] = sdy.all_gather [{"x":(2)3}, {}] %arg0 out_sharding=<@mesh1d_6, [{}, {}]> +// NEXT: %[[ALL_SLICE:.*]] = sdy.all_slice [{}, {"x":(1)3}] %[[ALL_GATHER]] out_sharding=<@mesh1d_6, [{}, {"x":(1)3}]> +// NEXT: return %[[ALL_SLICE]] +// %0 = sdy.reshard %arg0 <@mesh1d_6, [{}, {"x":(1)3}]> : tensor<6x2xf32> +// return %0 : tensor<6x2xf32> +// }