Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#sdy Add permutation_factor in op sharding rule. #364

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions shardy/dialect/sdy/ir/attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,9 @@ def Sdy_OpShardingRule : AttrDef<Sdy_Dialect, "OpShardingRule"> {
such as the contracting dimensions in a dot operation.
`need_replication_factors` contains the indices of factors requiring full
replication, such as the sorted dimension in a sort operation.
`permutation_factors` contains the indices of factors requiring
collective-permute if they are sharded, such as the padding dimensions in a
pad operation.

`is_custom_rule` describes whether this is a rule defined by a user for a
`stablehlo.custom_call` op. The partitioner doesn't know how to partition
Expand All @@ -930,7 +933,8 @@ def Sdy_OpShardingRule : AttrDef<Sdy_Dialect, "OpShardingRule"> {
operands/results).
- Rank of each `TensorMappingAttr` matches the rank of the corresponding
tensor type.
- For each group of factors (`reduction_factors`, `need_replication_factors`):
- For each group of factors (`reduction_factors`,
`need_replication_factors`, `permutation_factors`):
* Elements must be in range [0, `$factor_sizes`].
* No duplicate factor indices within each group and across groups.
}];
Expand All @@ -946,6 +950,8 @@ def Sdy_OpShardingRule : AttrDef<Sdy_Dialect, "OpShardingRule"> {
"factors requiring reduction">:$reduction_factors,
OptionalArrayRefParameter<"int64_t",
"factors requiring full replication">:$need_replication_factors,
OptionalArrayRefParameter<"int64_t",
"factors corresponding to multiple sizes">:$permutation_factors,
DefaultValuedParameter<"bool", "false",
"whether the rule is for a stablehlo.custom_call">:$is_custom_rule
);
Expand All @@ -956,8 +962,9 @@ def Sdy_OpShardingRule : AttrDef<Sdy_Dialect, "OpShardingRule"> {
`` `->` ``
`(`$result_mappings`)` ``
custom<FactorSizes>($factor_sizes) ``
custom<ReductionFactors>($reduction_factors) ``
custom<NeedReplicationFactors>($need_replication_factors) ``
custom<FactorsWithType>($reduction_factors, "\"reduction\"") ``
custom<FactorsWithType>($need_replication_factors, "\"need_replication\"") ``
custom<FactorsWithType>($permutation_factors, "\"permutation\"") ``
custom<IsCustomRule>($is_custom_rule)
`>`
}];
Expand All @@ -967,10 +974,11 @@ def Sdy_OpShardingRule : AttrDef<Sdy_Dialect, "OpShardingRule"> {
"ArrayRef<TensorMappingAttr>":$operand_mappings,
"ArrayRef<TensorMappingAttr>":$result_mappings,
"ArrayRef<int64_t>":$reduction_factors,
"ArrayRef<int64_t>":$need_replication_factors), [{
"ArrayRef<int64_t>":$need_replication_factors,
"ArrayRef<int64_t>":$permutation_factors), [{
return $_get($_ctxt, factor_sizes, operand_mappings, result_mappings,
reduction_factors, need_replication_factors,
/*is_custom_rule=*/false);
permutation_factors, /*is_custom_rule=*/false);
}]>
];

