Skip to content

Commit

Permalink
#sdy Support fine-grained control on the propagation direction for fa…
Browse files Browse the repository at this point in the history
…ctors.

For example, we can support
* forward propagation along factor 0
* Both forward and backward propagation along factor 1
* propagating nothing along factor 2

PiperOrigin-RevId: 720760068
  • Loading branch information
ZixuanJiang authored and copybara-github committed Jan 29, 2025
1 parent 7216f63 commit a2ef0ac
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 236 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,24 +85,20 @@ SmallVector<TensorIndexSize> getFactorToSourceTensor(
} // namespace

UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
ShardingProjection& projection, PropagationDirection direction,
PropagateAlongFactorPred propagateAlongFactor,
ShardingProjection& projection,
PropagationDirectionAlongFactor directionAlongFactor,
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
bool conservativePropagation) const {
UpdateTensorShardings result(projection.getNumOperands(),
projection.getNumResults());
if (direction == PropagationDirection::NONE) {
return result;
}

// Find the compatible major axes ignoring conflicts.
AxesPerFactor axesPerFactor;
axesPerFactor.reserve(factorSizes.size());
bool allElementsAreEmpty = true;
for (int64_t i = 0; i < factorSizes.size(); ++i) {
SmallVector<AxisRefAttr>& axes =
axesPerFactor.emplace_back(getCompatibleMajorAxes(
projection, i, direction, propagateAlongFactor, op));
SmallVector<AxisRefAttr>& axes = axesPerFactor.emplace_back(
getCompatibleMajorAxes(projection, i, directionAlongFactor(i), op));
if (!axes.empty()) {
allElementsAreEmpty = false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ namespace sdy {
class AggressiveFactorPropagation : public BasicFactorPropagation {
public:
UpdateTensorShardings propagateFactorShardings(
ShardingProjection& projection, PropagationDirection direction,
PropagateAlongFactorPred propagateAlongFactor,
ShardingProjection& projection,
PropagationDirectionAlongFactor directionAlongFactor,
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
bool conservativePropagation) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,19 @@ namespace {
using ::testing::ElementsAre;
using ::testing::IsEmpty;

PropagationDirectionAlongFactor propagateAnything() {
return [](int64_t) { return PropagationDirection::BOTH; };
}

class AggressiveFactorPropagationTest : public PropagationTestBase {
protected:
UpdateTensorShardings propagateFactorShardings(
ShardingProjection& projection, int64_t numFactors,
PropagateAlongFactorPred propagateAlongFactor = [](int64_t) {
return true;
}) {
PropagationDirectionAlongFactor directionAlongFactor =
propagateAnything()) {
return AggressiveFactorPropagation().propagateFactorShardings(
projection, /*direction=*/PropagationDirection::BOTH,
propagateAlongFactor, SmallVector<int64_t>(numFactors, 1),
/*mesh=*/nullptr, /*op=*/nullptr, /*conservativePropagation*/ false);
projection, directionAlongFactor, SmallVector<int64_t>(numFactors, 1),
/*mesh=*/nullptr, /*op=*/nullptr, /*conservativePropagation=*/false);
}
};

Expand Down Expand Up @@ -357,7 +359,7 @@ TEST_F(AggressiveFactorPropagationTest, PropagateAlongSpecificFactor) {
.factorIndexToSharding = {{0, {}}, {1, {}}}};

auto propagateAlongFactor =
[&](PropagateAlongFactorPred propagateAlongFactor,
[&](PropagationDirectionAlongFactor propagateAlongFactor,
const ShardingProjection& projectionExpected) {
ShardingProjection projection(
/*operands=*/{factor0IsSharded, factor1IsSharded},
Expand All @@ -370,25 +372,27 @@ TEST_F(AggressiveFactorPropagationTest, PropagateAlongSpecificFactor) {
EXPECT_EQ(projection, projectionExpected);
};

auto propagateAlongFactorOnly = [](int64_t factorIndex) {
return [factorIndex](int64_t i) {
return i == factorIndex ? PropagationDirection::BOTH
: PropagationDirection::NONE;
};
};

ShardingProjection propagateAlongFactor0Expected(
/*operands=*/{factor0IsSharded, factor1IsSharded},
/*results=*/{factor0IsSharded});
propagateAlongFactor([](int64_t factorIndex) { return factorIndex == 0; },
propagateAlongFactor0Expected);
propagateAlongFactor([](int64_t factorIndex) { return factorIndex != 1; },
propagateAlongFactor0Expected);
// When we propagate along all factors, we propagate "a" to the result along
// factor 0.
propagateAlongFactor([](int64_t factorIndex) { return true; },
propagateAlongFactor0Expected);

ShardingProjection propagateAlongFactor1Expected(
/*operands=*/{factor0IsSharded, factor1IsSharded},
/*results=*/{factor1IsSharded});
propagateAlongFactor([](int64_t factorIndex) { return factorIndex == 1; },
propagateAlongFactor1Expected);
propagateAlongFactor([](int64_t factorIndex) { return factorIndex != 0; },
propagateAlongFactor(propagateAlongFactorOnly(0),
propagateAlongFactor0Expected);
propagateAlongFactor(propagateAlongFactorOnly(1),
propagateAlongFactor1Expected);

// When we propagate along all factors, we propagate "a" to the result along
// factor 0.
propagateAlongFactor(propagateAnything(), propagateAlongFactor0Expected);
}

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,8 @@ std::pair<SmallVector<AxisRefAttr>, bool> getCompatibleMajorAxesInternal(

SmallVector<AxisRefAttr> BasicFactorPropagation::getCompatibleMajorAxes(
const ShardingProjection& projection, int64_t factorIndex,
PropagationDirection direction,
PropagateAlongFactorPred propagateAlongFactor, Operation* op) const {
if (!propagateAlongFactor(factorIndex)) {
PropagationDirection direction, Operation* op) const {
if (direction == PropagationDirection::NONE) {
return {};
}

Expand Down Expand Up @@ -387,16 +386,11 @@ std::optional<AxisRefAttr> BasicFactorPropagation::compatiblePrefix(

SmallVector<AxisRefAttr> BasicFactorPropagation::getCompatibleMajorShardingAxes(
const ShardingProjection& projection, int64_t factorIndex,
PropagationDirection direction,
PropagateAlongFactorPred propagateAlongFactor, int64_t factorSize,
MeshAttr mesh, Operation* op, bool conservativePropagation) const {
if (direction == PropagationDirection::NONE) {
return SmallVector<AxisRefAttr>();
}

PropagationDirection direction, int64_t factorSize, MeshAttr mesh,
Operation* op, bool conservativePropagation) const {
// Finds the compatible major axes ignoring conflicts.
SmallVector<AxisRefAttr> resultAxes = getCompatibleMajorAxes(
projection, factorIndex, direction, propagateAlongFactor, op);
SmallVector<AxisRefAttr> resultAxes =
getCompatibleMajorAxes(projection, factorIndex, direction, op);

// Removes the major-most axis that isn't compatible w.r.t. other factors or
// the replicated axes, and all axes that are minor to it.
Expand All @@ -412,8 +406,8 @@ SmallVector<AxisRefAttr> BasicFactorPropagation::getCompatibleMajorShardingAxes(
}

UpdateTensorShardings BasicFactorPropagation::propagateFactorShardings(
ShardingProjection& projection, PropagationDirection direction,
PropagateAlongFactorPred propagateAlongFactor,
ShardingProjection& projection,
PropagationDirectionAlongFactor directionAlongFactor,
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
bool conservativePropagation) const {
UpdateTensorShardings result(projection.getNumOperands(),
Expand All @@ -425,7 +419,7 @@ UpdateTensorShardings BasicFactorPropagation::propagateFactorShardings(
// that factor for all tensors, those are the axes we will propagate to
// tensors that aren't already sharded.
SmallVector<AxisRefAttr> axesToPropagate = getCompatibleMajorShardingAxes(
projection, factorIndex, direction, propagateAlongFactor, factorSize,
projection, factorIndex, directionAlongFactor(factorIndex), factorSize,
mesh, op, conservativePropagation);

// Update all shardings along this factor if possible.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class BasicFactorPropagation : public FactorPropagation {

// Propagates the factor shardings in `projection`.
UpdateTensorShardings propagateFactorShardings(
ShardingProjection& projection, PropagationDirection direction,
PropagateAlongFactorPred propagateAlongFactor,
ShardingProjection& projection,
PropagationDirectionAlongFactor directionAlongFactor,
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
bool conservativePropagation) const override;

Expand Down Expand Up @@ -87,9 +87,8 @@ class BasicFactorPropagation : public FactorPropagation {
// ["a":(1)2].
SmallVector<AxisRefAttr> getCompatibleMajorShardingAxes(
const ShardingProjection& projection, int64_t factorIndex,
PropagationDirection direction,
PropagateAlongFactorPred propagateAlongFactor, int64_t factorSize,
MeshAttr mesh, Operation* op, bool conservativePropagation) const;
PropagationDirection direction, int64_t factorSize, MeshAttr mesh,
Operation* op, bool conservativePropagation) const;

// Finds the longest prefix of axes that shard the given factor, such that all
// tensors either:
Expand All @@ -100,8 +99,7 @@ class BasicFactorPropagation : public FactorPropagation {
// This method does not resolve conflicts across factors or replicated axes.
SmallVector<AxisRefAttr> getCompatibleMajorAxes(
const ShardingProjection& projection, int64_t factorIndex,
PropagationDirection direction,
PropagateAlongFactorPred propagateAlongFactor, Operation* op) const;
PropagationDirection direction, Operation* op) const;

// Returns the largest prefix of `axisRef`, which does not overlap with
// sharding axes and overflow axes for all other factors.
Expand Down
Loading

0 comments on commit a2ef0ac

Please sign in to comment.