Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#sdy Update GetDirectionToPropagateFn such that it determines in which direction propagation should happen for an op and factor index. #359

Merged
merged 1 commit into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 27 additions & 43 deletions shardy/dialect/sdy/transforms/propagation/basic_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ void updateTensorShardings(const PropagationTensorParams& operandsParams,
LogicalResult propagateTensorShardings(
const PropagationTensorParams& operandsParams,
const PropagationTensorParams& resultsParams,
OpShardingRuleAttr shardingRule, PropagationDirection direction,
OpShardingRuleAttr shardingRule,
PropagationDirectionAlongFactor directionAlongFactor,
const FactorPropagation& factorPropagation, bool conservativePropagation,
Operation* op, const SymbolTable& symbolTable, PatternRewriter* rewriter,
ShardingGroupMap shardingGroupMap) {
Expand Down Expand Up @@ -280,11 +281,6 @@ LogicalResult propagateTensorShardings(
ShardingProjection shardingProjection = ShardingProjection::build(
operandsParams.shardings, resultsParams.shardings, shardingRule, mesh);
bool anyUpdated = false;

PropagationDirectionAlongFactor directionAlongFactor = [direction](int64_t) {
return direction;
};

auto updateShardings = [&]() {
auto [updateOperand, updateResult] =
factorPropagation.propagateFactorShardings(
Expand Down Expand Up @@ -318,29 +314,14 @@ LogicalResult propagateTensorShardings(
return success(anyUpdated);
}

// Same as the overload above, except there is a single operand and result.
LogicalResult propagateTensorShardings(
const PropagationTensorParams& operandsParams,
const PropagationTensorParams& resultsParams,
OpShardingRuleAttr shardingRule, Operation* op,
const SymbolTable& symbolTable, PatternRewriter* rewriter,
const FactorPropagation& factorPropagation,
const ShardingGroupMap& shardingGroupMap,
PropagationDirection direction = PropagationDirection::BOTH,
bool conservativePropagation = false) {
return propagateTensorShardings(
operandsParams, resultsParams, shardingRule, direction, factorPropagation,
conservativePropagation, op, symbolTable, rewriter, shardingGroupMap);
}

// Same as the overload above, except the operand and result shardings are
// extracted using `getSharding` and set using `setSharding`.
LogicalResult propagateTensorShardings(
ValueRange operands, ValueRange results, OpShardingRuleAttr shardingRule,
Operation* op, const SymbolTable& symbolTable, PatternRewriter& rewriter,
PropagationDirectionAlongFactor directionAlongFactor,
const FactorPropagation& factorPropagation,
const ShardingGroupMap& shardingGroupMap,
PropagationDirection direction = PropagationDirection::BOTH,
bool conservativePropagation = false) {
SmallVector<TensorShardingAttr> operandsShardings = getShardings(operands);
SmallVector<TensorShardingAttr> resultsShardings = getShardings(results);
Expand All @@ -357,9 +338,10 @@ LogicalResult propagateTensorShardings(
setSharding(results[index], sharding);
});

return propagateTensorShardings(
operandsParams, resultsParams, shardingRule, direction, factorPropagation,
conservativePropagation, op, symbolTable, &rewriter, shardingGroupMap);
return propagateTensorShardings(operandsParams, resultsParams, shardingRule,
directionAlongFactor, factorPropagation,
conservativePropagation, op, symbolTable,
&rewriter, shardingGroupMap);
}

// Propagates the shardings between the operands of the `funcOp`'s terminator
Expand Down Expand Up @@ -406,10 +388,12 @@ LogicalResult propagateFuncResults(FuncOp funcOp,
(void)propagateTensorShardings(
operandsParams, resultsParams,
// Treat the sharding data flow b/w the `funcOp` terminator and func
// result attrs as an identity op. Create an equivalent sharding
// rule.
createIdentityShardingRule(tensorType), funcOp, symbolTable,
/*rewriter=*/nullptr, factorPropagation, shardingGroupMap);
// result attrs as an identity op. Create an equivalent sharding rule.
createIdentityShardingRule(tensorType),
std::bind(propagateAny, funcOp, std::placeholders::_1),
factorPropagation,
/*conservativePropagation=*/false, funcOp, symbolTable,
/*rewriter=*/nullptr, shardingGroupMap);
}
return success();
}
Expand Down Expand Up @@ -455,19 +439,13 @@ class PropagateRegisteredOp : public RewritePattern {
diag << "op doesn't have a registered sharding rule";
});
}
PropagationDirection direction = getDirectionToPropagate(op);
if (direction == PropagationDirection::NONE) {
// No need to continue to propagate if the direction is `NONE`, as
// neither operands nor results can be updated.
return rewriter.notifyMatchFailure(op, [](Diagnostic& diag) {
diag << "propagation direction on op is NONE";
});
}

