Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#sdy #debug support debugging tool to save the edge source shardings. #310

Merged
merged 1 commit into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions shardy/dialect/sdy/ir/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,30 @@ inline constexpr StringRef kShardingRuleAttr = "sdy.sharding_rule";
// caused a value to be sharded a certain way.
inline constexpr StringRef kShardingOriginsAttr = "sdy.sharding_origins";

// Attribute name for saving which operand/result sharding of an op caused its
// value to be sharded a certain way.
inline constexpr StringRef kPropagationEdgesAttr = "sdy.propagation_edges";

// Attribute name like `kShardingOriginsAttr` but for
// `ShardableDataFlowOpInterface` op block arguments.
inline constexpr StringRef kBlockArgShardingOriginsAttr =
"sdy.block_arg_sharding_origins";

// Attribute name like `kPropagationEdgesAttr` but for
// `ShardableDataFlowOpInterface` op block arguments.
inline constexpr StringRef kBlockArgPropagationEdgesAttr =
"sdy.block_arg_propagation_edges";

// Attribute name like `kShardingOriginsAttr` but for
// `ShardableDataFlowOpInterface` op results.
inline constexpr StringRef kResultShardingOriginsAttr =
"sdy.result_sharding_origins";

// Attribute name like `kPropagationEdgesAttr` but for
// `ShardableDataFlowOpInterface` op results.
inline constexpr StringRef kResultPropagationEdgesAttr =
"sdy.result_propagation_edges";

