Skip to content

Commit

Permalink
#sdy Update GetDirectionToPropagateFn such that it determines in wh…
Browse files Browse the repository at this point in the history
…ich direction propagation should happen for an op and factor index.

This enables fine-grained control in the op/factor priority propagation schedule.

PiperOrigin-RevId: 721532701
  • Loading branch information
ZixuanJiang authored and copybara-github committed Feb 5, 2025
1 parent 2aaa85e commit 9fcc99d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 85 deletions.
70 changes: 27 additions & 43 deletions shardy/dialect/sdy/transforms/propagation/basic_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ void updateTensorShardings(const PropagationTensorParams& operandsParams,
LogicalResult propagateTensorShardings(
const PropagationTensorParams& operandsParams,
const PropagationTensorParams& resultsParams,
OpShardingRuleAttr shardingRule, PropagationDirection direction,
OpShardingRuleAttr shardingRule,
PropagationDirectionAlongFactor directionAlongFactor,
const FactorPropagation& factorPropagation, bool conservativePropagation,
Operation* op, const SymbolTable& symbolTable, PatternRewriter* rewriter,
ShardingGroupMap shardingGroupMap) {
Expand Down Expand Up @@ -280,11 +281,6 @@ LogicalResult propagateTensorShardings(
ShardingProjection shardingProjection = ShardingProjection::build(
operandsParams.shardings, resultsParams.shardings, shardingRule, mesh);
bool anyUpdated = false;

PropagationDirectionAlongFactor directionAlongFactor = [direction](int64_t) {
return direction;
};

auto updateShardings = [&]() {
auto [updateOperand, updateResult] =
factorPropagation.propagateFactorShardings(
Expand Down Expand Up @@ -318,29 +314,14 @@ LogicalResult propagateTensorShardings(
return success(anyUpdated);
}

// Same as the overload above, except there is a single operand and result.
LogicalResult propagateTensorShardings(
const PropagationTensorParams& operandsParams,
const PropagationTensorParams& resultsParams,
OpShardingRuleAttr shardingRule, Operation* op,
const SymbolTable& symbolTable, PatternRewriter* rewriter,
const FactorPropagation& factorPropagation,
const ShardingGroupMap& shardingGroupMap,
PropagationDirection direction = PropagationDirection::BOTH,
bool conservativePropagation = false) {
return propagateTensorShardings(
operandsParams, resultsParams, shardingRule, direction, factorPropagation,
conservativePropagation, op, symbolTable, rewriter, shardingGroupMap);
}

// Same as the overload above, except the operand and result shardings are
// extracted using `getSharding` and set using `setSharding`.
LogicalResult propagateTensorShardings(
ValueRange operands, ValueRange results, OpShardingRuleAttr shardingRule,
Operation* op, const SymbolTable& symbolTable, PatternRewriter& rewriter,
PropagationDirectionAlongFactor directionAlongFactor,
const FactorPropagation& factorPropagation,
const ShardingGroupMap& shardingGroupMap,
PropagationDirection direction = PropagationDirection::BOTH,
bool conservativePropagation = false) {
SmallVector<TensorShardingAttr> operandsShardings = getShardings(operands);
SmallVector<TensorShardingAttr> resultsShardings = getShardings(results);
Expand All @@ -357,9 +338,10 @@ LogicalResult propagateTensorShardings(
setSharding(results[index], sharding);
});

return propagateTensorShardings(
operandsParams, resultsParams, shardingRule, direction, factorPropagation,
conservativePropagation, op, symbolTable, &rewriter, shardingGroupMap);
return propagateTensorShardings(operandsParams, resultsParams, shardingRule,
directionAlongFactor, factorPropagation,
conservativePropagation, op, symbolTable,
&rewriter, shardingGroupMap);
}

// Propagates the shardings between the operands of the `funcOp`'s terminator
Expand Down Expand Up @@ -406,10 +388,12 @@ LogicalResult propagateFuncResults(FuncOp funcOp,
(void)propagateTensorShardings(
operandsParams, resultsParams,
// Treat the sharding data flow b/w the `funcOp` terminator and func
// result attrs as an identity op. Create an equivalent sharding
// rule.
createIdentityShardingRule(tensorType), funcOp, symbolTable,
/*rewriter=*/nullptr, factorPropagation, shardingGroupMap);
// result attrs as an identity op. Create an equivalent sharding rule.
createIdentityShardingRule(tensorType),
std::bind(propagateAny, funcOp, std::placeholders::_1),
factorPropagation,
/*conservativePropagation=*/false, funcOp, symbolTable,
/*rewriter=*/nullptr, shardingGroupMap);
}
return success();
}
Expand Down Expand Up @@ -455,19 +439,13 @@ class PropagateRegisteredOp : public RewritePattern {
diag << "op doesn't have a registered sharding rule";
});
}
PropagationDirection direction = getDirectionToPropagate(op);
if (direction == PropagationDirection::NONE) {
// No need to continue to propagate if the direction is `NONE`, as
// neither operands nor results can be updated.
return rewriter.notifyMatchFailure(op, [](Diagnostic& diag) {
diag << "propagation direction on op is NONE";
});
}

