Skip to content

Commit

Permalink
#sdy #debug support debugging tool to save the edge source shardings.
Browse files Browse the repository at this point in the history
This works by having each `Operation*` save a source/target edge, which is always composed of at least one operand. This is because results can be used several times, but an operand only ever has one defining op. So these edges always look "backwards" - never forwards towards a use of a result.

As such, the `FuncOp` args don't contain edge source information: only the ops that use them.

Similarly, since only the uses save the edge, for `FuncOp` results which have been updated, we save the edges in each result's `resultAttr`. If the function has multiple results, then each `edge_source` on the func result attributes will have index 0. This is because of how propagation works with running propagation on each returned result. Having them have the right index would make sense if the `sdy.edge_sources` were saved as a top level attribute on the func, but since one is saved per result, then an index of 0 makes most sense.

PiperOrigin-RevId: 709122522
  • Loading branch information
bartchr808 authored and copybara-github committed Feb 5, 2025
1 parent 9541653 commit 4fde090
Show file tree
Hide file tree
Showing 11 changed files with 731 additions and 226 deletions.
14 changes: 14 additions & 0 deletions shardy/dialect/sdy/ir/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,30 @@ inline constexpr StringRef kShardingRuleAttr = "sdy.sharding_rule";
// caused a value to be sharded a certain way.
inline constexpr StringRef kShardingOriginsAttr = "sdy.sharding_origins";

// Attribute name for saving which operand/result sharding of an op caused its
// value to be sharded a certain way.
inline constexpr StringRef kPropagationEdgesAttr = "sdy.propagation_edges";

// Attribute name like `kShardingOriginsAttr` but for
// `ShardableDataFlowOpInterface` op block arguments.
inline constexpr StringRef kBlockArgShardingOriginsAttr =
"sdy.block_arg_sharding_origins";

// Attribute name like `kPropagationEdgesAttr` but for
// `ShardableDataFlowOpInterface` op block arguments.
inline constexpr StringRef kBlockArgPropagationEdgesAttr =
"sdy.block_arg_propagation_edges";

// Attribute name like `kShardingOriginsAttr` but for
// `ShardableDataFlowOpInterface` op results.
inline constexpr StringRef kResultShardingOriginsAttr =
"sdy.result_sharding_origins";

// Attribute name like `kPropagationEdgesAttr` but for
// `ShardableDataFlowOpInterface` op results.
inline constexpr StringRef kResultPropagationEdgesAttr =
"sdy.result_propagation_edges";

