From 4fde0906c63ef9035660c36b8f33884c1fe68ee6 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Mon, 23 Dec 2024 12:17:33 -0800 Subject: [PATCH] #sdy #debug support debugging tool to save the edge source shardings. This works by having each `Operation*` save a source/target edge, which is always composed of at least one operand. This is because results can be used several times, but an operand only ever has one defining op. So these edges always look "backwards" - never forwards towards a use of a result. As such, the `FuncOp` args don't contain edge source information: only the ops that use them. Similarly, since only the uses save the edge, for `FuncOp` results which have been updated, we save the edges in each result's `resultAttr`. If the function has multiple results, then each `edge_source` on the func result attributes will have index 0. This is because of how propagation works with running propagation on each returned result. Having them have the right index would make sense if the `sdy.edge_sources` were saved as a top level attribute on the func, but since one is saved per result, then an index of 0 makes most sense. PiperOrigin-RevId: 709122522 --- shardy/dialect/sdy/ir/constants.h | 14 + shardy/dialect/sdy/transforms/export/BUILD | 1 + .../dialect/sdy/transforms/export/passes.td | 13 + .../transforms/export/sink_data_flow_edges.cc | 51 +- .../propagation/basic_propagation.cc | 9 +- .../propagation/basic_propagation.h | 10 +- .../propagation/debugging/source_sharding.cc | 548 +++++++++++++----- .../propagation/debugging/source_sharding.h | 74 ++- .../debugging/test/edge_shardings.mlir | 233 ++++++++ .../debugging/test/sharding_origins.mlir | 2 +- .../sdy/transforms/propagation/passes.h | 2 +- 11 files changed, 731 insertions(+), 226 deletions(-) create mode 100644 shardy/dialect/sdy/transforms/propagation/debugging/test/edge_shardings.mlir diff --git a/shardy/dialect/sdy/ir/constants.h b/shardy/dialect/sdy/ir/constants.h index 0be03f5c..da2c7a41 100644 --- a/shardy/dialect/sdy/ir/constants.h +++ b/shardy/dialect/sdy/ir/constants.h @@ -35,16 +35,30 @@ inline constexpr StringRef kShardingRuleAttr = "sdy.sharding_rule"; // caused a value to be sharded a certain way. inline constexpr StringRef kShardingOriginsAttr = "sdy.sharding_origins"; +// Attribute name for saving which operand/result sharding of an op caused its +// value to be sharded a certain way. +inline constexpr StringRef kPropagationEdgesAttr = "sdy.propagation_edges"; + // Attribute name like `kShardingOriginsAttr` but for // `ShardableDataFlowOpInterface` op block arguments. inline constexpr StringRef kBlockArgShardingOriginsAttr = "sdy.block_arg_sharding_origins"; +// Attribute name like `kPropagationEdgesAttr` but for +// `ShardableDataFlowOpInterface` op block arguments. +inline constexpr StringRef kBlockArgPropagationEdgesAttr = + "sdy.block_arg_propagation_edges"; + // Attribute name like `kShardingOriginsAttr` but for // `ShardableDataFlowOpInterface` op results. inline constexpr StringRef kResultShardingOriginsAttr = "sdy.result_sharding_origins"; +// Attribute name like `kPropagationEdgesAttr` but for +// `ShardableDataFlowOpInterface` op results. +inline constexpr StringRef kResultPropagationEdgesAttr = + "sdy.result_propagation_edges"; + // Attribute name for the unique name of a sharding origin. Is either an // `sdy.sharding_constraint`, or `sdy.ManualComputationOp` input/output. inline constexpr StringRef kShardingOriginNameAttr = "sdy.sharding_origin_name"; diff --git a/shardy/dialect/sdy/transforms/export/BUILD b/shardy/dialect/sdy/transforms/export/BUILD index 20d67dca..767b1ab6 100644 --- a/shardy/dialect/sdy/transforms/export/BUILD +++ b/shardy/dialect/sdy/transforms/export/BUILD @@ -58,6 +58,7 @@ cc_library( "//shardy/dialect/sdy/transforms/propagation:op_sharding_rule_registry", "//shardy/dialect/sdy/transforms/propagation:sharding_projection", "//shardy/dialect/sdy/transforms/propagation:utils", + "//shardy/dialect/sdy/transforms/propagation/debugging:source_sharding", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/shardy/dialect/sdy/transforms/export/passes.td b/shardy/dialect/sdy/transforms/export/passes.td index e454dd7e..9a537374 100644 --- a/shardy/dialect/sdy/transforms/export/passes.td +++ b/shardy/dialect/sdy/transforms/export/passes.td @@ -28,6 +28,19 @@ def SinkDataFlowEdgesPass : Pass<"sdy-sink-data-flow-edges", "func::FuncOp"> { }]; let dependentDialects = ["mlir::sdy::SdyDialect"]; //TODO(tomnatan): consider moving the sharding to all targets that can have a sharding attached. + + let options = [ + Option<"sinkDebugShardingOrigins", "sink-debug-sharding-origins", "bool", + /*default=*/"false", + "Whether to sink the debug sharding origins info. See " + "`debug-sharding-origins` option in propagation for more info.">, + Option<"sinkDebugPropagationEdgeSharding", + "sink-debug-propagation-edge-sharding", "bool", + /*default=*/"false", + "Whether to sink the debug propagation edge sharding info. See " + "`debug-propagation-edge-sharding` option in propagation for more " + "info."> + ]; } def UpdateNonDivisibleInputOutputShardingsPass : Pass<"sdy-update-non-divisible-input-output-shardings", "func::FuncOp"> { diff --git a/shardy/dialect/sdy/transforms/export/sink_data_flow_edges.cc b/shardy/dialect/sdy/transforms/export/sink_data_flow_edges.cc index 6873d22b..e54a0cf6 100644 --- a/shardy/dialect/sdy/transforms/export/sink_data_flow_edges.cc +++ b/shardy/dialect/sdy/transforms/export/sink_data_flow_edges.cc @@ -30,9 +30,10 @@ limitations under the License. #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" -#include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" +#include "shardy/dialect/sdy/transforms/export/passes.h" // IWYU pragma: keep +#include "shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.h" namespace mlir { namespace sdy { @@ -81,44 +82,6 @@ SmallVector getShardingsFromDataFlowEdges( return shardings; } -// Saves an array of all the origin sharding dictionaries for the given -// `edgeOwners` on `op`. If non exist, nothing is saved. -// -// For debugging the origin shardings, we want to preserve the origin sharding -// dictionaries from the `DataFlowEdgeOp`s on the owning op so that they are -// preserved after the propagation pipeline. -// -// See the `debug-sharding-origins` config on propagation for more details. -// -// TODO(b/388458831): add `saveDebugPropagationInfo` to the pass and pass it in -// here. Can then reserve the right size for `originShardingDicts` and not need -// the `exists` boolean. -void buildOriginShardingDictsFromDataFlowEdges(ValueRange edgeOwners, - Operation* op, - StringRef attrName, - IRRewriter& rewriter) { - SmallVector originShardingDicts; - // TODO(b/388458831): pass through a boolean indicating whether the origin - // sharding debug information is enabled. - bool exists = false; - for (Value edgeOwner : edgeOwners) { - DictionaryAttr dict; - if (auto dataFlowEdgeOp = DataFlowEdgeOp::lookup(edgeOwner)) { - dict = - dataFlowEdgeOp->getAttrOfType(kShardingOriginsAttr); - } - if (!dict) { - dict = rewriter.getDictionaryAttr({}); - } else { - exists = true; - } - originShardingDicts.push_back(dict); - } - if (exists) { - op->setAttr(attrName, rewriter.getArrayAttr(originShardingDicts)); - } -} - struct SinkDataFlowEdgesPass : public impl::SinkDataFlowEdgesPassBase { using SinkDataFlowEdgesPassBase::SinkDataFlowEdgesPassBase; @@ -152,8 +115,9 @@ struct SinkDataFlowEdgesPass shardableDataFlowOp.setBlockArgumentEdgeOwnerShardings( blockArgShardings); } - buildOriginShardingDictsFromDataFlowEdges( - blockArgOwners, op, kBlockArgShardingOriginsAttr, rewriter); + saveDebugInfoDictsFromDataFlowEdges( + blockArgOwners, op, sinkDebugShardingOrigins, + sinkDebugPropagationEdgeSharding, EdgeNodeType::OPERAND, rewriter); ResultRange resultOwners = shardableDataFlowOp.getOpResultEdgeOwners(); if (SmallVector resultShardings = @@ -161,8 +125,9 @@ struct SinkDataFlowEdgesPass !resultShardings.empty()) { shardableDataFlowOp.setOpResultEdgeOwnerShardings(resultShardings); } - buildOriginShardingDictsFromDataFlowEdges( - resultOwners, op, kResultShardingOriginsAttr, rewriter); + saveDebugInfoDictsFromDataFlowEdges( + resultOwners, op, sinkDebugShardingOrigins, + sinkDebugPropagationEdgeSharding, EdgeNodeType::RESULT, rewriter); return WalkResult::advance(); }); } diff --git a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc index 05a04e52..07414a38 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc @@ -300,8 +300,8 @@ LogicalResult propagateTensorShardings( if (context->hasActionHandler()) { context->executeAction( updateShardings, - /*IRUnits=*/{op}, operandsParams.tensors, resultsParams.tensors, mesh, - shardingRule, shardingProjection); + /*IRUnits=*/{op}, op, operandsParams.tensors, resultsParams.tensors, + mesh, shardingRule, shardingProjection, anyUpdated); } else { updateShardings(); } @@ -649,7 +649,8 @@ void BasicPropagationPassImpl::runOnOperation() { MLIRContext& context = getContext(); // Prepare debugging handler for sharding origins and edge sources. - ShardingDebugMappings mappings(debugShardingOrigins, debugEdgeSourceSharding); + ShardingDebugMappings mappings(debugShardingOrigins, + debugPropagationEdgeSharding); SourceShardingHandler handler(&mappings); handler.prepareHandler(moduleOp); @@ -683,7 +684,7 @@ void BasicPropagationPassImpl::setPropagationOptions( dumpDirectory = options.dumpDirectory.str(); conservativePropagation = options.conservativePropagation; debugShardingOrigins = options.debugShardingOrigins; - debugEdgeSourceSharding = options.debugEdgeSourceSharding; + debugPropagationEdgeSharding = options.debugPropagationEdgeSharding; } std::unique_ptr createBasicPropagationPass( diff --git a/shardy/dialect/sdy/transforms/propagation/basic_propagation.h b/shardy/dialect/sdy/transforms/propagation/basic_propagation.h index 676f3392..7b989353 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_propagation.h +++ b/shardy/dialect/sdy/transforms/propagation/basic_propagation.h @@ -126,12 +126,12 @@ class BasicPropagationPassImpl : public OperationPass { "before propagation."), llvm::cl::init(false)}; - Option debugEdgeSourceSharding{ - *this, "debug-edge-source-sharding", + Option debugPropagationEdgeSharding{ + *this, "debug-propagation-edge-sharding", llvm::cl::desc( - "whether to save information about the edge source of a sharding " - "on the MLIR module. These are from which operand/result a sharding " - "was propagated."), + "whether to save information about the SSA value edges of how a " + "sharding on the MLIR module propagated around. These are from which " + "operand/result a sharding was propagated to a given op."), llvm::cl::init(false)}; private: diff --git a/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc b/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc index cdbd0677..e45a6b2a 100644 --- a/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc +++ b/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" @@ -34,7 +35,9 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" @@ -47,8 +50,8 @@ namespace sdy { namespace { // The map from factor to edge source for operands and results. -struct FactorsToEdgeSourceMap { - llvm::SmallVector operands, results; +struct FactorsToEdgeMap { + llvm::SmallVector operands, results; }; // Finds what operand/result the new sharding axes came from for a given @@ -73,123 +76,120 @@ std::optional findNewAxisRefMatch( // // This only saves any newly introduced factor shardings, not any pre-existing // ones. So if no operand/result sharding changes, the map will be empty. -FactorsToEdgeSourceMap createSourceMap( +FactorsToEdgeMap createSourceMap( const ShardingProjection& oldShardingProjection, const ShardingProjection& newShardingProjection, - OpShardingRuleAttr shardingRule, MeshAttr mesh) { - FactorsToEdgeSourceMap axisToEdgeSourceMap{ - llvm::SmallVector( - oldShardingProjection.getNumOperands(), AxisToEdgeSourceMap()), - llvm::SmallVector( - oldShardingProjection.getNumResults(), AxisToEdgeSourceMap())}; + OpShardingRuleAttr shardingRule, MeshAttr mesh, + const int64_t propagationStep) { + FactorsToEdgeMap axisToEdgeMap{ + llvm::SmallVector(oldShardingProjection.getNumOperands(), + AxisToEdgeMap()), + llvm::SmallVector(oldShardingProjection.getNumResults(), + AxisToEdgeMap())}; // Saves the `axisRefs` to the specified `valueSourceMap` of - // `axisToEdgeSourceMap`. - auto saveEdgeSources = [&](ArrayRef newAxisRefs, - ArrayRef oldAxisRefs, - EdgeSourceType type, int64_t sourceIndex, - AxisToEdgeSourceMap& valueSourceMap) { + // `axisToEdgeMap`. + auto saveEdges = [&](ArrayRef newAxisRefs, + ArrayRef oldAxisRefs, EdgeNode source, + EdgeNode target, AxisToEdgeMap& valueSourceMap) { // To avoid iterating over all the new axes, only compare the very last old // axis (since there could have been a sub-axis update) and then the // trailing new axes. int64_t oldAxisIndex = oldAxisRefs.size() - 1; if (!oldAxisRefs.empty() && oldAxisRefs[oldAxisIndex] != newAxisRefs[oldAxisIndex]) { - valueSourceMap.try_emplace(newAxisRefs[oldAxisIndex], - EdgeSource{type, sourceIndex}); + valueSourceMap.try_emplace( + newAxisRefs[oldAxisIndex], + PropagationEdge{source, target, propagationStep}); } for (AxisRefAttr axisRef : newAxisRefs.drop_front(oldAxisRefs.size())) { - valueSourceMap.try_emplace(axisRef, EdgeSource{type, sourceIndex}); + valueSourceMap.try_emplace( + axisRef, PropagationEdge{source, target, propagationStep}); } }; MLIRContext* context = mesh.getContext(); ArrayRef factorSizes = shardingRule.getFactorSizes(); - auto visitValue = - [&](const TensorFactorShardings& oldValue, - const TensorFactorShardings& newValue, int64_t valueIndex, - TensorMappingAttr tensorMapping, - llvm::SmallVector& valueSourceMap) { - DenseSet oldAxes; - for (const auto& [_, oldFactorSharding] : - oldValue.factorIndexToSharding) { - oldAxes.insert(oldFactorSharding.axisRefs.begin(), - oldFactorSharding.axisRefs.end()); - } - for (const auto& [factorIndex, oldFactorSharding] : - oldValue.factorIndexToSharding) { - const FactorSharding& newFactorSharding = - newValue.factorIndexToSharding.at(factorIndex); - if (oldFactorSharding.axisRefs == newFactorSharding.axisRefs) { - continue; - } - SmallVector newlyIntroducedAxes; - // If multiple sub axes can be merged due to a dimension sharding - // having multiple factors, each sharded on a sub axis, make sure we - // only save the merged one. This can happen during an - // `(A, B) -> (AB,)` reshape. - TensorShardingAttr tensorSharding = newValue.createTensorShardingAttr( - context, tensorMapping, factorSizes, "", mesh); - for (DimensionShardingAttr dimSharding : - tensorSharding.getDimShardings()) { - llvm::copy_if( - dimSharding.getAxes(), std::back_inserter(newlyIntroducedAxes), - [&](const AxisRefAttr& axisRef) { - // Don't add any axes that were already in the - // old sharding. We just want new axes. - if (oldAxes.contains(axisRef)) { - return false; - } - // We need to avoid any axes that already existed - // in the old sharding, but aren't in the new - // projection as the conflicted. E.g. for a - // contracting dim matmul, if both the LHS/RHS are - // sharded on the same axis on their respective - // non-contracting dims, the dimension sharding - // will contain the conflicting axes, but the - // factor sharding will not. And we don't want this - // axis as it isn't a newly introduced axis. - for (AxisRefAttr newAxisRef : newFactorSharding.axisRefs) { - if (newAxisRef.prefixOf(axisRef)) { - return true; - } - } - return false; - }); - } - // This factor sharding has changed, let's find who changed it. - if (std::optional operandSource = - findNewAxisRefMatch(newFactorSharding.axisRefs, factorIndex, - oldShardingProjection.getOperands())) { - saveEdgeSources(newlyIntroducedAxes, oldFactorSharding.axisRefs, - EdgeSourceType::OPERAND, *operandSource, - valueSourceMap[valueIndex]); - } else if (std::optional resultSource = findNewAxisRefMatch( - newFactorSharding.axisRefs, factorIndex, - oldShardingProjection.getResults())) { - saveEdgeSources(newlyIntroducedAxes, oldFactorSharding.axisRefs, - EdgeSourceType::RESULT, *resultSource, - valueSourceMap[valueIndex]); - } - } - }; + auto visitValue = [&](const TensorFactorShardings& oldValue, + const TensorFactorShardings& newValue, + EdgeNodeType valueType, int64_t valueIndex, + TensorMappingAttr tensorMapping, + llvm::SmallVector& valueSourceMap) { + DenseSet oldAxes; + for (const auto& [_, oldFactorSharding] : oldValue.factorIndexToSharding) { + oldAxes.insert(oldFactorSharding.axisRefs.begin(), + oldFactorSharding.axisRefs.end()); + } + for (const auto& [oldFactorSharding, newFactorSharding] : llvm::zip_equal( + oldValue.factorIndexToSharding, newValue.factorIndexToSharding)) { + if (oldFactorSharding.second.axisRefs == + newFactorSharding.second.axisRefs) { + continue; + } + SmallVector newlyIntroducedAxes; + // If multiple sub axes can be merged due to a dimension sharding having + // multiple factors, each sharded on a sub axis, make sure we only save + // the merged one. This can happen during an `(A, B) -> (AB,)` reshape. + TensorShardingAttr tensorSharding = newValue.createTensorShardingAttr( + context, tensorMapping, factorSizes, "", mesh); + for (DimensionShardingAttr dimSharding : + tensorSharding.getDimShardings()) { + llvm::copy_if( + dimSharding.getAxes(), std::back_inserter(newlyIntroducedAxes), + [&](const AxisRefAttr& axisRef) { + // Don't add any axes that were already in the + // old sharding. We just want new axes. + if (oldAxes.contains(axisRef)) { + return false; + } + // We need to avoid any axes that already existed in the old + // sharding, but aren't in the new projection as the conflicted. + // E.g. for a contracting dim matmul, if both the LHS/RHS are + // sharded on the same axis on their respective non-contracting + // dims, the dimension sharding will contain the conflicting axes, + // but the factor sharding will not. And we don't want this axis + // as it isn't a newly introduced axis. + for (AxisRefAttr newAxisRef : newFactorSharding.second.axisRefs) { + if (newAxisRef.prefixOf(axisRef)) { + return true; + } + } + return false; + }); + } + // This factor sharding has changed, let's find who changed it. + if (std::optional operandSource = findNewAxisRefMatch( + newFactorSharding.second.axisRefs, oldFactorSharding.first, + oldShardingProjection.getOperands())) { + saveEdges(newlyIntroducedAxes, oldFactorSharding.second.axisRefs, + EdgeNode{EdgeNodeType::OPERAND, *operandSource}, + EdgeNode{valueType, valueIndex}, valueSourceMap[valueIndex]); + } else if (std::optional resultSource = findNewAxisRefMatch( + newFactorSharding.second.axisRefs, oldFactorSharding.first, + oldShardingProjection.getResults())) { + saveEdges(newlyIntroducedAxes, oldFactorSharding.second.axisRefs, + EdgeNode{EdgeNodeType::RESULT, *resultSource}, + EdgeNode{valueType, valueIndex}, valueSourceMap[valueIndex]); + } + } + }; for (auto [i, packedOperands] : llvm::enumerate(llvm::zip_equal(oldShardingProjection.getOperands(), newShardingProjection.getOperands()))) { auto [oldOperand, newOperand] = packedOperands; - visitValue(oldOperand, newOperand, i, shardingRule.getOperandMapping(i), - axisToEdgeSourceMap.operands); + visitValue(oldOperand, newOperand, EdgeNodeType::OPERAND, i, + shardingRule.getOperandMapping(i), axisToEdgeMap.operands); } for (auto [i, packedResults] : llvm::enumerate(llvm::zip_equal(oldShardingProjection.getResults(), newShardingProjection.getResults()))) { auto [oldResult, newResult] = packedResults; - visitValue(oldResult, newResult, i, shardingRule.getResultMapping(i), - axisToEdgeSourceMap.results); + visitValue(oldResult, newResult, EdgeNodeType::RESULT, i, + shardingRule.getResultMapping(i), axisToEdgeMap.results); } - return axisToEdgeSourceMap; + return axisToEdgeMap; } std::string manualComputationOriginName(OriginShardingType type, StringRef name, @@ -243,6 +243,12 @@ StringAttr shardingOriginToString(OriginSharding source, MLIRContext* context) { llvm::formatv("{0}: {1}", typeString, source.index)); } +// Avoid printing the string with escaping quotes, aka "\22". +void eraseDoubleQuotesInAxisRefString(std::string& axisRefString) { + axisRefString.erase(remove(axisRefString.begin(), axisRefString.end(), '"'), + axisRefString.end()); +} + // Create a list of entries from the `axisToOriginSharding` map to save as a // `DictionaryAttr`. SmallVector createOriginShardingEntries( @@ -251,9 +257,7 @@ SmallVector createOriginShardingEntries( entries.reserve(axisToOriginSharding.size()); for (const auto& [axisRef, shardingOrigin] : axisToOriginSharding) { std::string axisRefString = axisRef.toString(); - // Avoid printing the string with escaping quotes, aka "\22". - axisRefString.erase(remove(axisRefString.begin(), axisRefString.end(), '"'), - axisRefString.end()); + eraseDoubleQuotesInAxisRefString(axisRefString); entries.emplace_back( NamedAttribute(StringAttr::get(context, axisRefString), shardingOriginToString(shardingOrigin, context))); @@ -273,33 +277,40 @@ SmallVector getOriginShardingDicts(Operation* op, Builder& builder) { return SmallVector(resultDicts.getValue()); } -// Saves the originating sharding debug information on the `moduleOp`. -void saveShardingOriginsOnModule(ModuleOp moduleOp, - ShardingDebugMappings* mappings) { - MLIRContext* context = moduleOp.getContext(); +// Gets the `OpOperand` of the `value` in the `funcOp` terminator if the Value +// is used in the terminator. Else returns `nullptr`. +OpOperand* getTerminatorOperand(Value value, func::FuncOp funcOp) { + ArrayRef terminatorOperands = + getBodyTerminator(funcOp)->getOpOperands(); + if (auto it = llvm::find_if(value.getUses(), + [&](const OpOperand& use) { + return llvm::is_contained(terminatorOperands, + use); + }); + it != value.getUses().end()) { + return it.getOperand(); + } + return nullptr; +} + +// Saves the originating sharding debug information on each `Value` in +// `valueToOriginShardingMap`. +void saveShardingOriginsOnModule( + MLIRContext* context, + const ValueToOriginShardingMap& valueToOriginShardingMap) { Builder builder(context); - for (auto [value, axisToOriginSharding] : - mappings->valueToOriginShardingMap) { + for (auto& [value, axisToOriginSharding] : valueToOriginShardingMap) { Operation* owningOp = getOwningOp(value); func::FuncOp funcOp = getEnclosingOfType(owningOp); // TODO(bartchr): Swap the map to store `ValueOrFuncResult` to avoid having // to do this terminator finding logic just to set the func result attr. - OpOperand* terminatorOperand = nullptr; - ArrayRef terminatorOperands = - getBodyTerminator(funcOp)->getOpOperands(); - if (auto it = llvm::find_if(value.getUses(), - [&](const OpOperand& use) { - return llvm::is_contained(terminatorOperands, - use); - }); - it != value.getUses().end()) { - terminatorOperand = it.getOperand(); - } + OpOperand* terminatorOperand = getTerminatorOperand(value, funcOp); SmallVector entries = createOriginShardingEntries(axisToOriginSharding, context); + if (terminatorOperand) { int64_t operandNumber = terminatorOperand->getOperandNumber(); funcOp.setResultAttr(operandNumber, kShardingOriginsAttr, @@ -328,6 +339,131 @@ void saveShardingOriginsOnModule(ModuleOp moduleOp, } } +// Converts the `node` to a `NamedAttribute`. +StringAttr edgeNodeToString(EdgeNode node, Builder& builder) { + std::string typeString; + switch (node.type) { + case EdgeNodeType::OPERAND: { + typeString = "operand"; + break; + } + case EdgeNodeType::RESULT: { + typeString = "result"; + break; + } + } + return builder.getStringAttr( + llvm::formatv("{0}: {1}", typeString, node.index)); +} + +// In the case where we have a Value used multiple times as an operand, we +// should only add the edge once. For example: +// ```mlir +// %0 = stablehlo.add %arg0, %arg0 <[<@mesh, [{"a", ?}]>]> : tensor<8xf32> +// return %0 : tensor<8xf32> +// ``` +// The sharding projection said that both operand 0 and 1 are updated. However, +// they are the same value, so we only need to add the edge once. This is only +// the case for the target of the edge, because if the source appears multiple +// times, then it's because it effects multiple other operands/results in the +// op. +bool insertSeenValue(Operation* op, const PropagationEdge& edge, + llvm::SmallDenseSet& seenValues) { + EdgeNode target = edge.target; + switch (target.type) { + case EdgeNodeType::OPERAND: { + if (auto funcOp = dyn_cast(op)) { + return seenValues.insert( + getBodyTerminator(funcOp)->getOperand(target.index)).second; + } + return seenValues.insert(op->getOperand(target.index)).second; + } + case EdgeNodeType::RESULT: { + return true; + } + } +} + +// Create a list of entries from the `axisToEdge` map to save as a +// `DictionaryAttr`. +DictionaryAttr createEdgeEntries(Operation* op, + const AxisToEdgesMap& axisToEdges, + MLIRContext* context) { + Builder builder(context); + SmallVector entries; + for (const auto& [axisRef, edges] : axisToEdges) { + std::string axisRefString = axisRef.toString(); + eraseDoubleQuotesInAxisRefString(axisRefString); + SmallVector axisEntries; + llvm::SmallDenseSet seenTargetValues; + for (const PropagationEdge& edge : edges) { + assert(edge.source.type == EdgeNodeType::OPERAND || + edge.source.type == EdgeNodeType::RESULT); + if (!insertSeenValue(op, edge, seenTargetValues)) { + continue; + } + StringAttr sourceEntry = edgeNodeToString(edge.source, builder); + StringAttr targetEntry = edgeNodeToString(edge.target, builder); + DictionaryAttr edgeEntry = builder.getDictionaryAttr({ + builder.getNamedAttr("source", sourceEntry), + builder.getNamedAttr("target", targetEntry), + builder.getNamedAttr("propagation_step", + builder.getI64IntegerAttr(edge.propagationStep)), + }); + axisEntries.push_back(edgeEntry); + } + entries.emplace_back(builder.getStringAttr(axisRefString), + builder.getArrayAttr(axisEntries)); + } + return builder.getDictionaryAttr(entries); +} + +// Saves the originating sharding debug information on each `op` in +// `mappings->operationToEdgesMap`. +// +// This works by having each `Operation*` save a source/target edge, which is +// always composed of at least one operand. This is because results +// can be used several times, but an operand only ever has one defining op. So +// these edges always look "backwards" - never forwards towards a use of a +// result. +// +// As such, the `FuncOp` args don't contain edge source information: only the +// ops that use them. +void saveEdgesOnModule(MLIRContext* context, + const OperationToEdgesMap& operationToEdgesMap) { + Builder builder(context); + for (auto [op, axisToEdges] : operationToEdgesMap) { + if (isa(op)) { + continue; + } + op->setAttr(kPropagationEdgesAttr, + createEdgeEntries(op, axisToEdges, context)); + } +} + +// Saves the edge source sharding debug information on the result attrs of +// `funcOp`. +// +// Since only the uses of a sharding save the edge, for `FuncOp` results which +// have been updated, we save the edges in each result's `resultAttr`. If the +// function has multiple results, then each `edge_source` on the func result +// attributes will have index 0. This is because of how propagation works with +// running propagation on each returned result. Having them have the right index +// would make sense if the `sdy.edge_sources` were saved as a top level +// attribute on the func, but since one is saved per result, then an index of 0 +// makes most sense. +void saveEdgesOnFuncResults(func::FuncOp funcOp, + const FuncResultToEdgesMap& funcResultToEdgesMap) { + for (auto [funcOp, resultToEdgesMap] : funcResultToEdgesMap) { + for (auto [resultIndex, axisToEdgesMap] : + llvm::enumerate(resultToEdgesMap)) { + funcOp.setResultAttr( + resultIndex, kPropagationEdgesAttr, + createEdgeEntries(funcOp, axisToEdgesMap, funcOp->getContext())); + } + } +} + // Saves the sharding origin information on the `value` to the `handler`. void saveShardingOrigins(ValueToOriginShardingMap& valueToOriginShardingMap, TensorShardingAttr sharding, OriginShardingType type, @@ -428,13 +564,14 @@ void overrideOriginsToSelf(ModuleOp moduleOp) { }); } -// Sets up the `handler` with the initial sharding origin information on -// the `moduleOp`. -// The `SourceShardingHandler` will keep `valueToEdgeSourceMap` and -// `valueToOriginShardingMap` up to date with the source sharding information -// on the module during the propagation rewrite patterns. -void prepareShardingOriginsHandler(ModuleOp moduleOp, - ShardingDebugMappings* mappings) { +// Sets up `valueToOriginShardingMap` with the initial sharding origin +// information on the `moduleOp`. +// +// The `SourceShardingHandler` will keep `valueToOriginShardingMap` up to date +// with the origin sharding information on the module during the propagation +// rewrite patterns. +void prepareShardingOriginsHandler( + ModuleOp moduleOp, ValueToOriginShardingMap& valueToOriginShardingMap) { MLIRContext* context = moduleOp.getContext(); // Build the initial sharding origin map. // NOTE(bartchr): This assumes that we do not propagate across different @@ -442,22 +579,21 @@ void prepareShardingOriginsHandler(ModuleOp moduleOp, // this if we do propagate across `FuncOp`s. moduleOp.walk([&](func::FuncOp funcOp) { for (BlockArgument arg : funcOp.getArguments()) { - saveShardingOrigins(mappings->valueToOriginShardingMap, getSharding(arg), + saveShardingOrigins(valueToOriginShardingMap, getSharding(arg), OriginShardingType::INPUT, arg, arg.getArgNumber()); } for (OpOperand& returnOperand : getBodyTerminatorOpOperands(funcOp)) { int64_t valueIndex = returnOperand.getOperandNumber(); - saveShardingOrigins(mappings->valueToOriginShardingMap, - getFuncResultSharding(funcOp, valueIndex), - OriginShardingType::OUTPUT, returnOperand.get(), - valueIndex); + saveShardingOrigins( + valueToOriginShardingMap, getFuncResultSharding(funcOp, valueIndex), + OriginShardingType::OUTPUT, returnOperand.get(), valueIndex); } }); // NOTE: all `ManualComputationOp`s and `ShardingConstraintOp`s will have a // unique source name, no matter if they aren't in the same `FuncOp`. int64_t sourceId = 0; moduleOp.walk([&](ShardingConstraintOp shardingConstraintOp) { - saveShardingOrigins(mappings->valueToOriginShardingMap, + saveShardingOrigins(valueToOriginShardingMap, shardingConstraintOp.getSharding(), OriginShardingType::CONSTRAINT, shardingConstraintOp.getResult(), 0, sourceId); @@ -476,7 +612,7 @@ void prepareShardingOriginsHandler(ModuleOp moduleOp, auto edge = DataFlowEdgeOp::lookup(manualComputationOp.getBody().getArgument(i)); assert(edge); - saveShardingOrigins(mappings->valueToOriginShardingMap, sharding, + saveShardingOrigins(valueToOriginShardingMap, sharding, OriginShardingType::MC_INPUT, edge.getResult(), i, sourceId); } @@ -485,7 +621,7 @@ void prepareShardingOriginsHandler(ModuleOp moduleOp, // Assuming that the edges live as the only use of the op results. auto edge = DataFlowEdgeOp::lookup(manualComputationOp.getResult(i)); assert(edge); - saveShardingOrigins(mappings->valueToOriginShardingMap, sharding, + saveShardingOrigins(valueToOriginShardingMap, sharding, OriginShardingType::MC_OUTPUT, edge.getResult(), i, sourceId); } @@ -496,12 +632,26 @@ void prepareShardingOriginsHandler(ModuleOp moduleOp, }); } +// Sets up `funcResultToEdgesMap` for saving the edge source information +// on the `moduleOp`. +// +// The `SourceShardingHandler` will keep `funcResultToEdgesMap` up to date +// with the source sharding information on the module during the propagation +// rewrite patterns. +void prepareFuncResultToEdgesHandler( + ModuleOp moduleOp, FuncResultToEdgesMap& funcResultToEdgesMap) { + moduleOp.walk([&](func::FuncOp funcOp) { + funcResultToEdgesMap[funcOp] = + SmallVector(funcOp.getNumResults()); + }); +} + OriginSharding lookUpValueOriginSharding( Value value, AxisRefAttr axisRef, const ValueToOriginShardingMap& valueToOriginShardingMap) { // NOTE: need to call `getShardableValue` in case the operand/result is // part of a `ShardableDataFlowOpInterface` and the `Value` the sharding - // lives on is a `DataFlowEdgeOp` instead of the `edgeSource` itself. + // lives on is a `DataFlowEdgeOp` instead of the `edge` itself. const AxisToOriginShardingMap& axisToOriginSharding = valueToOriginShardingMap.at(getShardableValue(value)); if (auto it = axisToOriginSharding.find(axisRef); @@ -524,9 +674,9 @@ OriginSharding lookUpValueOriginSharding( } // namespace ShardingDebugMappings::ShardingDebugMappings(bool debugShardingOrigins, - bool debugEdgeSourceSharding) + bool debugPropagationEdgeSharding) : debugShardingOrigins(debugShardingOrigins), - debugEdgeSourceSharding(debugEdgeSourceSharding) {} + debugPropagationEdgeSharding(debugPropagationEdgeSharding) {} SourceShardingHandler::SourceShardingHandler(ShardingDebugMappings* mappings) : mappings(mappings) {} @@ -541,68 +691,148 @@ void SourceShardingHandler::operator()(function_ref transform, } auto sourceShardingAction = cast(action); - FactorsToEdgeSourceMap factorsToEdgeSources = createSourceMap( - sourceShardingAction.oldShardingProjection, - sourceShardingAction.newShardingProjection, - sourceShardingAction.shardingRule, sourceShardingAction.mesh); + if (!sourceShardingAction.anyUpdated) { + return; + } + FactorsToEdgeMap factorsToEdges = + createSourceMap(sourceShardingAction.oldShardingProjection, + sourceShardingAction.newShardingProjection, + sourceShardingAction.shardingRule, + sourceShardingAction.mesh, propagationStep); + propagationStep++; // If the new and old shardings are different, something was propagated to it. // Find and save it. - auto lookUpOriginSharding = [&](EdgeSource edgeSource, + auto lookUpOriginSharding = [&](EdgeNode edgeNode, AxisRefAttr axisRef) -> OriginSharding { - switch (edgeSource.type) { + switch (edgeNode.type) { case OPERAND: return lookUpValueOriginSharding( - sourceShardingAction.operands[edgeSource.index], axisRef, + sourceShardingAction.operands[edgeNode.index], axisRef, mappings->valueToOriginShardingMap); case RESULT: return lookUpValueOriginSharding( - sourceShardingAction.results[edgeSource.index], axisRef, + sourceShardingAction.results[edgeNode.index], axisRef, mappings->valueToOriginShardingMap); } - llvm_unreachable("unknown EdgeSource"); + llvm_unreachable("unknown EdgeNode"); }; - auto updateMappings = [&](ShardingDebugMappings* mappings, - AxisToEdgeSourceMap axisToEdgeSource, Value value) { - for (auto [axisRef, edgeSource] : axisToEdgeSource) { - if (mappings->debugEdgeSourceSharding) { - mappings->valueToEdgeSourceMap[value].try_emplace(axisRef, edgeSource); - } - if (mappings->debugShardingOrigins) { - mappings->valueToOriginShardingMap[value].try_emplace( - axisRef, lookUpOriginSharding(edgeSource, axisRef)); + auto updateMappings = [&](int64_t i, AxisRefAttr axisRef, + PropagationEdge edge, Value value) { + if (mappings->debugPropagationEdgeSharding) { + if (auto funcOp = dyn_cast(sourceShardingAction.op)) { + mappings->funcResultToEdgesMap[funcOp][i][axisRef].push_back(edge); + } else { + mappings->operationToEdgesMap[sourceShardingAction.op][axisRef] + .push_back(edge); } } + if (mappings->debugShardingOrigins) { + mappings->valueToOriginShardingMap[value].try_emplace( + axisRef, lookUpOriginSharding(edge.source, axisRef)); + } }; - for (auto [operand, axisToEdgeSource] : llvm::zip_equal( - sourceShardingAction.operands, factorsToEdgeSources.operands)) { - updateMappings(mappings, axisToEdgeSource, operand); + + for (auto [i, operand] : llvm::enumerate(sourceShardingAction.operands)) { + for (auto [axisRef, edge] : factorsToEdges.operands[i]) { + updateMappings(i, axisRef, edge, operand); + } } - for (auto [result, axisToEdgeSource] : llvm::zip_equal( - sourceShardingAction.results, factorsToEdgeSources.results)) { - updateMappings(mappings, axisToEdgeSource, result); + + for (auto [i, result] : llvm::enumerate(sourceShardingAction.results)) { + for (auto [axisRef, edge] : factorsToEdges.results[i]) { + updateMappings(i, axisRef, edge, result); + } } } void SourceShardingHandler::prepareHandler(ModuleOp moduleOp) { if (mappings->debugShardingOrigins) { - prepareShardingOriginsHandler(moduleOp, mappings); + prepareShardingOriginsHandler(moduleOp, mappings->valueToOriginShardingMap); } - if (mappings->debugEdgeSourceSharding) { - llvm_unreachable("edge sharding not implemented yet"); + if (mappings->debugPropagationEdgeSharding) { + prepareFuncResultToEdgesHandler(moduleOp, mappings->funcResultToEdgesMap); } - if (mappings->debugShardingOrigins || mappings->debugEdgeSourceSharding) { + if (mappings->debugShardingOrigins || + mappings->debugPropagationEdgeSharding) { moduleOp->getContext()->registerActionHandler(*this); } } void SourceShardingHandler::saveOnModule(ModuleOp moduleOp) { + MLIRContext* context = moduleOp.getContext(); if (mappings->debugShardingOrigins) { - saveShardingOriginsOnModule(moduleOp, mappings); + saveShardingOriginsOnModule(context, mappings->valueToOriginShardingMap); overrideOriginsToSelf(moduleOp); } - if (mappings->debugEdgeSourceSharding) { - llvm_unreachable("edge sharding not implemented yet"); + if (mappings->debugPropagationEdgeSharding) { + saveEdgesOnModule(context, mappings->operationToEdgesMap); + moduleOp.walk([&](func::FuncOp funcOp) { + saveEdgesOnFuncResults(funcOp, mappings->funcResultToEdgesMap); + }); + } +} + +namespace { + +// Looks for the debug info dictionary on the `dataFlowEdgeOp` called +// `debugAttrName` and pushes it back to the `debugInfoDict`. If the dictionary +// doesn't exist, pushes an empty dictionary. +void pushBackToDebugInfoDict(DataFlowEdgeOp dataFlowEdgeOp, + StringRef debugAttrName, + SmallVector& debugInfoDict, + IRRewriter& rewriter) { + assert(dataFlowEdgeOp); + if (auto edgeDebugInfo = + dataFlowEdgeOp->getAttrOfType(debugAttrName)) { + debugInfoDict.push_back(edgeDebugInfo); + } else { + debugInfoDict.push_back(rewriter.getDictionaryAttr({})); + } +} + +} // namespace + +void saveDebugInfoDictsFromDataFlowEdges(ValueRange edgeOwners, Operation* op, + bool sinkDebugShardingOrigins, + bool sinkDebugPropagationEdgeSharding, + EdgeNodeType edgeNodeType, + IRRewriter& rewriter) { + if (!sinkDebugShardingOrigins && !sinkDebugPropagationEdgeSharding) { + return; + } + SmallVector originShardingDicts; + if (sinkDebugShardingOrigins) { + originShardingDicts.reserve(edgeOwners.size()); + } + SmallVector propagationEdgeDicts; + if (sinkDebugPropagationEdgeSharding) { + propagationEdgeDicts.reserve(edgeOwners.size()); + } + for (Value edgeOwner : edgeOwners) { + if (auto dataFlowEdgeOp = DataFlowEdgeOp::lookup(edgeOwner)) { + if (sinkDebugShardingOrigins) { + pushBackToDebugInfoDict(dataFlowEdgeOp, kShardingOriginsAttr, + originShardingDicts, rewriter); + } + if (sinkDebugPropagationEdgeSharding) { + pushBackToDebugInfoDict(dataFlowEdgeOp, kPropagationEdgesAttr, + propagationEdgeDicts, rewriter); + } + } + } + + if (sinkDebugShardingOrigins) { + op->setAttr(edgeNodeType == EdgeNodeType::OPERAND + ? kBlockArgShardingOriginsAttr + : kResultShardingOriginsAttr, + rewriter.getArrayAttr(originShardingDicts)); + } + if (sinkDebugPropagationEdgeSharding) { + op->setAttr(edgeNodeType == EdgeNodeType::OPERAND + ? kBlockArgPropagationEdgesAttr + : kResultPropagationEdgesAttr, + rewriter.getArrayAttr(propagationEdgeDicts)); } } diff --git a/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.h b/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.h index b21207d5..991194c3 100644 --- a/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.h +++ b/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.h @@ -20,9 +20,12 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Action.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/Unit.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" @@ -53,24 +56,39 @@ struct OriginSharding { }; // Specifies whether a sharding came from an operand or a result. -enum EdgeSourceType { OPERAND, RESULT }; +enum EdgeNodeType { OPERAND, RESULT }; // The operand/result a sharding came from through an `Operation` to modify the // sharding of some `Value` in the `Operation`. -struct EdgeSource { - EdgeSourceType type; +struct EdgeNode { + EdgeNodeType type; int64_t index; }; -using AxisToEdgeSourceMap = llvm::DenseMap; +// The source and target of a source sharding edge. +struct PropagationEdge { + EdgeNode source; + EdgeNode target; + int64_t propagationStep; +}; + +// Types used for `debugPropagationEdgeSharding`. +using AxisToEdgeMap = llvm::DenseMap; +using AxisToEdgesMap = + llvm::DenseMap>; +using OperationToEdgesMap = llvm::DenseMap; +// Mapping from `FuncOp` to the edges for each result. +using FuncResultToEdgesMap = + llvm::DenseMap>; + +// Types used for `debugShardingOrigins`. using AxisToOriginShardingMap = llvm::DenseMap; -using ValueToEdgeSourceMap = llvm::DenseMap; using ValueToOriginShardingMap = llvm::DenseMap; // The mappings used for debugging sharding origins and edge sources. struct ShardingDebugMappings { ShardingDebugMappings(bool debugShardingOrigins, - bool debugEdgeSourceSharding); + bool debugPropagationEdgeSharding); // We do not allow copying of the mappings, as we don't want the mappings // to be copied over to the new instance. There should only ever be one @@ -78,8 +96,13 @@ struct ShardingDebugMappings { ShardingDebugMappings(const ShardingDebugMappings&) = delete; ShardingDebugMappings& operator=(const ShardingDebugMappings&) = delete; - bool debugShardingOrigins, debugEdgeSourceSharding; - ValueToEdgeSourceMap valueToEdgeSourceMap; + bool debugShardingOrigins, debugPropagationEdgeSharding; + OperationToEdgesMap operationToEdgesMap; + // NOTE: we need a separate map for `FuncOp` results as propagation is run + // per terminator operand/result pair, so we need to figure out which index + // the `FuncOp` result is. So this saves the edges for each `FuncOp` result + // separately. + FuncResultToEdgesMap funcResultToEdgesMap; ValueToOriginShardingMap valueToOriginShardingMap; }; @@ -91,17 +114,20 @@ class SourceShardingAction : public tracing::ActionImpl { public: using Base = tracing::ActionImpl; - SourceShardingAction(ArrayRef irUnits, ValueRange operands, - ValueRange results, MeshAttr mesh, + SourceShardingAction(ArrayRef irUnits, Operation* op, + ValueRange operands, ValueRange results, MeshAttr mesh, OpShardingRuleAttr shardingRule, - const ShardingProjection& shardingProjection) + const ShardingProjection& shardingProjection, + const bool& anyUpdated) : Base(irUnits), + op(op), operands(operands), results(results), mesh(mesh), shardingRule(shardingRule), oldShardingProjection(shardingProjection), - newShardingProjection(shardingProjection) {} + newShardingProjection(shardingProjection), + anyUpdated(anyUpdated) {} static constexpr StringLiteral tag = "SourceShardingAction"; static constexpr StringLiteral desc = @@ -109,6 +135,7 @@ class SourceShardingAction : public tracing::ActionImpl { "a user defined sharding either on `FuncOp` inputs/outputs, an " "`sdy.sharding_constraint`, or `sdy.ManualComputationOp` input/output."; + Operation* op; ValueRange operands, results; MeshAttr mesh; OpShardingRuleAttr shardingRule; @@ -117,12 +144,14 @@ class SourceShardingAction : public tracing::ActionImpl { // new sharding projections differ. const ShardingProjection oldShardingProjection; const ShardingProjection& newShardingProjection; + // Whether any of the operands/results were updated. + const bool& anyUpdated; }; // Handles `SourceShardingAction`s, figuring out what operand/result shardings // have been propagated through due to new axes. Saves what was the source of // the axis to appear on the sharding of a given `Value` to -// `valueToEdgeSourceMap` and `valueToOriginShardingMap`. +// `operationToEdgesMap`/`funcResultToEdgesMap` and `valueToOriginShardingMap`. struct SourceShardingHandler { SourceShardingHandler(ShardingDebugMappings* mappings); @@ -140,8 +169,27 @@ struct SourceShardingHandler { private: ShardingDebugMappings* mappings; + int64_t propagationStep = 0; }; +// Saves an array of all the origin sharding and propagation edge dictionaries +// for the given `edgeOwners` on `op`. If non exist, nothing is saved. +// +// Saving the info depends on if the corresponding `sinkDebugShardingOrigins` +// and `sinkDebugPropagationEdgeSharding` are true. +// +// For debugging the origin shardings and propagation edges, we want to preserve +// the debugging dictionaries from the `DataFlowEdgeOp`s on the owning op so +// that they are preserved after the propagation pipeline. +// +// See the `debug-sharding-origins` and `debug-edge-source-sharding` config on +// propagation for more details. +void saveDebugInfoDictsFromDataFlowEdges(ValueRange edgeOwners, Operation* op, + bool sinkDebugShardingOrigins, + bool sinkDebugPropagationEdgeSharding, + EdgeNodeType edgeNodeType, + IRRewriter& rewriter); + } // namespace sdy } // namespace mlir diff --git a/shardy/dialect/sdy/transforms/propagation/debugging/test/edge_shardings.mlir b/shardy/dialect/sdy/transforms/propagation/debugging/test/edge_shardings.mlir new file mode 100644 index 00000000..9a27a63f --- /dev/null +++ b/shardy/dialect/sdy/transforms/propagation/debugging/test/edge_shardings.mlir @@ -0,0 +1,233 @@ +// RUN: sdy_opt %s -split-input-file -sdy-add-data-flow-edges -sdy-aggressive-propagate=debug-propagation-edge-sharding=true -sdy-sink-data-flow-edges="sink-debug-propagation-edge-sharding=true" 2>&1 | FileCheck %s + + +sdy.mesh @mesh = <["a"=2, "b"=2, "c"=8]> + +// CHECK-LABEL: input_output_source_sharding +// CHECK-SAME: %arg0: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}, {"c", ?}]>}, +// CHECK-SAME: %arg1: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}, {"c", ?}]>} +// CHECK-SAME: ) -> (tensor<8x8x8xf32> {sdy.propagation_edges = {a = [{propagation_step = 2 : i64, source = "operand: 0", target = "result: 0"}], +// CHECK-SAME: b = [{propagation_step = 0 : i64, source = "result: 0", target = "operand: 0"}], +// CHECK-SAME: c = [{propagation_step = 2 : i64, source = "operand: 0", target = "result: 0"}]}, +// CHECK-SAME: sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}, {"c", ?}]>}) { +func.func @input_output_source_sharding( + %arg0: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}, {?}]>}, + %arg1: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {?}, {"c", ?}]>} +) -> (tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"b", ?}, {?}]>}) { + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 { + // CHECK-SAME: sdy.propagation_edges = {a = [{propagation_step = 1 : i64, source = "operand: 0", target = "operand: 1"}, + // CHECK-SAME: {propagation_step = 1 : i64, source = "operand: 0", target = "result: 0"}], + // CHECK-SAME: b = [{propagation_step = 1 : i64, source = "result: 0", target = "operand: 0"}, + // CHECK-SAME: {propagation_step = 1 : i64, source = "result: 0", target = "operand: 1"}], + // CHECK-SAME: c = [{propagation_step = 1 : i64, source = "operand: 1", target = "operand: 0"}, + // CHECK-SAME: {propagation_step = 1 : i64, source = "operand: 1", target = "result: 0"}]}, + // CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", ?}, {"b", ?}, {"c", ?}]>]> + // CHECK-SAME: } : tensor<8x8x8xf32> + // CHECK-NEXT: return %[[ADD]] : tensor<8x8x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8x8xf32> + return %0 : tensor<8x8x8xf32> +} + +// ----- + +sdy.mesh @mesh = <["a"=2]> + +// NOTE: Instead of saving `{source = "result: 0", target = "operand: 0"}` and +// `{source = "result: 0", target = "operand: 1"}` on the add due to the same +// value being used twice as an operand, we only save the edge once. +// +// CHECK-LABEL: duplicate_operands +// CHECK-SAME: %arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}]>} +// CHECK-SAME: ) -> (tensor<8xf32> {sdy.propagation_edges = {a = [{propagation_step = 0 : i64, source = "result: 0", target = "operand: 0"}]}, +// CHECK-SAME: sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}]>}) { +func.func @duplicate_operands( + %arg0: tensor<8xf32> +) -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}]>}) { + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 { + // CHECK-SAME: sdy.propagation_edges = {a = [{propagation_step = 1 : i64, source = "result: 0", target = "operand: 0"}]}, + // CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", ?}]>]> + // CHECK-SAME: } : tensor<8xf32> + // CHECK-NEXT: return %[[ADD]] : tensor<8xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8xf32> + return %0 : tensor<8xf32> +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// NOTE: since the definition of an edge always contains an operand as a source +// or target, even though the result sharding added `a` before the arg sharding, +// then the edge where axis `a` was added to the add is stored on the func +// result sharding. +// +// CHECK-LABEL: multiple_axes +// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", "b", ?}, {?}]>} +// CHECK-SAME: ) -> (tensor<8x8xf32> {sdy.propagation_edges = {a = [{propagation_step = 0 : i64, source = "result: 0", target = "operand: 0"}], +// CHECK-SAME: b = [{propagation_step = 2 : i64, source = "operand: 0", target = "result: 0"}]}, +// CHECK-SAME: sdy.sharding = #sdy.sharding<@mesh, [{"a", "b", ?}, {?}]>}) { +func.func @multiple_axes( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", "b", ?}, {?}]>} +) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}]>}) { + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 { + // CHECK-SAME: sdy.propagation_edges = {b = [{propagation_step = 1 : i64, source = "operand: 0", target = "result: 0"}]}, + // CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", ?}, {?}]>]> + // CHECK-SAME: } : tensor<8x8xf32> + // CHECK-NEXT: return %[[ADD]] : tensor<8x8xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32> + return %0 : tensor<8x8xf32> +} + +// ----- + +sdy.mesh @mesh = <["c"=8]> + +// NOTE(b/385908435): note how we save the smaller and larger sub axes on the +// func result sharding. Maybe this behavior is good, or should change? To be +// seen. +// +// CHECK-LABEL: sub_axis_update +// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"c":(1)4, ?}]>} +// CHECK-SAME: ) -> (tensor<8x8xf32> {sdy.propagation_edges = {"c:(1)2" = [{propagation_step = 0 : i64, source = "result: 0", target = "operand: 0"}], +// CHECK-SAME: "c:(1)4" = [{propagation_step = 2 : i64, source = "operand: 0", target = "result: 0"}]}, +// CHECK-SAME: sdy.sharding = #sdy.sharding<@mesh, [{?}, {"c":(1)4, ?}]>}) { +func.func @sub_axis_update( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"c":(1)4, ?}]>} +) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"c":(1)2, ?}]>}) { + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 { + // CHECK-SAME: sdy.propagation_edges = {"c:(1)4" = [{propagation_step = 1 : i64, source = "operand: 0", target = "result: 0"}]}, + // CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"c":(1)4, ?}]>]> + // CHECK-SAME: } : tensor<8x8xf32> + // CHECK-NEXT: return %[[ADD]] : tensor<8x8xf32> + %1 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32> + return %1 : tensor<8x8xf32> +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// CHECK-LABEL: manual_computation_manual_axes +// CHECK-SAME: %arg0: tensor<32x32x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}, {?}]>} +// CHECK-SAME: -> (tensor<32x32x32xf32> {sdy.propagation_edges = {a = [{propagation_step = 5 : i64, source = "operand: 0", target = "result: 0"}], +// CHECK-SAME: b = [{propagation_step = 5 : i64, source = "operand: 0", target = "result: 0"}]}, +// CHECK-SAME: sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}, {?}]>}) { +func.func @manual_computation_manual_axes(%arg0: tensor<32x32x32xf32>) -> tensor<32x32x32xf32> { + // CHECK-NEXT: %[[SUB:.*]] = stablehlo.subtract %arg0, %arg0 { + // CHECK-SAME: sdy.propagation_edges = {a = [{propagation_step = 1 : i64, source = "result: 0", target = "operand: 0"}], + // CHECK-SAME: b = [{propagation_step = 1 : i64, source = "result: 0", target = "operand: 0"}]}, + // CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", ?}, {"b", ?}, {?}]>]> + // CHECK-SAME: } : tensor<32x32x32xf32> + // CHECK-NEXT: %[[MC:.*]] = sdy.manual_computation(%[[SUB]]) + // CHECK-SAME: in_shardings=[<@mesh, [{"a", ?}, {"b", ?}, {?}]>] + // CHECK-SAME: out_shardings=[<@mesh, [{"a", ?}, {"b", ?}, {?}]>] + // CHECK-SAME: manual_axes={"a"} (%arg1: tensor<16x32x32xf32>) { + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg1, %arg1 { + // CHECK-SAME: sdy.propagation_edges = {b = [{propagation_step = 2 : i64, source = "operand: 0", target = "result: 0"}]}, + // CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"b", ?}, {?}]>]>} + // CHECK-NEXT: sdy.return %[[ADD]] + // CHECK-NEXT: } { + // CHECK-SAME: sdy.block_arg_propagation_edges = [{ + // CHECK-SAME: a = [{propagation_step = 0 : i64, source = "result: 0", target = "operand: 0"}], + // CHECK-SAME: b = [{propagation_step = 0 : i64, source = "result: 0", target = "operand: 0"}]}], + // CHECK-SAME: sdy.result_propagation_edges = [{ + // CHECK-SAME: b = [{propagation_step = 3 : i64, source = "operand: 0", target = "result: 0"}]}] + // CHECK-SAME: } : (tensor<32x32x32xf32>) -> tensor<32x32x32xf32> + // CHECK-NEXT: %[[SUB_2:.*]] = stablehlo.subtract %[[MC]], %[[MC]] { + // CHECK-SAME: sdy.propagation_edges = {a = [{propagation_step = 4 : i64, source = "operand: 0", target = "result: 0"}], + // CHECK-SAME: b = [{propagation_step = 4 : i64, source = "operand: 0", target = "result: 0"}]}, + // CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", ?}, {"b", ?}, {?}]>]> + // CHECK-SAME: } : tensor<32x32x32xf32> + // CHECK-NEXT: return %[[SUB_2]] + %0 = stablehlo.subtract %arg0, %arg0 : tensor<32x32x32xf32> + %1 = sdy.manual_computation(%0) in_shardings=[<@mesh, [{"a", ?}, {"b", ?}, {?}]>] out_shardings=[<@mesh, [{"a", ?}, {?}, {?}]>] manual_axes={"a"} (%arg1: tensor<16x32x32xf32>) { + %3 = stablehlo.add %arg1, %arg1 : tensor<16x32x32xf32> + sdy.return %3 : tensor<16x32x32xf32> + } : (tensor<32x32x32xf32>) -> tensor<32x32x32xf32> + %2 = stablehlo.subtract %1, %1 : tensor<32x32x32xf32> + return %2: tensor<32x32x32xf32> +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// TODO(b/391840483): If the function has multiple results, then each +// `edge_source` on the func result attributes will have index 0. This is +// because of how propagation works with running propagation on each returned +// result. Reconsider this behavior. +// +// CHECK-LABEL: manual_computation_multiple_results +// CHECK-SAME: %arg0: tensor<32x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"b", ?}, {"a", ?}]>}) +// CHECK-SAME: -> (tensor<16x32xf32> {sdy.propagation_edges = {a = [{propagation_step = 0 : i64, source = "operand: 0", target = "result: 0"}, +// CHECK-SAME: {propagation_step = 6 : i64, source = "operand: 0", target = "result: 0"}], +// CHECK-SAME: b = [{propagation_step = 0 : i64, source = "operand: 0", target = "result: 0"}]}, +// CHECK-SAME: sdy.sharding = #sdy.sharding<@mesh, [{?}, {"a", ?}]>}, +// CHECK-SAME: tensor<32x32xf32> {sdy.propagation_edges = {}, +// CHECK-SAME: sdy.sharding = #sdy.sharding<@mesh, [{"b", ?}, {"a", ?}]>}) { +func.func @manual_computation_multiple_results(%arg0: tensor<32x32xf32>) -> (tensor<16x32xf32>, tensor<32x32xf32>) { + // CHECK-NEXT: %[[MC:.*]]:2 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"b", ?}, {"a", ?}]>] out_shardings=[<@mesh, [{?}, {"a", ?}], replicated={"b"}>, <@mesh, [{"b", ?}, {"a", ?}]>] manual_axes={"b"} (%arg1: tensor<16x32xf32>) { + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg1, %arg1 { + // CHECK-SAME: sdy.propagation_edges = {a = [{propagation_step = 4 : i64, source = "result: 0", target = "operand: 0"}]}, + // CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"a", ?}]>]> + // CHECK-SAME: } : tensor<16x32xf32> + // CHECK-NEXT: sdy.return %[[ADD]], %[[ADD]] : tensor<16x32xf32>, tensor<16x32xf32> + // CHECK-NEXT: } { + // CHECK-SAME: sdy.block_arg_propagation_edges = [{ + // CHECK-SAME: a = [{propagation_step = 5 : i64, source = "result: 0", target = "operand: 0"}], + // CHECK-SAME: b = [{propagation_step = 1 : i64, source = "result: 0", target = "operand: 0"}]}], + // CHECK-SAME: sdy.result_propagation_edges = [ + // CHECK-SAME: {a = [{propagation_step = 3 : i64, source = "operand: 0", target = "result: 0"}]}, + // CHECK-SAME: {a = [{propagation_step = 2 : i64, source = "result: 0", target = "operand: 0"}]}] + // CHECK-SAME: } : (tensor<32x32xf32>) -> (tensor<16x32xf32>, tensor<32x32xf32>) + // CHECK-NEXT: return %[[MC]]#0, %[[MC]]#1 : tensor<16x32xf32>, tensor<32x32xf32> + %0:2 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"b", ?}, {?}]>] out_shardings=[<@mesh, [{?}, {?}], replicated={"b"}>, <@mesh, [{"b", ?}, {"a", ?}]>] manual_axes={"b"} (%arg1: tensor<16x32xf32>) { + %1 = stablehlo.add %arg1, %arg1 : tensor<16x32xf32> + sdy.return %1, %1 : tensor<16x32xf32>, tensor<16x32xf32> + } : (tensor<32x32xf32>) -> (tensor<16x32xf32>, tensor<32x32xf32>) + return %0#0, %0#1 : tensor<16x32xf32>, tensor<32x32xf32> +} + +// ----- + +sdy.mesh @mesh = <["c"=8]> + +// CHECK-LABEL: sub_axes_splitting_reshape +// CHECK-SAME: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}]>} +// CHECK-SAME: ) -> (tensor<4x4xf32> {sdy.propagation_edges = {"c:(1)4" = [{propagation_step = 1 : i64, source = "operand: 0", target = "result: 0"}], +// CHECK-SAME: "c:(4)2" = [{propagation_step = 1 : i64, source = "operand: 0", target = "result: 0"}]}, +// CHECK-SAME: sdy.sharding = #sdy.sharding<@mesh, [{"c":(1)4, ?}, {"c":(4)2, ?}]>}) { +func.func @sub_axes_splitting_reshape( + %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}]>} +) -> tensor<4x4xf32> { + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %arg0 { + // CHECK-SAME: sdy.propagation_edges = {"c:(1)4" = [{propagation_step = 0 : i64, source = "operand: 0", target = "result: 0"}], + // CHECK-SAME: "c:(4)2" = [{propagation_step = 0 : i64, source = "operand: 0", target = "result: 0"}]}, + // CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"c":(1)4, ?}, {"c":(4)2, ?}]>]> + // CHECK-SAME: } : (tensor<16xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: return %[[RESHAPE]] + %0 = stablehlo.reshape %arg0 : (tensor<16xf32>) -> tensor<4x4xf32> + return %0 : tensor<4x4xf32> +} + +// ----- + +sdy.mesh @mesh = <["c"=8]> + +// NOTE: since the reshape combines the two sub axes into one, we only save the +// merged axis on the reshape as an edge between the operand and result. +// +// CHECK-LABEL: sub_axes_merging_reshape +// CHECK-SAME: %arg0: tensor<4x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c":(1)4, ?}, {"c":(4)2, ?}]>} +// CHECK-SAME: ) -> (tensor<16xf32> {sdy.propagation_edges = {c = [{propagation_step = 1 : i64, source = "operand: 0", target = "result: 0"}]}, +// CHECK-SAME: sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}]>}) { +func.func @sub_axes_merging_reshape( + %arg0: tensor<4x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c":(1)4, ?}, {"c":(4)2, ?}]>}) + -> tensor<16xf32> { + // CHECK-NEXT: stablehlo.reshape %arg0 { + // CHECK-SAME: sdy.propagation_edges = {c = [{propagation_step = 0 : i64, source = "operand: 0", target = "result: 0"}]}, + // CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"c", ?}]>]> + // CHECK-SAME: } : (tensor<4x4xf32>) -> tensor<16xf32> + %0 = stablehlo.reshape %arg0 : (tensor<4x4xf32>) -> tensor<16xf32> + return %0 : tensor<16xf32> +} diff --git a/shardy/dialect/sdy/transforms/propagation/debugging/test/sharding_origins.mlir b/shardy/dialect/sdy/transforms/propagation/debugging/test/sharding_origins.mlir index 89cc0dc0..9935aed6 100644 --- a/shardy/dialect/sdy/transforms/propagation/debugging/test/sharding_origins.mlir +++ b/shardy/dialect/sdy/transforms/propagation/debugging/test/sharding_origins.mlir @@ -1,4 +1,4 @@ -// RUN: sdy_opt %s -sdy-add-data-flow-edges -sdy-aggressive-propagate="debug-sharding-origins=true" -sdy-sink-data-flow-edges 2>&1 | FileCheck %s +// RUN: sdy_opt %s -sdy-add-data-flow-edges -sdy-aggressive-propagate=debug-sharding-origins=true -sdy-sink-data-flow-edges="sink-debug-sharding-origins=true" 2>&1 | FileCheck %s sdy.mesh @mesh = <["a"=2, "b"=2, "c"=8]> diff --git a/shardy/dialect/sdy/transforms/propagation/passes.h b/shardy/dialect/sdy/transforms/propagation/passes.h index aa751680..d56485ec 100644 --- a/shardy/dialect/sdy/transforms/propagation/passes.h +++ b/shardy/dialect/sdy/transforms/propagation/passes.h @@ -46,7 +46,7 @@ struct PropagationOptions { // Whether to save debug information about the sharding origins on the module. bool debugShardingOrigins = false; // Whether to save debug information about the edge shardings on the module. - bool debugEdgeSourceSharding = false; + bool debugPropagationEdgeSharding = false; // Whether to avoid converting `sdy::ShardingConstraintOp` to // `sdy::ReshardOp`. bool skipConvertToReshard = false;