PropagationDirectionAlongFactor directionAlongFactor =
std::bind(getDirectionToPropagate, op, std::placeholders::_1);
return propagateTensorShardings(op->getOperands(), op->getResults(),
shardingRule, op, symbolTable, rewriter,
factorPropagation, shardingGroupMap,
direction, conservativePropagation);
directionAlongFactor, factorPropagation,
shardingGroupMap, conservativePropagation);
}

private:
Expand Down Expand Up @@ -522,14 +500,19 @@ class PropagateDataFlowEdgeOp : public OpRewritePattern<DataFlowEdgeOp> {
sharding, DataFlowShardingTransformType::kAfterEdgePropagation));
});

// TODO(b/394390827). We may pass getDirectionToPropagate so we can decide
// to change the priority of data flow ops.
PropagationDirectionAlongFactor directionAlongFactor =
std::bind(propagateAny, dataFlowEdgeOp, std::placeholders::_1);
return propagateTensorShardings(
operandsParams, resultsParams,
createIdentityShardingRule(cast<ShapedType>(dataFlowEdgeOp.getType()),
sources.size()),
PropagationDirection::BOTH, factorPropagation,
directionAlongFactor, factorPropagation,
/*conservativePropagation=*/false, dataFlowEdgeOp, symbolTable,
&rewriter, shardingGroupMap);
}

private:
const SymbolTable& symbolTable;
const FactorPropagation& factorPropagation;
Expand All @@ -556,8 +539,9 @@ class PropagatePropagationBarrier
propagationBarrierOp.getInput(), propagationBarrierOp.getResult(),
createIdentityShardingRule(
cast<RankedTensorType>(propagationBarrierOp.getType())),
propagationBarrierOp, symbolTable, rewriter, factorPropagation,
shardingGroupMap, propagationBarrierOp.getAllowedDirection());
propagationBarrierOp, symbolTable, rewriter,
[&](int64_t) { return propagationBarrierOp.getAllowedDirection(); },
factorPropagation, shardingGroupMap);
}

private:
Expand Down Expand Up @@ -604,7 +588,7 @@ bool allValidShapes(ModuleOp moduleOp) {

} // namespace