Expand Down Expand Up @@ -1001,13 +1009,17 @@ def Sdy_OpShardingRule : AttrDef<Sdy_Dialect, "OpShardingRule"> {
// Returns true if the `factorIndex` is a factor requiring full replication.
bool isNeedReplicationFactor(int64_t factorIndex) const;

// Returns true if the `factorIndex` is a permutation factor.
bool isPermutationFactor(int64_t factorIndex) const;

// Returns true if the `factorIndex` is a factor in all non-scalar tensors.
bool isFactorInAllNonScalarTensors(int64_t factorIndex) const;

// Returns true if the `factorIndex` is a batching factor, which satisfies:
// 1. It is not a reduction factor.
// 2. It is not a need replication factor.
// 3. It is used in all non-scalar tensors.
// 3. It is not a permutation factor.
// 4. It is used in all non-scalar tensors.
bool isBatchingFactor(int64_t factorIndex) const;

// Returns a vector of tensor indices that are non-scalar, of all operand
Expand Down
6 changes: 6 additions & 0 deletions shardy/dialect/sdy/ir/dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,10 @@ bool OpShardingRuleAttr::isNeedReplicationFactor(int64_t factorIndex) const {
return llvm::is_contained(getNeedReplicationFactors(), factorIndex);
}

bool OpShardingRuleAttr::isPermutationFactor(int64_t factorIndex) const {
return llvm::is_contained(getPermutationFactors(), factorIndex);
}

bool OpShardingRuleAttr::isFactorInAllNonScalarTensors(
int64_t factorIndex) const {
for (const TensorMappingAttr& tensorMapping :
Expand All @@ -970,9 +974,11 @@ bool OpShardingRuleAttr::isFactorInAllNonScalarTensors(
return true;
}

// TODO(b/394881597). Adding a method to return the factor type given the index.
bool OpShardingRuleAttr::isBatchingFactor(int64_t factorIndex) const {
return !isReductionFactor(factorIndex) &&
!isNeedReplicationFactor(factorIndex) &&
!isPermutationFactor(factorIndex) &&
isFactorInAllNonScalarTensors(factorIndex);
}

Expand Down
11 changes: 9 additions & 2 deletions shardy/dialect/sdy/ir/dialect_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ class DialectTest : public ::testing::Test {
ArrayRef<TensorMappingAttr> resultMappings,
ArrayRef<int64_t> reductionFactors = {},
ArrayRef<int64_t> needReplicationFactors = {},
bool isCustomRule = false) {
ArrayRef<int64_t> permutationFactors = {}, bool isCustomRule = false) {
return OpShardingRuleAttr::get(&context, factorSizes, operandMappings,
resultMappings, reductionFactors,
needReplicationFactors, isCustomRule);
needReplicationFactors, permutationFactors,
isCustomRule);
}

MLIRContext context;
Expand Down Expand Up @@ -608,6 +609,7 @@ TEST_F(DialectTest, OpShardingRuleAttrElementWiseOperation) {
auto verifyBatchingFactor = [&](int64_t factorIndex) {
EXPECT_FALSE(rule.isReductionFactor(factorIndex));
EXPECT_FALSE(rule.isNeedReplicationFactor(factorIndex));
EXPECT_FALSE(rule.isPermutationFactor(factorIndex));
EXPECT_TRUE(rule.isFactorInAllNonScalarTensors(factorIndex));
EXPECT_TRUE(rule.isBatchingFactor(factorIndex));
};
Expand Down Expand Up @@ -636,12 +638,14 @@ TEST_F(DialectTest, OpShardingRuleAttrDotGeneralOperation) {
// Verify the first factor is a batching factor.
EXPECT_FALSE(rule.isReductionFactor(0));
EXPECT_FALSE(rule.isNeedReplicationFactor(0));
EXPECT_FALSE(rule.isPermutationFactor(0));
EXPECT_TRUE(rule.isFactorInAllNonScalarTensors(0));
EXPECT_TRUE(rule.isBatchingFactor(0));

auto verifyNonContractingDimension = [&](int64_t factorIndex) {
EXPECT_FALSE(rule.isReductionFactor(factorIndex));
EXPECT_FALSE(rule.isNeedReplicationFactor(factorIndex));
EXPECT_FALSE(rule.isPermutationFactor(factorIndex));
EXPECT_FALSE(rule.isFactorInAllNonScalarTensors(factorIndex));
EXPECT_FALSE(rule.isBatchingFactor(factorIndex));
};
Expand All @@ -651,6 +655,7 @@ TEST_F(DialectTest, OpShardingRuleAttrDotGeneralOperation) {
// Verify the contracting dimension is a reduction factor.
EXPECT_TRUE(rule.isReductionFactor(3));
EXPECT_FALSE(rule.isNeedReplicationFactor(3));
EXPECT_FALSE(rule.isPermutationFactor(3));
EXPECT_FALSE(rule.isFactorInAllNonScalarTensors(3));
EXPECT_FALSE(rule.isBatchingFactor(3));
}
Expand All @@ -675,12 +680,14 @@ TEST_F(DialectTest, OpShardingRuleAttrDynamicSlice) {
// Verify the first factor is a batching factor.
EXPECT_FALSE(rule.isReductionFactor(0));
EXPECT_FALSE(rule.isNeedReplicationFactor(0));
EXPECT_FALSE(rule.isPermutationFactor(0));
EXPECT_TRUE(rule.isFactorInAllNonScalarTensors(0));
EXPECT_TRUE(rule.isBatchingFactor(0));

auto verifyNonBatchingFactor = [&](int64_t factorIndex) {
EXPECT_FALSE(rule.isReductionFactor(factorIndex));
EXPECT_FALSE(rule.isNeedReplicationFactor(factorIndex));
EXPECT_FALSE(rule.isPermutationFactor(factorIndex));
EXPECT_FALSE(rule.isFactorInAllNonScalarTensors(factorIndex));
EXPECT_FALSE(rule.isBatchingFactor(factorIndex));
};
Expand Down
15 changes: 0 additions & 15 deletions shardy/dialect/sdy/ir/parsers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,6 @@ ParseResult parseFactorSizes(AsmParser& parser,
return success();
}

namespace {

// Parses factor sizes. In a OpShardingRule, you could have `, type={k, i}`.
// `k` is index 2, while `i` is index 0. Thus factors would be set to [2, 0].
ParseResult parseFactorsWithType(AsmParser& parser,
Expand Down Expand Up @@ -320,19 +318,6 @@ ParseResult parseFactorsWithType(AsmParser& parser,
return success();
}

} // namespace

ParseResult parseReductionFactors(AsmParser& parser,
SmallVector<int64_t>& reductionFactors) {
return parseFactorsWithType(parser, reductionFactors, "reduction");
}

ParseResult parseNeedReplicationFactors(
AsmParser& parser, SmallVector<int64_t>& needReplicationFactors) {
return parseFactorsWithType(parser, needReplicationFactors,
"need_replication");
}

ParseResult parseIsCustomRule(AsmParser& parser, bool& isCustomRule) {
isCustomRule = false;
if (!parser.parseOptionalComma()) {
Expand Down
13 changes: 4 additions & 9 deletions shardy/dialect/sdy/ir/parsers.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,10 @@ ParseResult parseMeshOrRef(AsmParser& parser, Attribute& meshOrRef);
ParseResult parseFactorSizes(AsmParser& parser,
SmallVector<int64_t>& factorSizes);

// Parses the reduction factors of an OpShardingRule. We expect to parse
// `reduction={i, k}` into a vector [0, 2].
ParseResult parseReductionFactors(AsmParser& parser,
SmallVector<int64_t>& reductionFactors);

// Parses the factors needing replication of an OpShardingRule. We expect to
// parse `need_replication={i, k}` into a vector [0, 2].
ParseResult parseNeedReplicationFactors(
AsmParser& parser, SmallVector<int64_t>& needReplicationFactors);
// Parses a list of `factors` of `type` in an OpShardingRule. We expect to parse
// `type={i, k}` into a vector [0, 2].
ParseResult parseFactorsWithType(AsmParser& parser,
SmallVector<int64_t>& factors, StringRef type);

ParseResult parseIsCustomRule(AsmParser& parser, bool& isCustomRule);

Expand Down
15 changes: 0 additions & 15 deletions shardy/dialect/sdy/ir/printers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ void printFactorSizes(AsmPrinter& printer, ArrayRef<int64_t> factorSizes) {
printer << "}";
}

namespace {

void printFactorsWithType(AsmPrinter& printer, ArrayRef<int64_t> factors,
StringRef type) {
if (factors.empty()) {
Expand All @@ -88,19 +86,6 @@ void printFactorsWithType(AsmPrinter& printer, ArrayRef<int64_t> factors,
printer << "}";
}

} // namespace

void printReductionFactors(AsmPrinter& printer,
ArrayRef<int64_t> reductionFactors) {
return printFactorsWithType(printer, reductionFactors, "reduction");
}

void printNeedReplicationFactors(AsmPrinter& printer,
ArrayRef<int64_t> needReplicationFactors) {
return printFactorsWithType(printer, needReplicationFactors,
"need_replication");
}

void printIsCustomRule(AsmPrinter& printer, bool isCustomRule) {
if (isCustomRule) {
printer << ", custom";
Expand Down
13 changes: 4 additions & 9 deletions shardy/dialect/sdy/ir/printers.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,10 @@ void printMeshOrRef(AsmPrinter& printer, Attribute meshOrRef);
// printed as `{i=6, j=2, k=4}`.
void printFactorSizes(AsmPrinter& printer, ArrayRef<int64_t> factorSizes);

// Prints the reduction factors of an OpShardingRule. Given a vector [0, 2], we
// print `reduction={i, k}`.
void printReductionFactors(AsmPrinter& printer,
ArrayRef<int64_t> reductionFactors);

// Prints the factors needing replication of an OpShardingRule. Given a vector
// [0, 2], we print `need_replication={i, k}`.
void printNeedReplicationFactors(AsmPrinter& printer,
ArrayRef<int64_t> needReplicationFactors);
// Prints the `factors` of `type` in an OpShardingRule. Given a vector [0, 2],
// we print `type={i, k}`.
void printFactorsWithType(AsmPrinter& printer, ArrayRef<int64_t> factors,
StringRef type);

void printIsCustomRule(AsmPrinter& printer, bool isCustomRule);

Expand Down
10 changes: 5 additions & 5 deletions shardy/dialect/sdy/ir/test/sharding_rule_parse_print.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ func.func @custom_call_custom_rule(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32
func.return %0: tensor<16x32xf32>
}

// CHECK-LABEL: func @reduction_and_need_replication_factors
func.func @reduction_and_need_replication_factors(%arg0: tensor<2x3x5x7xf32>) -> tensor<2x5x7xf32> {
// CHECK: {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} reduction={j} need_replication={i, l}, custom>}
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} reduction={j} need_replication={i, l}, custom>} : (tensor<2x3x5x7xf32>) -> tensor<2x5x7xf32>
func.return %0: tensor<2x5x7xf32>
// CHECK-LABEL: func @special_factors
func.func @special_factors(%arg0: tensor<2x3x5x7xf32>) -> tensor<2x11x7xf32> {
// CHECK: {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} reduction={j} need_replication={i, l} permutation={k}, custom>}
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} reduction={j} need_replication={i, l} permutation={k}, custom>} : (tensor<2x3x5x7xf32>) -> tensor<2x11x7xf32>
func.return %0: tensor<2x11x7xf32>
}
12 changes: 10 additions & 2 deletions shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,16 @@ func.func @invalid_special_factor_index(%arg0: tensor<2x4x8xf32>) -> tensor<2x8x

// -----

func.func @invalid_special_factor_index(%arg0: tensor<2x4x8xf32>) -> tensor<2x8xf32> {
// expected-error@+1 {{reduction and need_replication factors must be disjoint}}
func.func @a_factor_in_two_special_factor_sets(%arg0: tensor<2x4x8xf32>) -> tensor<2x8xf32> {
// expected-error@+1 {{a factor can only be in one of the special factor sets}}
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k])->([i, k]) {i=2, j=4, k=8} reduction={j} need_replication={j}>} : (tensor<2x4x8xf32>) -> tensor<2x8xf32>
func.return %0: tensor<2x8xf32>
}

// -----

func.func @a_factor_in_three_special_factor_sets(%arg0: tensor<2x4x8xf32>) -> tensor<2x8xf32> {
// expected-error@+1 {{a factor can only be in one of the special factor sets}}
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k])->([i, k]) {i=2, j=4, k=8} reduction={j} need_replication={j} permutation={j}>} : (tensor<2x4x8xf32>) -> tensor<2x8xf32>
func.return %0: tensor<2x8xf32>
}
21 changes: 14 additions & 7 deletions shardy/dialect/sdy/ir/verifiers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ LogicalResult verifyOpShardingRuleAttr(OpShardingRuleAttr shardingRule,
ArrayRef<int64_t> reductionFactors = shardingRule.getReductionFactors();
ArrayRef<int64_t> needReplicationFactors =
shardingRule.getNeedReplicationFactors();
ArrayRef<int64_t> permutationFactors = shardingRule.getPermutationFactors();

if (failed(verifyIndicesOfSpecialFactors(op, shardingRule.getNumFactors(),
reductionFactors))) {
Expand All @@ -619,15 +620,21 @@ LogicalResult verifyOpShardingRuleAttr(OpShardingRuleAttr shardingRule,
needReplicationFactors))) {
return failure();
}
if (failed(verifyIndicesOfSpecialFactors(op, shardingRule.getNumFactors(),
permutationFactors))) {
return failure();
}

SmallVector<int64_t> intersection;
std::set_intersection(reductionFactors.begin(), reductionFactors.end(),
needReplicationFactors.begin(),
needReplicationFactors.end(),
std::back_inserter(intersection));
if (!intersection.empty()) {
SmallDenseSet<int64_t> specialFactors;
specialFactors.insert(reductionFactors.begin(), reductionFactors.end());
specialFactors.insert(needReplicationFactors.begin(),
needReplicationFactors.end());
specialFactors.insert(permutationFactors.begin(), permutationFactors.end());
if (specialFactors.size() != reductionFactors.size() +
needReplicationFactors.size() +
permutationFactors.size()) {
return op->emitOpError(
"reduction and need_replication factors must be disjoint");
"a factor can only be in one of the special factor sets");
}

return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ OpShardingRuleAttr OpShardingRuleBuilder::build() {

auto result = OpShardingRuleAttr::get(
context, factorSizes, operandMappingAttrs, resultMappingAttrs,
reductionFactors, needReplicationFactors);
reductionFactors, needReplicationFactors, permutationFactors);

// Erase all added factors, to return the builder to its original state before
// calling this method.
Expand Down Expand Up @@ -153,6 +153,9 @@ void OpShardingRuleBuilder::updateFactorType(FactorType factorType,
case FactorType::kNeedReplication:
needReplicationFactors.push_back(factorIndex);
return;
case FactorType::kPermutation:
permutationFactors.push_back(factorIndex);
return;
case FactorType::kDefault:
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ enum class FactorType {
// If we have sharding along a dimension that needs replication, the
// partitioner will make this dimension replicated.
kNeedReplication,

// If we have sharding along a dimension that needs permutation, the
// partitioner will add collective-permute operations.
kPermutation,
};

// The factor mappings that compose a dimension of a tensor.
Expand Down Expand Up @@ -139,8 +143,8 @@ class OpShardingRuleBuilder {
SmallVector<TensorMapping> operandMappings;
SmallVector<TensorMapping> resultMappings;

SmallVector<int64_t> reductionFactors;
SmallVector<int64_t> needReplicationFactors;
SmallVector<int64_t> reductionFactors, needReplicationFactors,
permutationFactors;
};

// Creates an identity mapping for an op with `numOperands` operands and
Expand Down
Loading