diff --git a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc index f843578a..05a04e52 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc @@ -245,7 +245,8 @@ void updateTensorShardings(const PropagationTensorParams& operandsParams, LogicalResult propagateTensorShardings( const PropagationTensorParams& operandsParams, const PropagationTensorParams& resultsParams, - OpShardingRuleAttr shardingRule, PropagationDirection direction, + OpShardingRuleAttr shardingRule, + PropagationDirectionAlongFactor directionAlongFactor, const FactorPropagation& factorPropagation, bool conservativePropagation, Operation* op, const SymbolTable& symbolTable, PatternRewriter* rewriter, ShardingGroupMap shardingGroupMap) { @@ -280,11 +281,6 @@ LogicalResult propagateTensorShardings( ShardingProjection shardingProjection = ShardingProjection::build( operandsParams.shardings, resultsParams.shardings, shardingRule, mesh); bool anyUpdated = false; - - PropagationDirectionAlongFactor directionAlongFactor = [direction](int64_t) { - return direction; - }; - auto updateShardings = [&]() { auto [updateOperand, updateResult] = factorPropagation.propagateFactorShardings( @@ -318,29 +314,14 @@ LogicalResult propagateTensorShardings( return success(anyUpdated); } -// Same as the overload above, except there is a single operand and result. -LogicalResult propagateTensorShardings( - const PropagationTensorParams& operandsParams, - const PropagationTensorParams& resultsParams, - OpShardingRuleAttr shardingRule, Operation* op, - const SymbolTable& symbolTable, PatternRewriter* rewriter, - const FactorPropagation& factorPropagation, - const ShardingGroupMap& shardingGroupMap, - PropagationDirection direction = PropagationDirection::BOTH, - bool conservativePropagation = false) { - return propagateTensorShardings( - operandsParams, resultsParams, shardingRule, direction, factorPropagation, - conservativePropagation, op, symbolTable, rewriter, shardingGroupMap); -} - // Same as the overload above, except the operand and result shardings are // extracted using `getSharding` and set using `setSharding`. LogicalResult propagateTensorShardings( ValueRange operands, ValueRange results, OpShardingRuleAttr shardingRule, Operation* op, const SymbolTable& symbolTable, PatternRewriter& rewriter, + PropagationDirectionAlongFactor directionAlongFactor, const FactorPropagation& factorPropagation, const ShardingGroupMap& shardingGroupMap, - PropagationDirection direction = PropagationDirection::BOTH, bool conservativePropagation = false) { SmallVector operandsShardings = getShardings(operands); SmallVector resultsShardings = getShardings(results); @@ -357,9 +338,10 @@ LogicalResult propagateTensorShardings( setSharding(results[index], sharding); }); - return propagateTensorShardings( - operandsParams, resultsParams, shardingRule, direction, factorPropagation, - conservativePropagation, op, symbolTable, &rewriter, shardingGroupMap); + return propagateTensorShardings(operandsParams, resultsParams, shardingRule, + directionAlongFactor, factorPropagation, + conservativePropagation, op, symbolTable, + &rewriter, shardingGroupMap); } // Propagates the shardings between the operands of the `funcOp`'s terminator @@ -406,10 +388,12 @@ LogicalResult propagateFuncResults(FuncOp funcOp, (void)propagateTensorShardings( operandsParams, resultsParams, // Treat the sharding data flow b/w the `funcOp` terminator and func - // result attrs as an identity op. Create an equivalent sharding - // rule. - createIdentityShardingRule(tensorType), funcOp, symbolTable, - /*rewriter=*/nullptr, factorPropagation, shardingGroupMap); + // result attrs as an identity op. Create an equivalent sharding rule. + createIdentityShardingRule(tensorType), + std::bind(propagateAny, funcOp, std::placeholders::_1), + factorPropagation, + /*conservativePropagation=*/false, funcOp, symbolTable, + /*rewriter=*/nullptr, shardingGroupMap); } return success(); } @@ -455,19 +439,13 @@ class PropagateRegisteredOp : public RewritePattern { diag << "op doesn't have a registered sharding rule"; }); } - PropagationDirection direction = getDirectionToPropagate(op); - if (direction == PropagationDirection::NONE) { - // No need to continue to propagate if the direction is `NONE`, as - // neither operands nor results can be updated. - return rewriter.notifyMatchFailure(op, [](Diagnostic& diag) { - diag << "propagation direction on op is NONE"; - }); - } + PropagationDirectionAlongFactor directionAlongFactor = + std::bind(getDirectionToPropagate, op, std::placeholders::_1); return propagateTensorShardings(op->getOperands(), op->getResults(), shardingRule, op, symbolTable, rewriter, - factorPropagation, shardingGroupMap, - direction, conservativePropagation); + directionAlongFactor, factorPropagation, + shardingGroupMap, conservativePropagation); } private: @@ -522,14 +500,19 @@ class PropagateDataFlowEdgeOp : public OpRewritePattern { sharding, DataFlowShardingTransformType::kAfterEdgePropagation)); }); + // TODO(b/394390827). We may pass getDirectionToPropagate so we can decide + // to change the priority of data flow ops. + PropagationDirectionAlongFactor directionAlongFactor = + std::bind(propagateAny, dataFlowEdgeOp, std::placeholders::_1); return propagateTensorShardings( operandsParams, resultsParams, createIdentityShardingRule(cast(dataFlowEdgeOp.getType()), sources.size()), - PropagationDirection::BOTH, factorPropagation, + directionAlongFactor, factorPropagation, /*conservativePropagation=*/false, dataFlowEdgeOp, symbolTable, &rewriter, shardingGroupMap); } + private: const SymbolTable& symbolTable; const FactorPropagation& factorPropagation; @@ -556,8 +539,9 @@ class PropagatePropagationBarrier propagationBarrierOp.getInput(), propagationBarrierOp.getResult(), createIdentityShardingRule( cast(propagationBarrierOp.getType())), - propagationBarrierOp, symbolTable, rewriter, factorPropagation, - shardingGroupMap, propagationBarrierOp.getAllowedDirection()); + propagationBarrierOp, symbolTable, rewriter, + [&](int64_t) { return propagationBarrierOp.getAllowedDirection(); }, + factorPropagation, shardingGroupMap); } private: @@ -604,7 +588,7 @@ bool allValidShapes(ModuleOp moduleOp) { } // namespace -PropagationDirection propagateAny(Operation*) { +PropagationDirection propagateAny(Operation*, int64_t) { return PropagationDirection::BOTH; } diff --git a/shardy/dialect/sdy/transforms/propagation/basic_propagation.h b/shardy/dialect/sdy/transforms/propagation/basic_propagation.h index 732565dc..676f3392 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_propagation.h +++ b/shardy/dialect/sdy/transforms/propagation/basic_propagation.h @@ -17,6 +17,8 @@ limitations under the License. #define SHARDY_DIALECT_SDY_TRANSFORMS_PROPAGATION_BASIC_PROPAGATION_H_ #include + +#include #include #include #include @@ -29,21 +31,22 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "shardy/dialect/sdy/ir/dialect.h" -#include "shardy/dialect/sdy/transforms/propagation/passes.h" #include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h" #include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h" +#include "shardy/dialect/sdy/transforms/propagation/passes.h" #include "shardy/dialect/sdy/transforms/propagation/sharding_group_map.h" namespace mlir { namespace sdy { // A function that determines in which direction propagation should happen for a -// given op. +// given op and factor index. using GetDirectionToPropagateFn = - std::function; + std::function; -// A function that returns `PropagationDirection::BOTH` for all operations. -PropagationDirection propagateAny(Operation* op); +// A function that returns `PropagationDirection::BOTH` for all operations and +// factor indices. +PropagationDirection propagateAny(Operation* op, int64_t factorIndex); // The implementation class for the basic propagation pass. // diff --git a/shardy/dialect/sdy/transforms/propagation/op_priority_propagation.cc b/shardy/dialect/sdy/transforms/propagation/op_priority_propagation.cc index 4528ac56..8fa28948 100644 --- a/shardy/dialect/sdy/transforms/propagation/op_priority_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/op_priority_propagation.cc @@ -46,34 +46,11 @@ namespace sdy { namespace { // A function that determines in which direction propagation should happen for a -// given op. -using GetDirectionToPropagateFnPtr = PropagationDirection (*)(Operation*); +// given op and factor index. +using GetDirectionToPropagateFnPtr = PropagationDirection (*)(Operation*, + int64_t); -template -PropagationDirection isaBoth(Operation* op) { - return isa(op) ? PropagationDirection::BOTH - : PropagationDirection::NONE; -} - -template -PropagationDirection isNotABoth(Operation* op) { - return !isa(op) ? PropagationDirection::BOTH - : PropagationDirection::NONE; -} - -template -PropagationDirection isaForward(Operation* op) { - return isa(op) ? PropagationDirection::FORWARD - : PropagationDirection::NONE; -} - -template -PropagationDirection isaBackward(Operation* op) { - return isa(op) ? PropagationDirection::BACKWARD - : PropagationDirection::NONE; -} - -PropagationDirection isPassThrough(Operation* op) { +PropagationDirection isPassThrough(Operation* op, int64_t) { if (isElementwise(op) || isa(op)) { return PropagationDirection::BOTH; @@ -93,18 +70,19 @@ constexpr std::array opPropagationSchedule = { // a caller. It will return the intersection of the passed in // `getDirectionToPropagate` and the op based direction. GetDirectionToPropagateFn getOpBasedDirectionToPropagate( - int64_t currentOpPriority, + int64_t currentPriority, GetDirectionToPropagateFn getDirectionToPropagate) { - return [currentOpPriority, getDirectionToPropagate](Operation* op) { + return [currentPriority, getDirectionToPropagate](Operation* op, + int64_t factorIndex) { PropagationDirection opBasedDirection = std::accumulate( opPropagationSchedule.begin(), - opPropagationSchedule.begin() + currentOpPriority + 1, + opPropagationSchedule.begin() + currentPriority + 1, PropagationDirection::NONE, [&](PropagationDirection acc, GetDirectionToPropagateFnPtr dirFn) { - return unionOfPropagationDirections(acc, dirFn(op)); + return unionOfPropagationDirections(acc, dirFn(op, factorIndex)); }); - return intersectionOfPropagationDirections(opBasedDirection, - getDirectionToPropagate(op)); + return intersectionOfPropagationDirections( + opBasedDirection, getDirectionToPropagate(op, factorIndex)); }; } @@ -129,13 +107,13 @@ LogicalResult OpPriorityPropagationPassImpl::propagate( return AggressivePropagationPassImpl::propagate( moduleOp, symbolTable, shardingGroupMap, getDirectionToPropagate); } - // Reset currentOpPriority to 0. Before running the pass. This same instance + // Reset currentPriority to 0. Before running the pass. This same instance // could have been run earlier already (e.g. with a different user priority). - for (int64_t currentOpPriority = 0; - currentOpPriority < opPropagationSchedule.size(); currentOpPriority++) { + for (int64_t currentPriority = 0; + currentPriority < opPropagationSchedule.size(); currentPriority++) { if (AggressivePropagationPassImpl::propagate( moduleOp, symbolTable, shardingGroupMap, - getOpBasedDirectionToPropagate(currentOpPriority, + getOpBasedDirectionToPropagate(currentPriority, getDirectionToPropagate)) .failed()) { return failure();