PropagationDirection propagateAny(Operation*) {
PropagationDirection propagateAny(Operation*, int64_t) {
return PropagationDirection::BOTH;
}

Expand Down
13 changes: 8 additions & 5 deletions shardy/dialect/sdy/transforms/propagation/basic_propagation.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
#define SHARDY_DIALECT_SDY_TRANSFORMS_PROPAGATION_BASIC_PROPAGATION_H_

#include <stdbool.h>

#include <cstdint>
#include <functional>
#include <memory>
#include <string>
Expand All @@ -29,21 +31,22 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/transforms/propagation/passes.h"
#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/passes.h"
#include "shardy/dialect/sdy/transforms/propagation/sharding_group_map.h"

namespace mlir {
namespace sdy {

// A function that determines in which direction propagation should happen for a
// given op.
// given op and factor index.
using GetDirectionToPropagateFn =
std::function<PropagationDirection(Operation*)>;
std::function<PropagationDirection(Operation* op, int64_t factorIndex)>;

// A function that returns `PropagationDirection::BOTH` for all operations.
PropagationDirection propagateAny(Operation* op);
// A function that returns `PropagationDirection::BOTH` for all operations and
// factor indices.
PropagationDirection propagateAny(Operation* op, int64_t factorIndex);

// The implementation class for the basic propagation pass.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,34 +46,11 @@ namespace sdy {
namespace {

// A function that determines in which direction propagation should happen for a
// given op.
using GetDirectionToPropagateFnPtr = PropagationDirection (*)(Operation*);
// given op and factor index.
using GetDirectionToPropagateFnPtr = PropagationDirection (*)(Operation*,
int64_t);

template <typename... OpTs>
PropagationDirection isaBoth(Operation* op) {
return isa<OpTs...>(op) ? PropagationDirection::BOTH
: PropagationDirection::NONE;
}

template <typename... OpTs>
PropagationDirection isNotABoth(Operation* op) {
return !isa<OpTs...>(op) ? PropagationDirection::BOTH
: PropagationDirection::NONE;
}

template <typename... OpTs>
PropagationDirection isaForward(Operation* op) {
return isa<OpTs...>(op) ? PropagationDirection::FORWARD
: PropagationDirection::NONE;
}

template <typename... OpTs>
PropagationDirection isaBackward(Operation* op) {
return isa<OpTs...>(op) ? PropagationDirection::BACKWARD
: PropagationDirection::NONE;
}

PropagationDirection isPassThrough(Operation* op) {
PropagationDirection isPassThrough(Operation* op, int64_t) {
if (isElementwise(op) ||
isa<stablehlo::ReshapeOp, stablehlo::TransposeOp>(op)) {
return PropagationDirection::BOTH;
Expand All @@ -93,18 +70,19 @@ constexpr std::array<GetDirectionToPropagateFnPtr, 2> opPropagationSchedule = {
// a caller. It will return the intersection of the passed in
// `getDirectionToPropagate` and the op based direction.
GetDirectionToPropagateFn getOpBasedDirectionToPropagate(
int64_t currentOpPriority,
int64_t currentPriority,
GetDirectionToPropagateFn getDirectionToPropagate) {
return [currentOpPriority, getDirectionToPropagate](Operation* op) {
return [currentPriority, getDirectionToPropagate](Operation* op,
int64_t factorIndex) {
PropagationDirection opBasedDirection = std::accumulate(
opPropagationSchedule.begin(),
opPropagationSchedule.begin() + currentOpPriority + 1,
opPropagationSchedule.begin() + currentPriority + 1,
PropagationDirection::NONE,
[&](PropagationDirection acc, GetDirectionToPropagateFnPtr dirFn) {
return unionOfPropagationDirections(acc, dirFn(op));
return unionOfPropagationDirections(acc, dirFn(op, factorIndex));
});
return intersectionOfPropagationDirections(opBasedDirection,
getDirectionToPropagate(op));
return intersectionOfPropagationDirections(
opBasedDirection, getDirectionToPropagate(op, factorIndex));
};
}

Expand All @@ -129,13 +107,13 @@ LogicalResult OpPriorityPropagationPassImpl::propagate(
return AggressivePropagationPassImpl::propagate(
moduleOp, symbolTable, shardingGroupMap, getDirectionToPropagate);
}
// Reset currentOpPriority to 0. Before running the pass. This same instance
// Reset currentPriority to 0. Before running the pass. This same instance
// could have been run earlier already (e.g. with a different user priority).
for (int64_t currentOpPriority = 0;
currentOpPriority < opPropagationSchedule.size(); currentOpPriority++) {
for (int64_t currentPriority = 0;
currentPriority < opPropagationSchedule.size(); currentPriority++) {
if (AggressivePropagationPassImpl::propagate(
moduleOp, symbolTable, shardingGroupMap,
getOpBasedDirectionToPropagate(currentOpPriority,
getOpBasedDirectionToPropagate(currentPriority,
getDirectionToPropagate))
.failed()) {
return failure();
Expand Down

0 comments on commit 9fcc99d

Please sign in to comment.