diff --git a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc index 6ae35e2..aaded5f 100644 --- a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc @@ -85,24 +85,20 @@ SmallVector getFactorToSourceTensor( } // namespace UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings( - ShardingProjection& projection, PropagationDirection direction, - PropagateAlongFactorPred propagateAlongFactor, + ShardingProjection& projection, + PropagationDirectionAlongFactor directionAlongFactor, ArrayRef factorSizes, MeshAttr mesh, Operation* op, bool conservativePropagation) const { UpdateTensorShardings result(projection.getNumOperands(), projection.getNumResults()); - if (direction == PropagationDirection::NONE) { - return result; - } // Find the compatible major axes ignoring conflicts. AxesPerFactor axesPerFactor; axesPerFactor.reserve(factorSizes.size()); bool allElementsAreEmpty = true; for (int64_t i = 0; i < factorSizes.size(); ++i) { - SmallVector& axes = - axesPerFactor.emplace_back(getCompatibleMajorAxes( - projection, i, direction, propagateAlongFactor, op)); + SmallVector& axes = axesPerFactor.emplace_back( + getCompatibleMajorAxes(projection, i, directionAlongFactor(i), op)); if (!axes.empty()) { allElementsAreEmpty = false; } diff --git a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h index f51209b..cf44b14 100644 --- a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h +++ b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h @@ -77,8 +77,8 @@ namespace sdy { class AggressiveFactorPropagation : public BasicFactorPropagation { public: UpdateTensorShardings propagateFactorShardings( - ShardingProjection& projection, PropagationDirection direction, - PropagateAlongFactorPred propagateAlongFactor, + ShardingProjection& projection, + PropagationDirectionAlongFactor directionAlongFactor, ArrayRef factorSizes, MeshAttr mesh, Operation* op, bool conservativePropagation) const override; }; diff --git a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation_test.cc b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation_test.cc index e495834..45e12e1 100644 --- a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation_test.cc +++ b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation_test.cc @@ -34,17 +34,19 @@ namespace { using ::testing::ElementsAre; using ::testing::IsEmpty; +PropagationDirectionAlongFactor propagateAnything() { + return [](int64_t) { return PropagationDirection::BOTH; }; +} + class AggressiveFactorPropagationTest : public PropagationTestBase { protected: UpdateTensorShardings propagateFactorShardings( ShardingProjection& projection, int64_t numFactors, - PropagateAlongFactorPred propagateAlongFactor = [](int64_t) { - return true; - }) { + PropagationDirectionAlongFactor directionAlongFactor = + propagateAnything()) { return AggressiveFactorPropagation().propagateFactorShardings( - projection, /*direction=*/PropagationDirection::BOTH, - propagateAlongFactor, SmallVector(numFactors, 1), - /*mesh=*/nullptr, /*op=*/nullptr, /*conservativePropagation*/ false); + projection, directionAlongFactor, SmallVector(numFactors, 1), + /*mesh=*/nullptr, /*op=*/nullptr, /*conservativePropagation=*/false); } }; @@ -357,7 +359,7 @@ TEST_F(AggressiveFactorPropagationTest, PropagateAlongSpecificFactor) { .factorIndexToSharding = {{0, {}}, {1, {}}}}; auto propagateAlongFactor = - [&](PropagateAlongFactorPred propagateAlongFactor, + [&](PropagationDirectionAlongFactor propagateAlongFactor, const ShardingProjection& projectionExpected) { ShardingProjection projection( /*operands=*/{factor0IsSharded, factor1IsSharded}, @@ -370,25 +372,27 @@ TEST_F(AggressiveFactorPropagationTest, PropagateAlongSpecificFactor) { EXPECT_EQ(projection, projectionExpected); }; + auto propagateAlongFactorOnly = [](int64_t factorIndex) { + return [factorIndex](int64_t i) { + return i == factorIndex ? PropagationDirection::BOTH + : PropagationDirection::NONE; + }; + }; + ShardingProjection propagateAlongFactor0Expected( /*operands=*/{factor0IsSharded, factor1IsSharded}, /*results=*/{factor0IsSharded}); - propagateAlongFactor([](int64_t factorIndex) { return factorIndex == 0; }, - propagateAlongFactor0Expected); - propagateAlongFactor([](int64_t factorIndex) { return factorIndex != 1; }, - propagateAlongFactor0Expected); - // When we propagate along all factors, we propagate "a" to the result along - // factor 0. - propagateAlongFactor([](int64_t factorIndex) { return true; }, - propagateAlongFactor0Expected); - ShardingProjection propagateAlongFactor1Expected( /*operands=*/{factor0IsSharded, factor1IsSharded}, /*results=*/{factor1IsSharded}); - propagateAlongFactor([](int64_t factorIndex) { return factorIndex == 1; }, - propagateAlongFactor1Expected); - propagateAlongFactor([](int64_t factorIndex) { return factorIndex != 0; }, + propagateAlongFactor(propagateAlongFactorOnly(0), + propagateAlongFactor0Expected); + propagateAlongFactor(propagateAlongFactorOnly(1), propagateAlongFactor1Expected); + + // When we propagate along all factors, we propagate "a" to the result along + // factor 0. + propagateAlongFactor(propagateAnything(), propagateAlongFactor0Expected); } } // namespace diff --git a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc index 72edb86..1743279 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc @@ -305,9 +305,8 @@ std::pair, bool> getCompatibleMajorAxesInternal( SmallVector BasicFactorPropagation::getCompatibleMajorAxes( const ShardingProjection& projection, int64_t factorIndex, - PropagationDirection direction, - PropagateAlongFactorPred propagateAlongFactor, Operation* op) const { - if (!propagateAlongFactor(factorIndex)) { + PropagationDirection direction, Operation* op) const { + if (direction == PropagationDirection::NONE) { return {}; } @@ -387,16 +386,11 @@ std::optional BasicFactorPropagation::compatiblePrefix( SmallVector BasicFactorPropagation::getCompatibleMajorShardingAxes( const ShardingProjection& projection, int64_t factorIndex, - PropagationDirection direction, - PropagateAlongFactorPred propagateAlongFactor, int64_t factorSize, - MeshAttr mesh, Operation* op, bool conservativePropagation) const { - if (direction == PropagationDirection::NONE) { - return SmallVector(); - } - + PropagationDirection direction, int64_t factorSize, MeshAttr mesh, + Operation* op, bool conservativePropagation) const { // Finds the compatible major axes ignoring conflicts. - SmallVector resultAxes = getCompatibleMajorAxes( - projection, factorIndex, direction, propagateAlongFactor, op); + SmallVector resultAxes = + getCompatibleMajorAxes(projection, factorIndex, direction, op); // Removes the major-most axis that isn't compatible w.r.t. other factors or // the replicated axes, and all axes that are minor to it. @@ -412,8 +406,8 @@ SmallVector BasicFactorPropagation::getCompatibleMajorShardingAxes( } UpdateTensorShardings BasicFactorPropagation::propagateFactorShardings( - ShardingProjection& projection, PropagationDirection direction, - PropagateAlongFactorPred propagateAlongFactor, + ShardingProjection& projection, + PropagationDirectionAlongFactor directionAlongFactor, ArrayRef factorSizes, MeshAttr mesh, Operation* op, bool conservativePropagation) const { UpdateTensorShardings result(projection.getNumOperands(), @@ -425,7 +419,7 @@ UpdateTensorShardings BasicFactorPropagation::propagateFactorShardings( // that factor for all tensors, those are the axes we will propagate to // tensors that aren't already sharded. SmallVector axesToPropagate = getCompatibleMajorShardingAxes( - projection, factorIndex, direction, propagateAlongFactor, factorSize, + projection, factorIndex, directionAlongFactor(factorIndex), factorSize, mesh, op, conservativePropagation); // Update all shardings along this factor if possible. diff --git a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h index 8a7570a..f74e6d9 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h +++ b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h @@ -42,8 +42,8 @@ class BasicFactorPropagation : public FactorPropagation { // Propagates the factor shardings in `projection`. UpdateTensorShardings propagateFactorShardings( - ShardingProjection& projection, PropagationDirection direction, - PropagateAlongFactorPred propagateAlongFactor, + ShardingProjection& projection, + PropagationDirectionAlongFactor directionAlongFactor, ArrayRef factorSizes, MeshAttr mesh, Operation* op, bool conservativePropagation) const override; @@ -87,9 +87,8 @@ class BasicFactorPropagation : public FactorPropagation { // ["a":(1)2]. SmallVector getCompatibleMajorShardingAxes( const ShardingProjection& projection, int64_t factorIndex, - PropagationDirection direction, - PropagateAlongFactorPred propagateAlongFactor, int64_t factorSize, - MeshAttr mesh, Operation* op, bool conservativePropagation) const; + PropagationDirection direction, int64_t factorSize, MeshAttr mesh, + Operation* op, bool conservativePropagation) const; // Finds the longest prefix of axes that shard the given factor, such that all // tensors either: @@ -100,8 +99,7 @@ class BasicFactorPropagation : public FactorPropagation { // This method does not resolve conflicts across factors or replicated axes. SmallVector getCompatibleMajorAxes( const ShardingProjection& projection, int64_t factorIndex, - PropagationDirection direction, - PropagateAlongFactorPred propagateAlongFactor, Operation* op) const; + PropagationDirection direction, Operation* op) const; // Returns the largest prefix of `axisRef`, which does not overlap with // sharding axes and overflow axes for all other factors. diff --git a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation_test.cc b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation_test.cc index dc6883e..cb6f1ca 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation_test.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation_test.cc @@ -36,29 +36,30 @@ namespace { using ::testing::ElementsAre; using ::testing::IsEmpty; +PropagationDirectionAlongFactor propagateAnything() { + return [](int64_t) { return PropagationDirection::BOTH; }; +} + class BasicFactorPropagationTest : public PropagationTestBase { protected: UpdateTensorShardings propagateFactorShardings( ShardingProjection& projection, int64_t numFactors, - PropagationDirection direction = PropagationDirection::BOTH, - MeshAttr mesh = nullptr, bool conservativePropagation = false, - PropagateAlongFactorPred propagateAlongFactor = [](int64_t) { - return true; - }) { + PropagationDirectionAlongFactor directionAlongFactor = + propagateAnything(), + MeshAttr mesh = nullptr, bool conservativePropagation = false) { return BasicFactorPropagation().propagateFactorShardings( - projection, direction, propagateAlongFactor, - SmallVector(numFactors, 1), mesh, /*op=*/nullptr, - conservativePropagation); + projection, directionAlongFactor, SmallVector(numFactors, 1), + mesh, /*op=*/nullptr, conservativePropagation); } UpdateTensorShardings propagateFactorShardings( ShardingProjection& projection, ArrayRef factorSizes, - PropagationDirection direction = PropagationDirection::BOTH, - MeshAttr mesh = nullptr, bool conservativePropagation = false) { + PropagationDirectionAlongFactor directionAlongFactor = + propagateAnything(), + MeshAttr mesh = nullptr) { return BasicFactorPropagation().propagateFactorShardings( - projection, direction, - /*propagateAlongFactor=*/[](int64_t) { return true; }, factorSizes, - mesh, /*op=*/nullptr, conservativePropagation); + projection, directionAlongFactor, factorSizes, mesh, /*op=*/nullptr, + /*conservativePropagation=*/false); } }; @@ -441,7 +442,7 @@ TEST_F(BasicFactorPropagationTest, MinorMostFactorNotDivisible) { ShardingProjection projectionBefore({operand}, {resultBefore}); ShardingProjection projectionAfter({operand}, {resultAfter}); auto [updateOperands, updateResults] = propagateFactorShardings( - projectionBefore, factorSizes, PropagationDirection::BOTH, mesh); + projectionBefore, factorSizes, propagateAnything(), mesh); EXPECT_THAT(toSetBitsVector(updateOperands), IsEmpty()); EXPECT_THAT(toSetBitsVector(updateResults), ElementsAre(0)); EXPECT_EQ(projectionBefore, projectionAfter); @@ -484,136 +485,135 @@ TEST_F(BasicFactorPropagationTest, MinorMostFactorNotDivisible) { } } -TEST_F(BasicFactorPropagationTest, UniDirectionalPropagation) { - TensorFactorShardings operandBefore0 = { - .factorIndexToSharding = { - {0, {.axisRefs = {createAxis("a"), createAxis("b")}}}, - {1, {.axisRefs = {createAxis("d"), createAxis("e")}}}, - }}; - TensorFactorShardings operandBefore1 = { - .factorIndexToSharding = { - {0, {.axisRefs = {createAxis("a")}}}, - {1, {.axisRefs = {createAxis("d")}}}, - }}; - TensorFactorShardings result0 = { - .factorIndexToSharding = { - {0, - {.axisRefs = {createAxis("a"), createAxis("b"), createAxis("c")}}}, - {1, {.axisRefs = {createAxis("d")}}}, - }}; - - TensorFactorShardings operandAfter0 = { - .factorIndexToSharding = { - {0, - {.axisRefs = {createAxis("a"), createAxis("b"), createAxis("c")}}}, - {1, {.axisRefs = {createAxis("d"), createAxis("e")}}}, - }}; - TensorFactorShardings operandAfter1 = { - .factorIndexToSharding = { - {0, - {.axisRefs = {createAxis("a"), createAxis("b"), createAxis("c")}}}, - {1, {.axisRefs = {createAxis("d")}}}, - }}; - - { - // Test that we only propagate backwards. Since we are only propagating - // backwards, we can expand both operands to have ["a", "b", "c"] along - // factor 0. - // - // Since we are only propagating backwards, we do not push "e" forwards - // along factor 1. We do not propagate sideways to each operand as our - // current behavior with BACKWARD closes all operands for factor expansion. - - ShardingProjection projection({operandBefore0, operandBefore1}, {result0}); - ShardingProjection projectionExpected({operandAfter0, operandAfter1}, - {result0}); - - auto [updateOperands, updateResults] = - propagateFactorShardings(projection, 2, PropagationDirection::BACKWARD); - EXPECT_THAT(toSetBitsVector(updateOperands), ElementsAre(0, 1)); - EXPECT_THAT(toSetBitsVector(updateResults), IsEmpty()); - EXPECT_EQ(projection, projectionExpected); - } +TEST_F(BasicFactorPropagationTest, DifferentDirectionsForDifferentFactors) { + ShardingProjection projection( + /*operands=*/ + {{.factorIndexToSharding = {{0, {.axisRefs = {createAxis("a")}}}, + {1, {.axisRefs = {createAxis("b")}}}, + {2, {.axisRefs = {createAxis("c")}}}, + {3, {.axisRefs = {createAxis("d")}}}, + {4, {.axisRefs = {}}}, + {5, {.axisRefs = {}}}, + {6, {.axisRefs = {}}}, + {7, {.axisRefs = {}}}}}, + {.factorIndexToSharding = {{0, {.axisRefs = {}}}, + {1, {.axisRefs = {}}}, + {2, {.axisRefs = {}}}, + {3, {.axisRefs = {}}}, + {4, {.axisRefs = {}}}, + {5, {.axisRefs = {}}}, + {6, {.axisRefs = {}}}, + {7, {.axisRefs = {}}}}}}, + /*results=*/ + {{.factorIndexToSharding = {{0, {.axisRefs = {}}}, + {1, {.axisRefs = {}}}, + {2, {.axisRefs = {}}}, + {3, {.axisRefs = {}}}, + {4, {.axisRefs = {createAxis("e")}}}, + {5, {.axisRefs = {createAxis("f")}}}, + {6, {.axisRefs = {createAxis("g")}}}, + {7, {.axisRefs = {createAxis("h")}}}}}, + {.factorIndexToSharding = {{0, {.axisRefs = {}}}, + {1, {.axisRefs = {}}}, + {2, {.axisRefs = {}}}, + {3, {.axisRefs = {}}}, + {4, {.axisRefs = {}}}, + {5, {.axisRefs = {}}}, + {6, {.axisRefs = {}}}, + {7, {.axisRefs = {}}}}}}); + + PropagationDirectionAlongFactor directionAlongFactor = + [](int64_t factorIndex) { + if (factorIndex == 0 || factorIndex == 4) { + return PropagationDirection::BOTH; + } + if (factorIndex == 1 || factorIndex == 5) { + return PropagationDirection::FORWARD; + } + if (factorIndex == 2 || factorIndex == 6) { + return PropagationDirection::BACKWARD; + } + return PropagationDirection::NONE; + }; - { - // Test that we only propagate forwards. - ShardingProjection projection({result0}, {operandBefore0, operandBefore1}); - ShardingProjection projectionExpected({result0}, - {operandAfter0, operandAfter1}); + ShardingProjection projectionExpected( + /*operands=*/ + {{.factorIndexToSharding = {{0, {.axisRefs = {createAxis("a")}}}, + {1, {.axisRefs = {createAxis("b")}}}, + {2, {.axisRefs = {createAxis("c")}}}, + {3, {.axisRefs = {createAxis("d")}}}, + {4, {.axisRefs = {createAxis("e")}}}, + {5, {.axisRefs = {}}}, + {6, {.axisRefs = {createAxis("g")}}}, + {7, {.axisRefs = {}}}}}, + {.factorIndexToSharding = {{0, {.axisRefs = {createAxis("a")}}}, + {1, {.axisRefs = {createAxis("b")}}}, + {2, {.axisRefs = {}}}, + {3, {.axisRefs = {}}}, + {4, {.axisRefs = {createAxis("e")}}}, + {5, {.axisRefs = {}}}, + {6, {.axisRefs = {createAxis("g")}}}, + {7, {.axisRefs = {}}}}}}, + /*results=*/ + {{.factorIndexToSharding = {{0, {.axisRefs = {createAxis("a")}}}, + {1, {.axisRefs = {createAxis("b")}}}, + {2, {.axisRefs = {}}}, + {3, {.axisRefs = {}}}, + {4, {.axisRefs = {createAxis("e")}}}, + {5, {.axisRefs = {createAxis("f")}}}, + {6, {.axisRefs = {createAxis("g")}}}, + {7, {.axisRefs = {createAxis("h")}}}}}, + {.factorIndexToSharding = {{0, {.axisRefs = {createAxis("a")}}}, + {1, {.axisRefs = {createAxis("b")}}}, + {2, {.axisRefs = {}}}, + {3, {.axisRefs = {}}}, + {4, {.axisRefs = {createAxis("e")}}}, + {5, {.axisRefs = {}}}, + {6, {.axisRefs = {createAxis("g")}}}, + {7, {.axisRefs = {}}}}}}); - auto [updateOperands, updateResults] = - propagateFactorShardings(projection, 2, PropagationDirection::FORWARD); - EXPECT_THAT(toSetBitsVector(updateOperands), IsEmpty()); - EXPECT_THAT(toSetBitsVector(updateResults), ElementsAre(0, 1)); - EXPECT_EQ(projection, projectionExpected); - } + auto [updateOperands, updateResults] = + propagateFactorShardings(projection, 8, directionAlongFactor); + EXPECT_THAT(toSetBitsVector(updateOperands), ElementsAre(0, 1)); + EXPECT_THAT(toSetBitsVector(updateResults), ElementsAre(0, 1)); + EXPECT_EQ(projection, projectionExpected); } TEST_F(BasicFactorPropagationTest, UniDirectionalPropagationWithConflict) { - TensorFactorShardings operand0 = { - .factorIndexToSharding = { - {0, {.axisRefs = {createAxis("a"), createAxis("b")}}}, - }}; - TensorFactorShardings operand1 = {.factorIndexToSharding = { - {0, {.axisRefs = {createAxis("a")}}}, - }}; - TensorFactorShardings result = { - .factorIndexToSharding = { - {0, - {.axisRefs = {createAxis("z"), createAxis("a"), createAxis("b")}}}, - }}; - - { - // Even though we are propagating backwards, we still need to account for - // conflicts. The "z" blocks any propagation. - ShardingProjection projection({operand0, operand1}, {result}); - auto [updateOperands, updateResults] = - propagateFactorShardings(projection, 1, PropagationDirection::BACKWARD); - EXPECT_THAT(toSetBitsVector(updateOperands), IsEmpty()); - EXPECT_THAT(toSetBitsVector(updateResults), IsEmpty()); - } - { - ShardingProjection projection({result}, {operand0, operand1}); - auto [updateOperands, updateResults] = - propagateFactorShardings(projection, 1, PropagationDirection::FORWARD); - EXPECT_THAT(toSetBitsVector(updateOperands), IsEmpty()); - EXPECT_THAT(toSetBitsVector(updateResults), IsEmpty()); - } -} - -TEST_F(BasicFactorPropagationTest, NonePropagationDirection) { ShardingProjection projection( /*operands=*/ - { - {.factorIndexToSharding = - { - {0, - {.axisRefs = {createAxis("a"), createAxis("b"), - createAxis("c")}}}, - }}, - }, - /*results=*/{ - {.factorIndexToSharding = - { - {0, {.axisRefs = {createAxis("a"), createAxis("b")}}}, - }}, - {.factorIndexToSharding = - { - {0, {.axisRefs = {createAxis("a")}}}, - }}, - {.factorIndexToSharding = - { - {0, - {.axisRefs = {createAxis("a"), createAxis("b"), - createAxis("c")}}}, - }}, - }); + {{.factorIndexToSharding = + {{0, {.axisRefs = {createAxis("a"), createAxis("b")}}}, + {1, {.axisRefs = {}}}, + {2, {.axisRefs = {createAxis("d")}}}}}}, + /*results=*/ + {{.factorIndexToSharding = { + {0, {.axisRefs = {createAxis("b")}}}, + {1, {.axisRefs = {createAxis("c"), createAxis("d")}}}, + {2, {.axisRefs = {}}}}}}); + + PropagationDirectionAlongFactor directionAlongFactor = + [](int64_t factorIndex) { + if (factorIndex == 0) { + return PropagationDirection::FORWARD; + } + if (factorIndex == 1) { + return PropagationDirection::BACKWARD; + } + return PropagationDirection::NONE; + }; + + ShardingProjection projectionExpected( + /*operands=*/ + {{.factorIndexToSharding = + {{0, {.axisRefs = {createAxis("a"), createAxis("b")}}}, + {1, {.axisRefs = {createAxis("c")}}}, + {2, {.axisRefs = {createAxis("d")}}}}}}, + /*results=*/{projection.getResult(0)}); - // Even though [a, b, c] is the most compatible, since we aren't propagating, - // we don't update any operands or results. auto [updateOperands, updateResults] = - propagateFactorShardings(projection, 1, PropagationDirection::NONE); - EXPECT_THAT(toSetBitsVector(updateOperands), IsEmpty()); + propagateFactorShardings(projection, 3, directionAlongFactor); + EXPECT_THAT(toSetBitsVector(updateOperands), ElementsAre(0)); EXPECT_THAT(toSetBitsVector(updateResults), IsEmpty()); } @@ -639,51 +639,13 @@ TEST_F(BasicFactorPropagationTest, {0, {.axisRefs = {createAxis("a")}}}, {1, {}}}}}); auto [updateOperands, updateResults] = propagateFactorShardings( - projection, 2, PropagationDirection::BOTH, /*mesh=*/nullptr, + projection, 2, propagateAnything(), /*mesh=*/nullptr, /*conservativePropagation=*/true); EXPECT_THAT(toSetBitsVector(updateOperands), IsEmpty()); EXPECT_THAT(toSetBitsVector(updateResults), ElementsAre(0)); EXPECT_EQ(projection, projectionExpected); } -TEST_F(BasicFactorPropagationTest, PropagateAlongPartialFactors) { - ShardingProjection projection( - /*operands=*/ - {{.factorIndexToSharding = - { - {0, {.axisRefs = {createAxis("a"), createAxis("b")}}}, - {1, {.axisRefs = {createAxis("c")}}}, - {2, {.axisRefs = {}}}, - }}}, - /*results=*/ - {{.factorIndexToSharding = { - {0, {.axisRefs = {}}}, - {1, {.axisRefs = {}}}, - {2, {.axisRefs = {createAxis("b")}}}, - }}}); - - // We do not propagate along factor 1. Factor 1 is still considered as - // conflict when we propagate along other factors. Thus, we only propagate - // ["a"] along factor 0. - ShardingProjection projectionExpected( - /*operands=*/{projection.getOperand(0)}, - /*results=*/{{.factorIndexToSharding = { - {0, {.axisRefs = {createAxis("a")}}}, - {1, {.axisRefs = {}}}, - {2, {.axisRefs = {createAxis("b")}}}, - }}}); - - PropagateAlongFactorPred doNotPropagateAlongFactor1 = - [](int64_t factorIndex) { return factorIndex != 1; }; - - auto [updateOperands, updateResults] = propagateFactorShardings( - projection, 2, PropagationDirection::BOTH, /*mesh=*/nullptr, - /*conservativePropagation=*/false, doNotPropagateAlongFactor1); - EXPECT_THAT(toSetBitsVector(updateOperands), IsEmpty()); - EXPECT_THAT(toSetBitsVector(updateResults), ElementsAre(0)); - EXPECT_EQ(projection, projectionExpected); -} - } // namespace } // namespace sdy } // namespace mlir diff --git a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc index 772e7ef..c2de9ec 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc @@ -270,11 +270,16 @@ LogicalResult propagateTensorShardings( ShardingProjection shardingProjection = ShardingProjection::build( operandShardings, resultShardings, shardingRule, mesh); bool anyUpdated = false; + // TODO(zixuanjiang). We apply the same propagation direction to all factors. + // We may want to consider propagating along different factors with different + // directions in the future. + PropagationDirectionAlongFactor directionAlongFactor = [direction](int64_t) { + return direction; + }; auto updateShardings = [&]() { auto [updateOperand, updateResult] = factorPropagation.propagateFactorShardings( - shardingProjection, direction, - /*propagateAlongFactor=*/[](int64_t) { return true; }, + shardingProjection, directionAlongFactor, shardingRule.getFactorSizes(), mesh, op, conservativePropagation); // We need to update the tensor sharding attributes explicitly, as we diff --git a/shardy/dialect/sdy/transforms/propagation/factor_propagation.h b/shardy/dialect/sdy/transforms/propagation/factor_propagation.h index f2f56d3..c4cd3a5 100644 --- a/shardy/dialect/sdy/transforms/propagation/factor_propagation.h +++ b/shardy/dialect/sdy/transforms/propagation/factor_propagation.h @@ -27,9 +27,10 @@ limitations under the License. namespace mlir { namespace sdy { -// A predicate taking a factor index and returning whether sharding axes should -// be propagated along that factor. -using PropagateAlongFactorPred = std::function; +// A predicate taking a factor index and returning the propagation direction +// along that factor. +using PropagationDirectionAlongFactor = + std::function; // An interface for propagating factor shardings. class FactorPropagation { @@ -38,7 +39,8 @@ class FactorPropagation { // Propagates the factor shardings in `projection`. // - // * `direction` specifies the direction of propagation. + // * `directionAlongFactor` is a predicate that determines in which direction + // propagation should happen for a given factor. // * `factorSizes` is the size of each factor. // * `mesh` is the mesh that the factors are sharded over. // * `op` is the operation that the factor shardings are propagated through. @@ -47,12 +49,9 @@ class FactorPropagation { // calculating the compatible major axes. If the projection contains a // sub-axis, then the axes (and any axes further sharding the factor) is // excluded from the result. - // - // TODO(b/392971621). Unify `PropagationDirection` and - // `PropagateAlongFactorPred`. virtual UpdateTensorShardings propagateFactorShardings( - ShardingProjection& projection, PropagationDirection direction, - PropagateAlongFactorPred propagateAlongFactor, + ShardingProjection& projection, + PropagationDirectionAlongFactor directionAlongFactor, ArrayRef factorSizes, MeshAttr mesh, Operation* op, bool conservativePropagation) const = 0; };