// Attribute name for the unique name of a sharding origin. Is either an
// `sdy.sharding_constraint`, or `sdy.ManualComputationOp` input/output.
inline constexpr StringRef kShardingOriginNameAttr = "sdy.sharding_origin_name";
Expand Down
1 change: 1 addition & 0 deletions shardy/dialect/sdy/transforms/export/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ cc_library(
"//shardy/dialect/sdy/transforms/propagation:op_sharding_rule_registry",
"//shardy/dialect/sdy/transforms/propagation:sharding_projection",
"//shardy/dialect/sdy/transforms/propagation:utils",
"//shardy/dialect/sdy/transforms/propagation/debugging:source_sharding",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down
13 changes: 13 additions & 0 deletions shardy/dialect/sdy/transforms/export/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ def SinkDataFlowEdgesPass : Pass<"sdy-sink-data-flow-edges", "func::FuncOp"> {
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
//TODO(tomnatan): consider moving the sharding to all targets that can have a sharding attached.

let options = [
Option<"sinkDebugShardingOrigins", "sink-debug-sharding-origins", "bool",
/*default=*/"false",
"Whether to sink the debug sharding origins info. See "
"`debug-sharding-origins` option in propagation for more info.">,
Option<"sinkDebugPropagationEdgeSharding",
"sink-debug-propagation-edge-sharding", "bool",
/*default=*/"false",
"Whether to sink the debug propagation edge sharding info. See "
"`debug-propagation-edge-sharding` option in propagation for more "
"info.">
];
}

def UpdateNonDivisibleInputOutputShardingsPass : Pass<"sdy-update-non-divisible-input-output-shardings", "func::FuncOp"> {
Expand Down
51 changes: 8 additions & 43 deletions shardy/dialect/sdy/transforms/export/sink_data_flow_edges.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ limitations under the License.
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
#include "mlir/Support/LLVM.h"
#include "shardy/dialect/sdy/ir/constants.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "shardy/dialect/sdy/transforms/export/passes.h" // IWYU pragma: keep
#include "shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.h"

namespace mlir {
namespace sdy {
Expand Down Expand Up @@ -81,44 +82,6 @@ SmallVector<TensorShardingAttr> getShardingsFromDataFlowEdges(
return shardings;
}

// Saves an array of all the origin sharding dictionaries for the given
// `edgeOwners` on `op`. If non exist, nothing is saved.
//
// For debugging the origin shardings, we want to preserve the origin sharding
// dictionaries from the `DataFlowEdgeOp`s on the owning op so that they are
// preserved after the propagation pipeline.
//
// See the `debug-sharding-origins` config on propagation for more details.
//
// TODO(b/388458831): add `saveDebugPropagationInfo` to the pass and pass it in
// here. Can then reserve the right size for `originShardingDicts` and not need
// the `exists` boolean.
void buildOriginShardingDictsFromDataFlowEdges(ValueRange edgeOwners,
Operation* op,
StringRef attrName,
IRRewriter& rewriter) {
SmallVector<Attribute> originShardingDicts;
// TODO(b/388458831): pass through a boolean indicating whether the origin
// sharding debug information is enabled.
bool exists = false;
for (Value edgeOwner : edgeOwners) {
DictionaryAttr dict;
if (auto dataFlowEdgeOp = DataFlowEdgeOp::lookup(edgeOwner)) {
dict =
dataFlowEdgeOp->getAttrOfType<DictionaryAttr>(kShardingOriginsAttr);
}
if (!dict) {
dict = rewriter.getDictionaryAttr({});
} else {
exists = true;
}
originShardingDicts.push_back(dict);
}
if (exists) {
op->setAttr(attrName, rewriter.getArrayAttr(originShardingDicts));
}
}

struct SinkDataFlowEdgesPass
: public impl::SinkDataFlowEdgesPassBase<SinkDataFlowEdgesPass> {
using SinkDataFlowEdgesPassBase::SinkDataFlowEdgesPassBase;
Expand Down Expand Up @@ -152,17 +115,19 @@ struct SinkDataFlowEdgesPass
shardableDataFlowOp.setBlockArgumentEdgeOwnerShardings(
blockArgShardings);
}
buildOriginShardingDictsFromDataFlowEdges(
blockArgOwners, op, kBlockArgShardingOriginsAttr, rewriter);
saveDebugInfoDictsFromDataFlowEdges(
blockArgOwners, op, sinkDebugShardingOrigins,
sinkDebugPropagationEdgeSharding, EdgeNodeType::OPERAND, rewriter);

ResultRange resultOwners = shardableDataFlowOp.getOpResultEdgeOwners();
if (SmallVector<TensorShardingAttr> resultShardings =
getShardingsFromDataFlowEdges(resultOwners);
!resultShardings.empty()) {
shardableDataFlowOp.setOpResultEdgeOwnerShardings(resultShardings);
}
buildOriginShardingDictsFromDataFlowEdges(
resultOwners, op, kResultShardingOriginsAttr, rewriter);
saveDebugInfoDictsFromDataFlowEdges(
resultOwners, op, sinkDebugShardingOrigins,
sinkDebugPropagationEdgeSharding, EdgeNodeType::RESULT, rewriter);
return WalkResult::advance();
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ LogicalResult propagateTensorShardings(
if (context->hasActionHandler()) {
context->executeAction<SourceShardingAction>(
updateShardings,
/*IRUnits=*/{op}, operandsParams.tensors, resultsParams.tensors, mesh,
shardingRule, shardingProjection);
/*IRUnits=*/{op}, op, operandsParams.tensors, resultsParams.tensors,
mesh, shardingRule, shardingProjection, anyUpdated);
} else {
updateShardings();
}
Expand Down Expand Up @@ -649,7 +649,8 @@ void BasicPropagationPassImpl::runOnOperation() {
MLIRContext& context = getContext();

// Prepare debugging handler for sharding origins and edge sources.
ShardingDebugMappings mappings(debugShardingOrigins, debugEdgeSourceSharding);
ShardingDebugMappings mappings(debugShardingOrigins,
debugPropagationEdgeSharding);
SourceShardingHandler handler(&mappings);
handler.prepareHandler(moduleOp);

Expand Down Expand Up @@ -683,7 +684,7 @@ void BasicPropagationPassImpl::setPropagationOptions(
dumpDirectory = options.dumpDirectory.str();
conservativePropagation = options.conservativePropagation;
debugShardingOrigins = options.debugShardingOrigins;
debugEdgeSourceSharding = options.debugEdgeSourceSharding;
debugPropagationEdgeSharding = options.debugPropagationEdgeSharding;
}

std::unique_ptr<Pass> createBasicPropagationPass(
Expand Down
10 changes: 5 additions & 5 deletions shardy/dialect/sdy/transforms/propagation/basic_propagation.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ class BasicPropagationPassImpl : public OperationPass<ModuleOp> {
"before propagation."),
llvm::cl::init(false)};

Option<bool> debugEdgeSourceSharding{
*this, "debug-edge-source-sharding",
Option<bool> debugPropagationEdgeSharding{
*this, "debug-propagation-edge-sharding",
llvm::cl::desc(
"whether to save information about the edge source of a sharding "
"on the MLIR module. These are from which operand/result a sharding "
"was propagated."),
"whether to save information about the SSA value edges of how a "
"sharding on the MLIR module propagated around. These are from which "
"operand/result a sharding was propagated to a given op."),
llvm::cl::init(false)};

private:
Expand Down
Loading

0 comments on commit 4fde090

Please sign in to comment.