Skip to content

Commit

Permalink
#sdy Add support for all-to-all in -sdy-reshard-to-collectives and …
Browse files Browse the repository at this point in the history
…make further improvements.

Improvements:
1. Introduces a CollectiveInserter class that holds and manipulates the current state of the transformation.
2. Simplify the handling of sub-axes by aligning them between the input and output shardings, so that we can treat them as full axes and use them as keys in hash maps.
3. Make all-slice and all-to-all insertion greedy, i.e., we slice axes if they are available and all-to-all axes if they are the suffix of the source dimension, regardless of whether they will end up in the right position in the target dimension.
4. All-slice in a dimension that isn't the one in the output sharding, if we can then all-to-all the sliced axis together with an axis that was already in that source dimension, to the same target dimension.

A few notes:
1. Greedy all-slice and all-to-all is only possible if the target dimension has the capacity, we ignore this constraint in this CL and will address in a follow up.
2. This pass is still missing support for collective-permute, which is crucial for getting optimal collectives in cases where axes are reordered or replaced with others in the same dimension. Added some use cases to the test.

PiperOrigin-RevId: 723451739
  • Loading branch information
tomnatan30 authored and copybara-github committed Feb 5, 2025
1 parent 35f1ea9 commit 9541653
Show file tree
Hide file tree
Showing 5 changed files with 802 additions and 452 deletions.
40 changes: 10 additions & 30 deletions shardy/dialect/sdy/ir/attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -362,25 +362,25 @@ def Sdy_AxisRef : AttrDef<Sdy_Dialect, "AxisRef"> {
// "a":(1)3, "a":(2)3 -> false
bool canCoexist(AxisRefAttr other) const;

// Returns the largest prefix of this axis that overlaps with `other`, or
// `std::nullopt` if the prefix does not exist.
// If this axis or sub-axis overlaps with `other`, returns that overlapping
// axis or sub-axis, otherwise returns `std::nullopt`.
//
// If this axis and `other` can't coexist, returns `std::nullopt` (see
// AxisRefAttr::canCoexist).
//
// For example:
// "a", "a" -> "a"
// "a":(2)2, "a" -> "a":(2)2
// "a", "a":(2)2 -> "a":(2)2
// "a":(2)2, "a":(2)2 -> "a":(2)2
// "a":(1)4, "a":(2)4 -> "a":(2)2
// "a":(2)4, "a":(1)4 -> "a":(2)2
// "a":(1)4, "a":(1)2 -> "a":(1)2
// "a":(2)8, "a":(1)4 -> "a":(2)2
// "a", "b" -> std::nullopt
// "a":(2)2, "b" -> std::nullopt
// "a":(1)4, "a":(2)4 -> std::nullopt
// "a":(2)8, "a":(4)2 -> "a":(4)2
// "a":(1)4, "a":(4)2 -> std::nullopt
// "a":(1)2, "a":(4)2 -> std::nullopt
// "a":(1)4, "b":(2)4 -> std::nullopt
// "a":(1)2, "a":(1)3 -> std::nullopt
// "a":(3)2, "a":(2)3 -> std::nullopt
std::optional<AxisRefAttr> getPrefixWithOverlap(
AxisRefAttr other, MeshAttr mesh) const;
std::optional<AxisRefAttr> getOverlap(AxisRefAttr other) const;

// If there is no overlap between this axis and `other`, return this axis.
// Otherwise, return the largest prefix of this axis by removing the
Expand Down Expand Up @@ -443,26 +443,6 @@ def Sdy_AxisRef : AttrDef<Sdy_Dialect, "AxisRef"> {
// "a":(1)2, "a":(2)4 -> std::nullopt
std::optional<AxisRefAttr> getGreatestCommonPrefix(AxisRefAttr other) const;

// Removes the common prefix of this axis and `other` from this axis. If the
// two axes do not have common prefix or `other` is greater or equal to this
// axis, return `std::nullopt`.
//
// If this axis and `other` can't coexist, returns `std::nullopt` (see
// AxisRefAttr::canCoexist).
//
// For example:
// "a", "a":(1)4 -> "a":(4)2 (size("a") == 8)
// "a":(1)4, "a":(1)2 -> "a":(2)2
// "a":(2)8, "a":(2)4 -> "a":(8)2
// "a", "b" -> std::nullopt
// "a", "a" -> std::nullopt
// "a":(1)4, "a" -> std::nullopt
// "a":(2)4, "a":(2)8 -> std::nullopt
// "a":(1)2, "a":(2)4 -> std::nullopt
// "a":(1)2, "a":(1)3 -> std::nullopt
std::optional<AxisRefAttr> removeCommonPrefix(
AxisRefAttr prefix, MeshAttr mesh) const;

// Returns whether this axis-ref can be merged with `other`, i.e., they are
// consecutive sub-axes of the same full axis and this sub-axis is major to
// `other`.
Expand Down
77 changes: 50 additions & 27 deletions shardy/dialect/sdy/ir/dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,22 @@ bool AxisRefAttr::overlaps(AxisRefAttr other) const {
otherSubAxisInfo.getPreSize() < thisSubAxisInfo.getNextPreSize();
}

namespace {

bool canSubAxesCoexist(int64_t minPreSize, int64_t maxPreSize,
int64_t minNextPreSize, int64_t maxNextPreSize) {
if (minNextPreSize > maxPreSize) {
// Sub-axes overlap, check if overlapping and non-overlapping parts are
// valid.
return minNextPreSize % maxPreSize == 0 && maxPreSize % minPreSize == 0 &&
maxNextPreSize % minNextPreSize == 0;
}
// Sub-axes don't overlap, check if the gap is valid.
return maxPreSize % minNextPreSize == 0;
}

} // namespace

bool AxisRefAttr::canCoexist(AxisRefAttr other) const {
if (getName() != other.getName()) {
return true;
Expand All @@ -457,31 +473,46 @@ bool AxisRefAttr::canCoexist(AxisRefAttr other) const {
auto [minNextPreSize, maxNextPreSize] =
std::minmax(thisNextPreSize, otherNextPreSize);

if (minNextPreSize > maxPreSize) {
// Sub-axes overlap, check if overlapping and non-overlapping parts are
// valid.
return minNextPreSize % maxPreSize == 0 && maxPreSize % minPreSize == 0 &&
maxNextPreSize % minNextPreSize == 0;
}
// Sub-axes don't overlap, check if the gap is valid.
return maxPreSize % minNextPreSize == 0;
return canSubAxesCoexist(minPreSize, maxPreSize, minNextPreSize,
maxNextPreSize);
}

std::optional<AxisRefAttr> AxisRefAttr::getPrefixWithOverlap(
AxisRefAttr other, MeshAttr mesh) const {
int64_t thisPreSize = getSubAxisPreSize();
if (!canCoexist(other) || !overlaps(other) ||
other.getSubAxisPreSize() > thisPreSize) {
std::optional<AxisRefAttr> AxisRefAttr::getOverlap(AxisRefAttr other) const {
if (other.getName() != getName()) {
return std::nullopt;
}
if (other.contains(*this)) {

SubAxisInfoAttr thisSubAxisInfo = getSubAxisInfo();
SubAxisInfoAttr otherSubAxisInfo = other.getSubAxisInfo();

if (!thisSubAxisInfo) {
// This is a full axis.
return other;
}

if (!otherSubAxisInfo) {
// Other is a full axis.
return *this;
}
int64_t thisNextPreSize = getNextPreSizeOrFullSize(mesh);
int64_t otherNextPreSize = other.getNextPreSizeOrFullSize(mesh);
return AxisRefAttr::get(
getContext(), getName(), thisPreSize,
std::min(thisNextPreSize, otherNextPreSize) / thisPreSize);

int64_t thisPreSize = thisSubAxisInfo.getPreSize();
int64_t otherPreSize = otherSubAxisInfo.getPreSize();
int64_t thisNextPreSize = thisSubAxisInfo.getNextPreSize();
int64_t otherNextPreSize = otherSubAxisInfo.getNextPreSize();

auto [minPreSize, maxPreSize] = std::minmax(thisPreSize, otherPreSize);
auto [minNextPreSize, maxNextPreSize] =
std::minmax(thisNextPreSize, otherNextPreSize);

if (minNextPreSize <= maxPreSize ||
!canSubAxesCoexist(minPreSize, maxPreSize, minNextPreSize,
maxNextPreSize)) {
// No overlap or can't co-exist.
return std::nullopt;
}

return AxisRefAttr::get(getContext(), getName(), maxPreSize,
minNextPreSize / maxPreSize);
}

std::optional<AxisRefAttr> AxisRefAttr::getPrefixWithoutOverlap(
Expand Down Expand Up @@ -556,14 +587,6 @@ std::optional<AxisRefAttr> AxisRefAttr::getGreatestCommonPrefix(
return std::nullopt;
}

std::optional<AxisRefAttr> AxisRefAttr::removeCommonPrefix(
AxisRefAttr prefix, MeshAttr mesh) const {
if (!prefix.strictPrefixOf(*this)) {
return std::nullopt;
}
return getSuffixWithoutOverlap(prefix, mesh);
}

//===----------------------------------------------------------------------===//
// DimensionShardingAttr
//===----------------------------------------------------------------------===//
Expand Down
104 changes: 32 additions & 72 deletions shardy/dialect/sdy/ir/dialect_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,57 +252,47 @@ TEST_F(DialectTest, AxisRefAttrCompare) {
compare(createSubAxis("x", 1, 4), createSubAxis("x", 2, 2));
}

TEST_F(DialectTest, AxisRefAttrGetPrefixWithOverlap) {
auto mesh = MeshAttr::get(&context, {MeshAxisAttr::get(&context, "x", 16),
MeshAxisAttr::get(&context, "y", 4)});
auto samePrefix = [&](AxisRefAttr a, AxisRefAttr b) {
AxisRefAttr smaller = std::min(a, b);
EXPECT_EQ(a.getPrefixWithOverlap(b, mesh), smaller);
EXPECT_EQ(b.getPrefixWithOverlap(a, mesh), smaller);
TEST_F(DialectTest, AxisRefAttrGetOverlap) {
auto contained = [](AxisRefAttr small, AxisRefAttr large) {
EXPECT_TRUE(large.contains(small));
EXPECT_EQ(large.getOverlap(small), small);
EXPECT_EQ(small.getOverlap(large), small);
};
samePrefix(createAxis("x"), createAxis("x"));
samePrefix(createSubAxis("x", 2, 2), createSubAxis("x", 2, 2));
samePrefix(createSubAxis("x", 1, 4), createSubAxis("x", 1, 2));

// "x":(2)4 and "x"
EXPECT_EQ(
createSubAxis("x", 2, 4).getPrefixWithOverlap(createAxis("x"), mesh),
createSubAxis("x", 2, 4));
EXPECT_EQ(
createAxis("x").getPrefixWithOverlap(createSubAxis("x", 2, 4), mesh),
std::nullopt);

// "x":(2)4 and "x":(1)4
EXPECT_EQ(createSubAxis("x", 2, 4).getPrefixWithOverlap(
createSubAxis("x", 1, 4), mesh),
createSubAxis("x", 2, 2));
EXPECT_EQ(createSubAxis("x", 1, 4).getPrefixWithOverlap(
createSubAxis("x", 2, 4), mesh),
std::nullopt);

// "x"(4)2 and "x":(2)8
EXPECT_EQ(createSubAxis("x", 4, 2).getPrefixWithOverlap(
createSubAxis("x", 2, 8), mesh),
createSubAxis("x", 4, 2));
EXPECT_EQ(createSubAxis("x", 2, 8).getPrefixWithOverlap(
createSubAxis("x", 4, 2), mesh),
std::nullopt);
contained(createAxis("x"), createAxis("x"));
contained(createSubAxis("x", 1, 4), createAxis("x"));
contained(createSubAxis("x", 4, 8), createAxis("x"));
contained(createSubAxis("x", 2, 2), createSubAxis("x", 2, 2));
contained(createSubAxis("x", 1, 2), createSubAxis("x", 1, 4));
contained(createSubAxis("x", 2, 2), createSubAxis("x", 1, 4));
contained(createSubAxis("x", 2, 2), createSubAxis("x", 1, 8));

auto overlaps = [](AxisRefAttr a, AxisRefAttr b, AxisRefAttr expected) {
EXPECT_EQ(a.getOverlap(b), expected);
EXPECT_EQ(b.getOverlap(a), expected);
};
overlaps(createSubAxis("x", 1, 4), createSubAxis("x", 2, 4),
createSubAxis("x", 2, 2));
overlaps(createSubAxis("x", 4, 4), createSubAxis("x", 2, 4),
createSubAxis("x", 4, 2));

auto checkNoOverlap = [&](AxisRefAttr a, AxisRefAttr b) {
EXPECT_EQ(a.getPrefixWithOverlap(b, mesh), std::nullopt);
EXPECT_EQ(b.getPrefixWithOverlap(a, mesh), std::nullopt);
auto checkNoOverlap = [](AxisRefAttr a, AxisRefAttr b) {
EXPECT_FALSE(a.overlaps(b));
EXPECT_EQ(a.getOverlap(b), std::nullopt);
EXPECT_EQ(b.getOverlap(a), std::nullopt);
};
checkNoOverlap(createAxis("x"), createAxis("y"));
checkNoOverlap(createAxis("x"), createSubAxis("y", 1, 2));
checkNoOverlap(createSubAxis("x", 1, 4), createSubAxis("x", 4, 2));
checkNoOverlap(createSubAxis("x", 1, 2), createSubAxis("x", 4, 2));

auto checkCannotCoexist = [&](AxisRefAttr a, AxisRefAttr b) {
EXPECT_EQ(a.getPrefixWithOverlap(b, mesh), std::nullopt);
EXPECT_EQ(b.getPrefixWithOverlap(a, mesh), std::nullopt);
auto checkCannotCoexist = [](AxisRefAttr a, AxisRefAttr b) {
EXPECT_FALSE(a.canCoexist(b));
EXPECT_EQ(a.getOverlap(b), std::nullopt);
EXPECT_EQ(b.getOverlap(a), std::nullopt);
};
checkCannotCoexist(createSubAxis("x", 1, 2), createSubAxis("x", 1, 3));
checkCannotCoexist(createSubAxis("x", 3, 2), createSubAxis("x", 2, 3));
checkCannotCoexist(createSubAxis("x", 1, 2), createSubAxis("x", 3, 2));
checkCannotCoexist(createSubAxis("x", 1, 3), createSubAxis("x", 2, 3));
checkCannotCoexist(createSubAxis("x", 2, 3), createSubAxis("x", 3, 2));
}

// The test cases are the same as DialectTest.AxisRefAttrOverlaps.
Expand Down Expand Up @@ -453,36 +443,6 @@ TEST_F(DialectTest, AxisRefAttrGetGreatestCommonPrefix) {
prefix(createSubAxis("x", 2, 4), createSubAxis("x", 2, 8));
}

TEST_F(DialectTest, AxisRefAttrRemoveCommonPrefix) {
auto mesh = MeshAttr::get(&context, {MeshAxisAttr::get(&context, "x", 16),
MeshAxisAttr::get(&context, "y", 4)});
auto isNotPrefix = [&](AxisRefAttr a, AxisRefAttr b) {
EXPECT_EQ(a.removeCommonPrefix(b, mesh), std::nullopt);
EXPECT_EQ(b.removeCommonPrefix(a, mesh), std::nullopt);
};
isNotPrefix(createAxis("x"), createAxis("y"));
isNotPrefix(createSubAxis("x", 1, 2), createSubAxis("y", 1, 2));
isNotPrefix(createSubAxis("x", 1, 2), createSubAxis("x", 2, 4));
isNotPrefix(createSubAxis("x", 1, 2), createSubAxis("x", 1, 3));

auto equals = [&](AxisRefAttr a) {
EXPECT_EQ(a.removeCommonPrefix(a, mesh), std::nullopt);
};
equals(createAxis("x"));
equals(createSubAxis("x", 2, 4));

auto prefix = [&](AxisRefAttr small, AxisRefAttr large,
AxisRefAttr expected) {
EXPECT_EQ(large.removeCommonPrefix(small, mesh), expected);
EXPECT_EQ(small.removeCommonPrefix(large, mesh), std::nullopt);
};
prefix(createSubAxis("x", 1, 4), createAxis("x"), createSubAxis("x", 4, 4));
prefix(createSubAxis("x", 1, 2), createSubAxis("x", 1, 4),
createSubAxis("x", 2, 2));
prefix(createSubAxis("x", 2, 4), createSubAxis("x", 2, 8),
createSubAxis("x", 8, 2));
}

TEST_F(DialectTest, TensorShardingAttrCanShardOrReplicate) {
TensorShardingAttr sharding = createTensorSharding(
{createDimSharding({createAxis("x"), createSubAxis("z", 2, 2)},
Expand Down
Loading

0 comments on commit 9541653

Please sign in to comment.