PropagationDirectionAlongFactor directionAlongFactor =
std::bind(getDirectionToPropagate, op, std::placeholders::_1);
return propagateTensorShardings(op->getOperands(), op->getResults(),
shardingRule, op, symbolTable, rewriter,
factorPropagation, shardingGroupMap,
direction, conservativePropagation);
directionAlongFactor, factorPropagation,
shardingGroupMap, conservativePropagation);
}

private:
Expand Down Expand Up @@ -522,14 +500,19 @@ class PropagateDataFlowEdgeOp : public OpRewritePattern<DataFlowEdgeOp> {
sharding, DataFlowShardingTransformType::kAfterEdgePropagation));
});

// TODO(b/394390827). We may pass getDirectionToPropagate so we can decide
// to change the priority of data flow ops.
PropagationDirectionAlongFactor directionAlongFactor =
std::bind(propagateAny, dataFlowEdgeOp, std::placeholders::_1);
return propagateTensorShardings(
operandsParams, resultsParams,
createIdentityShardingRule(cast<ShapedType>(dataFlowEdgeOp.getType()),
sources.size()),
PropagationDirection::BOTH, factorPropagation,
directionAlongFactor, factorPropagation,
/*conservativePropagation=*/false, dataFlowEdgeOp, symbolTable,
&rewriter, shardingGroupMap);
}

private:
const SymbolTable& symbolTable;
const FactorPropagation& factorPropagation;
Expand All @@ -556,8 +539,9 @@ class PropagatePropagationBarrier
propagationBarrierOp.getInput(), propagationBarrierOp.getResult(),
createIdentityShardingRule(
cast<RankedTensorType>(propagationBarrierOp.getType())),
propagationBarrierOp, symbolTable, rewriter, factorPropagation,
shardingGroupMap, propagationBarrierOp.getAllowedDirection());
propagationBarrierOp, symbolTable, rewriter,
[&](int64_t) { return propagationBarrierOp.getAllowedDirection(); },
factorPropagation, shardingGroupMap);
}

private:
Expand Down Expand Up @@ -604,7 +588,7 @@ bool allValidShapes(ModuleOp moduleOp) {

} // namespace