// Attribute name for the unique name of a sharding origin. Is either an
// `sdy.sharding_constraint`, or `sdy.ManualComputationOp` input/output.
inline constexpr StringRef kShardingOriginNameAttr = "sdy.sharding_origin_name";
Expand Down
1 change: 1 addition & 0 deletions shardy/dialect/sdy/transforms/export/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ cc_library(
"//shardy/dialect/sdy/transforms/propagation:op_sharding_rule_registry",
"//shardy/dialect/sdy/transforms/propagation:sharding_projection",
"//shardy/dialect/sdy/transforms/propagation:utils",
"//shardy/dialect/sdy/transforms/propagation/debugging:source_sharding",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down
13 changes: 13 additions & 0 deletions shardy/dialect/sdy/transforms/export/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ def SinkDataFlowEdgesPass : Pass<"sdy-sink-data-flow-edges", "func::FuncOp"> {
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
//TODO(tomnatan): consider moving the sharding to all targets that can have a sharding attached.

let options = [
Option<"sinkDebugShardingOrigins", "sink-debug-sharding-origins", "bool",
/*default=*/"false",
"Whether to sink the debug sharding origins info. See "
"`debug-sharding-origins` option in propagation for more info.">,
Option<"sinkDebugPropagationEdgeSharding",
"sink-debug-propagation-edge-sharding", "bool",
/*default=*/"false",
"Whether to sink the debug propagation edge sharding info. See "
"`debug-propagation-edge-sharding` option in propagation for more "
"info.">
];
}

def UpdateNonDivisibleInputOutputShardingsPass : Pass<"sdy-update-non-divisible-input-output-shardings", "func::FuncOp"> {
Expand Down
51 changes: 8 additions & 43 deletions shardy/dialect/sdy/transforms/export/sink_data_flow_edges.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ limitations under the License.
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
#include "mlir/Support/LLVM.h"
#include "shardy/dialect/sdy/ir/constants.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "shardy/dialect/sdy/transforms/export/passes.h" // IWYU pragma: keep
#include "shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.h"

namespace mlir {
namespace sdy {
Expand Down Expand Up @@ -81,44 +82,6 @@ SmallVector<TensorShardingAttr> getShardingsFromDataFlowEdges(
return shardings;
}

// Saves an array of all the origin sharding dictionaries for the given
// `edgeOwners` on `op`. If non exist, nothing is saved.
//
// For debugging the origin shardings, we want to preserve the origin sharding
// dictionaries from the `DataFlowEdgeOp`s on the owning op so that they are
// preserved after the propagation pipeline.
//
// See the `debug-sharding-origins` config on propagation for more details.
//
// TODO(b/388458831): add `saveDebugPropagationInfo` to the pass and pass it in
// here. Can then reserve the right size for `originShardingDicts` and not need
// the `exists` boolean.
void buildOriginShardingDictsFromDataFlowEdges(ValueRange edgeOwners,
Operation* op,
StringRef attrName,
IRRewriter& rewriter) {
SmallVector<Attribute> originShardingDicts;
// TODO(b/388458831): pass through a boolean indicating whether the origin
// sharding debug information is enabled.
bool exists = false;
for (Value edgeOwner : edgeOwners) {
DictionaryAttr dict;
if (auto dataFlowEdgeOp = DataFlowEdgeOp::lookup(edgeOwner)) {
dict =
dataFlowEdgeOp->getAttrOfType<DictionaryAttr>(kShardingOriginsAttr);
}
if (!dict) {
dict = rewriter.getDictionaryAttr({});
} else {
exists = true;
}
originShardingDicts.push_back(dict);
}
if (exists) {
op->setAttr(attrName, rewriter.getArrayAttr(originShardingDicts));
}
}

struct SinkDataFlowEdgesPass
: public impl::SinkDataFlowEdgesPassBase<SinkDataFlowEdgesPass> {
using SinkDataFlowEdgesPassBase::SinkDataFlowEdgesPassBase;
Expand Down Expand Up @@ -152,17 +115,19 @@ struct SinkDataFlowEdgesPass
shardableDataFlowOp.setBlockArgumentEdgeOwnerShardings(
blockArgShardings);
}
buildOriginShardingDictsFromDataFlowEdges(
blockArgOwners, op, kBlockArgShardingOriginsAttr, rewriter);
saveDebugInfoDictsFromDataFlowEdges(
blockArgOwners, op, sinkDebugShardingOrigins,
sinkDebugPropagationEdgeSharding, EdgeNodeType::OPERAND, rewriter);

ResultRange resultOwners = shardableDataFlowOp.getOpResultEdgeOwners();
if (SmallVector<TensorShardingAttr> resultShardings =
getShardingsFromDataFlowEdges(resultOwners);
!resultShardings.empty()) {
shardableDataFlowOp.setOpResultEdgeOwnerShardings(resultShardings);
}
buildOriginShardingDictsFromDataFlowEdges(
resultOwners, op, kResultShardingOriginsAttr, rewriter);
saveDebugInfoDictsFromDataFlowEdges(
resultOwners, op, sinkDebugShardingOrigins,
sinkDebugPropagationEdgeSharding, EdgeNodeType::RESULT, rewriter);
return WalkResult::advance();
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ LogicalResult propagateTensorShardings(
if (context->hasActionHandler()) {
context->executeAction<SourceShardingAction>(
updateShardings,
/*IRUnits=*/{op}, operandsParams.tensors, resultsParams.tensors, mesh,
shardingRule, shardingProjection);
/*IRUnits=*/{op}, op, operandsParams.tensors, resultsParams.tensors,
mesh, shardingRule, shardingProjection, anyUpdated);
} else {
updateShardings();
}
Expand Down Expand Up @@ -649,7 +649,8 @@ void BasicPropagationPassImpl::runOnOperation() {
MLIRContext& context = getContext();

// Prepare debugging handler for sharding origins and edge sources.
ShardingDebugMappings mappings(debugShardingOrigins, debugEdgeSourceSharding);
ShardingDebugMappings mappings(debugShardingOrigins,
debugPropagationEdgeSharding);
SourceShardingHandler handler(&mappings);
handler.prepareHandler(moduleOp);

Expand Down Expand Up @@ -683,7 +684,7 @@ void BasicPropagationPassImpl::setPropagationOptions(
dumpDirectory = options.dumpDirectory.str();
conservativePropagation = options.conservativePropagation;
debugShardingOrigins = options.debugShardingOrigins;
debugEdgeSourceSharding = options.debugEdgeSourceSharding;
debugPropagationEdgeSharding = options.debugPropagationEdgeSharding;
}

std::unique_ptr<Pass> createBasicPropagationPass(
Expand Down
10 changes: 5 additions & 5 deletions shardy/dialect/sdy/transforms/propagation/basic_propagation.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ class BasicPropagationPassImpl : public OperationPass<ModuleOp> {
"before propagation."),
llvm::cl::init(false)};

Option<bool> debugEdgeSourceSharding{
*this, "debug-edge-source-sharding",
Option<bool> debugPropagationEdgeSharding{
*this, "debug-propagation-edge-sharding",
llvm::cl::desc(
"whether to save information about the edge source of a sharding "
"on the MLIR module. These are from which operand/result a sharding "
"was propagated."),
"whether to save information about the SSA value edges of how a "
"sharding on the MLIR module propagated around. These are from which "
"operand/result a sharding was propagated to a given op."),
llvm::cl::init(false)};

private:
Expand Down
Loading