PropagationDirection propagateAny(Operation*) {
PropagationDirection propagateAny(Operation*, int64_t) {
return PropagationDirection::BOTH;
}

Expand Down
13 changes: 8 additions & 5 deletions shardy/dialect/sdy/transforms/propagation/basic_propagation.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
#define SHARDY_DIALECT_SDY_TRANSFORMS_PROPAGATION_BASIC_PROPAGATION_H_

#include <stdbool.h>

#include <cstdint>
#include <functional>
#include <memory>
#include <string>
Expand All @@ -29,21 +31,22 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/transforms/propagation/passes.h"
#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/passes.h"
#include "shardy/dialect/sdy/transforms/propagation/sharding_group_map.h"

namespace mlir {
namespace sdy {

// A function that determines in which direction propagation should happen for a
// given op.
// given op and factor index.
using GetDirectionToPropagateFn =
std::function<PropagationDirection(Operation*)>;
std::function<PropagationDirection(Operation* op, int64_t factorIndex)>;

// A function that returns `PropagationDirection::BOTH` for all operations.
PropagationDirection propagateAny(Operation* op);
// A function that returns `PropagationDirection::BOTH` for all operations and
// factor indices.
PropagationDirection propagateAny(Operation* op, int64_t factorIndex);

// The implementation class for the basic propagation pass.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,34 +46,11 @@ namespace sdy {
namespace {

// A function that determines in which direction propagation should happen for a
// given op.
using GetDirectionToPropagateFnPtr = PropagationDirection (*)(Operation*);
// given op and factor index.
using GetDirectionToPropagateFnPtr = PropagationDirection (*)(Operation*,
int64_t);

template <typename... OpTs>
PropagationDirection isaBoth(Operation* op) {
return isa<OpTs...>(op) ? PropagationDirection::BOTH
: PropagationDirection::NONE;
}

template <typename... OpTs>
PropagationDirection isNotABoth(Operation* op) {
return !isa<OpTs...>(op) ? PropagationDirection::BOTH
: PropagationDirection::NONE;
}

template <typename... OpTs>
PropagationDirection isaForward(Operation* op) {
return isa<OpTs...>(op) ? PropagationDirection::FORWARD
: PropagationDirection::NONE;
}

template <typename... OpTs>
PropagationDirection isaBackward(Operation* op) {
return isa<OpTs...>(op) ? PropagationDirection::BACKWARD
: PropagationDirection::NONE;
}

PropagationDirection isPassThrough(Operation* op) {
PropagationDirection isPassThrough(Operation* op, int64_t) {
if (isElementwise(op) ||
isa<stablehlo::ReshapeOp, stablehlo::TransposeOp>(op)) {
return PropagationDirection::BOTH;
Expand All @@ -93,18 +70,19 @@ constexpr std::array<GetDirectionToPropagateFnPtr, 2> opPropagationSchedule = {
// a caller. It will return the intersection of the passed in
// `getDirectionToPropagate` and the op based direction.
GetDirectionToPropagateFn getOpBasedDirectionToPropagate(
int64_t currentOpPriority,
int64_t currentPriority,
GetDirectionToPropagateFn getDirectionToPropagate) {
return [currentOpPriority, getDirectionToPropagate](Operation* op) {
return [currentPriority, getDirectionToPropagate](Operation* op,
int64_t factorIndex) {
PropagationDirection opBasedDirection = std::accumulate(
opPropagationSchedule.begin(),
opPropagationSchedule.begin() + currentOpPriority + 1,
opPropagationSchedule.begin() + currentPriority + 1,
PropagationDirection::NONE,
[&](PropagationDirection acc, GetDirectionToPropagateFnPtr dirFn) {
return unionOfPropagationDirections(acc, dirFn(op));
return unionOfPropagationDirections(acc, dirFn(op, factorIndex));
});
return intersectionOfPropagationDirections(opBasedDirection,
getDirectionToPropagate(op));
return intersectionOfPropagationDirections(
opBasedDirection, getDirectionToPropagate(op, factorIndex));
};
}

Expand All @@ -129,13 +107,13 @@ LogicalResult OpPriorityPropagationPassImpl::propagate(
return AggressivePropagationPassImpl::propagate(
moduleOp, symbolTable, shardingGroupMap, getDirectionToPropagate);
}
// Reset currentOpPriority to 0. Before running the pass. This same instance
// Reset currentPriority to 0. Before running the pass. This same instance
// could have been run earlier already (e.g. with a different user priority).
for (int64_t currentOpPriority = 0;
currentOpPriority < opPropagationSchedule.size(); currentOpPriority++) {
for (int64_t currentPriority = 0;
currentPriority < opPropagationSchedule.size(); currentPriority++) {
if (AggressivePropagationPassImpl::propagate(
moduleOp, symbolTable, shardingGroupMap,
getOpBasedDirectionToPropagate(currentOpPriority,
getOpBasedDirectionToPropagate(currentPriority,
getDirectionToPropagate))
.failed()) {
return failure();
Expand Down