From 24442f94262d1e455e70eff4e4082436a29f2c91 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Fri, 7 Feb 2025 02:06:17 -0800 Subject: [PATCH] #sdy Add `permutation_factor` in op sharding rule. If we partition along this factor, the partitioner will add collective-permute operations. Several potential examples * slicing dimensions in slice, dynamic-slice, dynamic-update-slice * padding dimensions in padd operations PiperOrigin-RevId: 724262915 --- shardy/dialect/sdy/ir/attrs.td | 24 ++++++++++++++----- shardy/dialect/sdy/ir/dialect.cc | 6 +++++ shardy/dialect/sdy/ir/dialect_test.cc | 11 +++++++-- shardy/dialect/sdy/ir/parsers.cc | 15 ------------ shardy/dialect/sdy/ir/parsers.h | 13 ++++------ shardy/dialect/sdy/ir/printers.cc | 15 ------------ shardy/dialect/sdy/ir/printers.h | 13 ++++------ .../ir/test/sharding_rule_parse_print.mlir | 10 ++++---- .../ir/test/sharding_rule_verification.mlir | 12 ++++++++-- shardy/dialect/sdy/ir/verifiers.cc | 21 ++++++++++------ .../propagation/op_sharding_rule_builder.cc | 5 +++- .../propagation/op_sharding_rule_builder.h | 8 +++++-- shardy/integrations/c/attributes.cc | 14 ++++++++++- shardy/integrations/c/attributes.h | 7 ++++++ shardy/integrations/python/ir/sdy_module.cc | 15 +++++++++--- 15 files changed, 112 insertions(+), 77 deletions(-) diff --git a/shardy/dialect/sdy/ir/attrs.td b/shardy/dialect/sdy/ir/attrs.td index c4ef24a2..f8c5e43f 100644 --- a/shardy/dialect/sdy/ir/attrs.td +++ b/shardy/dialect/sdy/ir/attrs.td @@ -916,6 +916,9 @@ def Sdy_OpShardingRule : AttrDef { such as the contracting dimensions in a dot operation. `need_replication_factors` contains the indices of factors requiring full replication, such as the sorted dimension in a sort operation. + `permutation_factors` contains the indices of factors requiring + collective-permute if they are sharded, such as the padding dimensions in a + pad operation. `is_custom_rule` describes whether this is a rule defined by a user for a `stablehlo.custom_call` op. The partitioner doesn't know how to partition @@ -930,7 +933,8 @@ def Sdy_OpShardingRule : AttrDef { operands/results). - Rank of each `TensorMappingAttr` matches the rank of the corresponding tensor type. - - For each group of factors (`reduction_factors`, `need_replication_factors`): + - For each group of factors (`reduction_factors`, + `need_replication_factors`, `permutation_factors`): * Elements must be in range [0, `$factor_sizes`]. * No duplicate factor indices within each group and across groups. }]; @@ -946,6 +950,8 @@ def Sdy_OpShardingRule : AttrDef { "factors requiring reduction">:$reduction_factors, OptionalArrayRefParameter<"int64_t", "factors requiring full replication">:$need_replication_factors, + OptionalArrayRefParameter<"int64_t", + "factors corresponding to multiple sizes">:$permutation_factors, DefaultValuedParameter<"bool", "false", "whether the rule is for a stablehlo.custom_call">:$is_custom_rule ); @@ -956,8 +962,9 @@ def Sdy_OpShardingRule : AttrDef { `` `->` `` `(`$result_mappings`)` `` custom($factor_sizes) `` - custom($reduction_factors) `` - custom($need_replication_factors) `` + custom($reduction_factors, "\"reduction\"") `` + custom($need_replication_factors, "\"need_replication\"") `` + custom($permutation_factors, "\"permutation\"") `` custom($is_custom_rule) `>` }]; @@ -967,10 +974,11 @@ def Sdy_OpShardingRule : AttrDef { "ArrayRef":$operand_mappings, "ArrayRef":$result_mappings, "ArrayRef":$reduction_factors, - "ArrayRef":$need_replication_factors), [{ + "ArrayRef":$need_replication_factors, + "ArrayRef":$permutation_factors), [{ return $_get($_ctxt, factor_sizes, operand_mappings, result_mappings, reduction_factors, need_replication_factors, - /*is_custom_rule=*/false); + permutation_factors, /*is_custom_rule=*/false); }]> ]; @@ -1001,13 +1009,17 @@ def Sdy_OpShardingRule : AttrDef { // Returns true if the `factorIndex` is a factor requiring full replication. bool isNeedReplicationFactor(int64_t factorIndex) const; + // Returns true if the `factorIndex` is a permutation factor. + bool isPermutationFactor(int64_t factorIndex) const; + // Returns true if the `factorIndex` is a factor in all non-scalar tensors. bool isFactorInAllNonScalarTensors(int64_t factorIndex) const; // Returns true if the `factorIndex` is a batching factor, which satisfies: // 1. It is not a reduction factor. // 2. It is not a need replication factor. - // 3. It is used in all non-scalar tensors. + // 3. It is not a permutation factor. + // 4. It is used in all non-scalar tensors. bool isBatchingFactor(int64_t factorIndex) const; // Returns a vector of tensor indices that are non-scalar, of all operand diff --git a/shardy/dialect/sdy/ir/dialect.cc b/shardy/dialect/sdy/ir/dialect.cc index a912576e..9ce0723b 100644 --- a/shardy/dialect/sdy/ir/dialect.cc +++ b/shardy/dialect/sdy/ir/dialect.cc @@ -954,6 +954,10 @@ bool OpShardingRuleAttr::isNeedReplicationFactor(int64_t factorIndex) const { return llvm::is_contained(getNeedReplicationFactors(), factorIndex); } +bool OpShardingRuleAttr::isPermutationFactor(int64_t factorIndex) const { + return llvm::is_contained(getPermutationFactors(), factorIndex); +} + bool OpShardingRuleAttr::isFactorInAllNonScalarTensors( int64_t factorIndex) const { for (const TensorMappingAttr& tensorMapping : @@ -970,9 +974,11 @@ bool OpShardingRuleAttr::isFactorInAllNonScalarTensors( return true; } +// TODO(b/394881597). Adding a method to return the factor type given the index. bool OpShardingRuleAttr::isBatchingFactor(int64_t factorIndex) const { return !isReductionFactor(factorIndex) && !isNeedReplicationFactor(factorIndex) && + !isPermutationFactor(factorIndex) && isFactorInAllNonScalarTensors(factorIndex); } diff --git a/shardy/dialect/sdy/ir/dialect_test.cc b/shardy/dialect/sdy/ir/dialect_test.cc index ffbf8de2..35b443b2 100644 --- a/shardy/dialect/sdy/ir/dialect_test.cc +++ b/shardy/dialect/sdy/ir/dialect_test.cc @@ -80,10 +80,11 @@ class DialectTest : public ::testing::Test { ArrayRef resultMappings, ArrayRef reductionFactors = {}, ArrayRef needReplicationFactors = {}, - bool isCustomRule = false) { + ArrayRef permutationFactors = {}, bool isCustomRule = false) { return OpShardingRuleAttr::get(&context, factorSizes, operandMappings, resultMappings, reductionFactors, - needReplicationFactors, isCustomRule); + needReplicationFactors, permutationFactors, + isCustomRule); } MLIRContext context; @@ -608,6 +609,7 @@ TEST_F(DialectTest, OpShardingRuleAttrElementWiseOperation) { auto verifyBatchingFactor = [&](int64_t factorIndex) { EXPECT_FALSE(rule.isReductionFactor(factorIndex)); EXPECT_FALSE(rule.isNeedReplicationFactor(factorIndex)); + EXPECT_FALSE(rule.isPermutationFactor(factorIndex)); EXPECT_TRUE(rule.isFactorInAllNonScalarTensors(factorIndex)); EXPECT_TRUE(rule.isBatchingFactor(factorIndex)); }; @@ -636,12 +638,14 @@ TEST_F(DialectTest, OpShardingRuleAttrDotGeneralOperation) { // Verify the first factor is a batching factor. EXPECT_FALSE(rule.isReductionFactor(0)); EXPECT_FALSE(rule.isNeedReplicationFactor(0)); + EXPECT_FALSE(rule.isPermutationFactor(0)); EXPECT_TRUE(rule.isFactorInAllNonScalarTensors(0)); EXPECT_TRUE(rule.isBatchingFactor(0)); auto verifyNonContractingDimension = [&](int64_t factorIndex) { EXPECT_FALSE(rule.isReductionFactor(factorIndex)); EXPECT_FALSE(rule.isNeedReplicationFactor(factorIndex)); + EXPECT_FALSE(rule.isPermutationFactor(factorIndex)); EXPECT_FALSE(rule.isFactorInAllNonScalarTensors(factorIndex)); EXPECT_FALSE(rule.isBatchingFactor(factorIndex)); }; @@ -651,6 +655,7 @@ TEST_F(DialectTest, OpShardingRuleAttrDotGeneralOperation) { // Verify the contracting dimension is a reduction factor. EXPECT_TRUE(rule.isReductionFactor(3)); EXPECT_FALSE(rule.isNeedReplicationFactor(3)); + EXPECT_FALSE(rule.isPermutationFactor(3)); EXPECT_FALSE(rule.isFactorInAllNonScalarTensors(3)); EXPECT_FALSE(rule.isBatchingFactor(3)); } @@ -675,12 +680,14 @@ TEST_F(DialectTest, OpShardingRuleAttrDynamicSlice) { // Verify the first factor is a batching factor. EXPECT_FALSE(rule.isReductionFactor(0)); EXPECT_FALSE(rule.isNeedReplicationFactor(0)); + EXPECT_FALSE(rule.isPermutationFactor(0)); EXPECT_TRUE(rule.isFactorInAllNonScalarTensors(0)); EXPECT_TRUE(rule.isBatchingFactor(0)); auto verifyNonBatchingFactor = [&](int64_t factorIndex) { EXPECT_FALSE(rule.isReductionFactor(factorIndex)); EXPECT_FALSE(rule.isNeedReplicationFactor(factorIndex)); + EXPECT_FALSE(rule.isPermutationFactor(factorIndex)); EXPECT_FALSE(rule.isFactorInAllNonScalarTensors(factorIndex)); EXPECT_FALSE(rule.isBatchingFactor(factorIndex)); }; diff --git a/shardy/dialect/sdy/ir/parsers.cc b/shardy/dialect/sdy/ir/parsers.cc index 82df127a..624ce495 100644 --- a/shardy/dialect/sdy/ir/parsers.cc +++ b/shardy/dialect/sdy/ir/parsers.cc @@ -289,8 +289,6 @@ ParseResult parseFactorSizes(AsmParser& parser, return success(); } -namespace { - // Parses factor sizes. In a OpShardingRule, you could have `, type={k, i}`. // `k` is index 2, while `i` is index 0. Thus factors would be set to [2, 0]. ParseResult parseFactorsWithType(AsmParser& parser, @@ -320,19 +318,6 @@ ParseResult parseFactorsWithType(AsmParser& parser, return success(); } -} // namespace - -ParseResult parseReductionFactors(AsmParser& parser, - SmallVector& reductionFactors) { - return parseFactorsWithType(parser, reductionFactors, "reduction"); -} - -ParseResult parseNeedReplicationFactors( - AsmParser& parser, SmallVector& needReplicationFactors) { - return parseFactorsWithType(parser, needReplicationFactors, - "need_replication"); -} - ParseResult parseIsCustomRule(AsmParser& parser, bool& isCustomRule) { isCustomRule = false; if (!parser.parseOptionalComma()) { diff --git a/shardy/dialect/sdy/ir/parsers.h b/shardy/dialect/sdy/ir/parsers.h index 24581cc2..fccc8495 100644 --- a/shardy/dialect/sdy/ir/parsers.h +++ b/shardy/dialect/sdy/ir/parsers.h @@ -38,15 +38,10 @@ ParseResult parseMeshOrRef(AsmParser& parser, Attribute& meshOrRef); ParseResult parseFactorSizes(AsmParser& parser, SmallVector& factorSizes); -// Parses the reduction factors of an OpShardingRule. We expect to parse -// `reduction={i, k}` into a vector [0, 2]. -ParseResult parseReductionFactors(AsmParser& parser, - SmallVector& reductionFactors); - -// Parses the factors needing replication of an OpShardingRule. We expect to -// parse `need_replication={i, k}` into a vector [0, 2]. -ParseResult parseNeedReplicationFactors( - AsmParser& parser, SmallVector& needReplicationFactors); +// Parses a list of `factors` of `type` in an OpShardingRule. We expect to parse +// `type={i, k}` into a vector [0, 2]. +ParseResult parseFactorsWithType(AsmParser& parser, + SmallVector& factors, StringRef type); ParseResult parseIsCustomRule(AsmParser& parser, bool& isCustomRule); diff --git a/shardy/dialect/sdy/ir/printers.cc b/shardy/dialect/sdy/ir/printers.cc index 58f1668b..538b29e0 100644 --- a/shardy/dialect/sdy/ir/printers.cc +++ b/shardy/dialect/sdy/ir/printers.cc @@ -74,8 +74,6 @@ void printFactorSizes(AsmPrinter& printer, ArrayRef factorSizes) { printer << "}"; } -namespace { - void printFactorsWithType(AsmPrinter& printer, ArrayRef factors, StringRef type) { if (factors.empty()) { @@ -88,19 +86,6 @@ void printFactorsWithType(AsmPrinter& printer, ArrayRef factors, printer << "}"; } -} // namespace - -void printReductionFactors(AsmPrinter& printer, - ArrayRef reductionFactors) { - return printFactorsWithType(printer, reductionFactors, "reduction"); -} - -void printNeedReplicationFactors(AsmPrinter& printer, - ArrayRef needReplicationFactors) { - return printFactorsWithType(printer, needReplicationFactors, - "need_replication"); -} - void printIsCustomRule(AsmPrinter& printer, bool isCustomRule) { if (isCustomRule) { printer << ", custom"; diff --git a/shardy/dialect/sdy/ir/printers.h b/shardy/dialect/sdy/ir/printers.h index 6a390e21..62af9301 100644 --- a/shardy/dialect/sdy/ir/printers.h +++ b/shardy/dialect/sdy/ir/printers.h @@ -38,15 +38,10 @@ void printMeshOrRef(AsmPrinter& printer, Attribute meshOrRef); // printed as `{i=6, j=2, k=4}`. void printFactorSizes(AsmPrinter& printer, ArrayRef factorSizes); -// Prints the reduction factors of an OpShardingRule. Given a vector [0, 2], we -// print `reduction={i, k}`. -void printReductionFactors(AsmPrinter& printer, - ArrayRef reductionFactors); - -// Prints the factors needing replication of an OpShardingRule. Given a vector -// [0, 2], we print `need_replication={i, k}`. -void printNeedReplicationFactors(AsmPrinter& printer, - ArrayRef needReplicationFactors); +// Prints the `factors` of `type` in an OpShardingRule. Given a vector [0, 2], +// we print `type={i, k}`. +void printFactorsWithType(AsmPrinter& printer, ArrayRef factors, + StringRef type); void printIsCustomRule(AsmPrinter& printer, bool isCustomRule); diff --git a/shardy/dialect/sdy/ir/test/sharding_rule_parse_print.mlir b/shardy/dialect/sdy/ir/test/sharding_rule_parse_print.mlir index 7b4b9410..bc7dbf77 100644 --- a/shardy/dialect/sdy/ir/test/sharding_rule_parse_print.mlir +++ b/shardy/dialect/sdy/ir/test/sharding_rule_parse_print.mlir @@ -47,9 +47,9 @@ func.func @custom_call_custom_rule(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32 func.return %0: tensor<16x32xf32> } -// CHECK-LABEL: func @reduction_and_need_replication_factors -func.func @reduction_and_need_replication_factors(%arg0: tensor<2x3x5x7xf32>) -> tensor<2x5x7xf32> { - // CHECK: {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} reduction={j} need_replication={i, l}, custom>} - %0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} reduction={j} need_replication={i, l}, custom>} : (tensor<2x3x5x7xf32>) -> tensor<2x5x7xf32> - func.return %0: tensor<2x5x7xf32> +// CHECK-LABEL: func @special_factors +func.func @special_factors(%arg0: tensor<2x3x5x7xf32>) -> tensor<2x11x7xf32> { + // CHECK: {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} reduction={j} need_replication={i, l} permutation={k}, custom>} + %0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} reduction={j} need_replication={i, l} permutation={k}, custom>} : (tensor<2x3x5x7xf32>) -> tensor<2x11x7xf32> + func.return %0: tensor<2x11x7xf32> } diff --git a/shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir b/shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir index 1efb7654..5cfac632 100644 --- a/shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir +++ b/shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir @@ -136,8 +136,16 @@ func.func @invalid_special_factor_index(%arg0: tensor<2x4x8xf32>) -> tensor<2x8x // ----- -func.func @invalid_special_factor_index(%arg0: tensor<2x4x8xf32>) -> tensor<2x8xf32> { - // expected-error@+1 {{reduction and need_replication factors must be disjoint}} +func.func @a_factor_in_two_special_factor_sets(%arg0: tensor<2x4x8xf32>) -> tensor<2x8xf32> { + // expected-error@+1 {{a factor can only be in one of the special factor sets}} %0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k])->([i, k]) {i=2, j=4, k=8} reduction={j} need_replication={j}>} : (tensor<2x4x8xf32>) -> tensor<2x8xf32> func.return %0: tensor<2x8xf32> } + +// ----- + +func.func @a_factor_in_three_special_factor_sets(%arg0: tensor<2x4x8xf32>) -> tensor<2x8xf32> { + // expected-error@+1 {{a factor can only be in one of the special factor sets}} + %0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k])->([i, k]) {i=2, j=4, k=8} reduction={j} need_replication={j} permutation={j}>} : (tensor<2x4x8xf32>) -> tensor<2x8xf32> + func.return %0: tensor<2x8xf32> +} diff --git a/shardy/dialect/sdy/ir/verifiers.cc b/shardy/dialect/sdy/ir/verifiers.cc index b6da0951..86867ae9 100644 --- a/shardy/dialect/sdy/ir/verifiers.cc +++ b/shardy/dialect/sdy/ir/verifiers.cc @@ -610,6 +610,7 @@ LogicalResult verifyOpShardingRuleAttr(OpShardingRuleAttr shardingRule, ArrayRef reductionFactors = shardingRule.getReductionFactors(); ArrayRef needReplicationFactors = shardingRule.getNeedReplicationFactors(); + ArrayRef permutationFactors = shardingRule.getPermutationFactors(); if (failed(verifyIndicesOfSpecialFactors(op, shardingRule.getNumFactors(), reductionFactors))) { @@ -619,15 +620,21 @@ LogicalResult verifyOpShardingRuleAttr(OpShardingRuleAttr shardingRule, needReplicationFactors))) { return failure(); } + if (failed(verifyIndicesOfSpecialFactors(op, shardingRule.getNumFactors(), + permutationFactors))) { + return failure(); + } - SmallVector intersection; - std::set_intersection(reductionFactors.begin(), reductionFactors.end(), - needReplicationFactors.begin(), - needReplicationFactors.end(), - std::back_inserter(intersection)); - if (!intersection.empty()) { + SmallDenseSet specialFactors; + specialFactors.insert(reductionFactors.begin(), reductionFactors.end()); + specialFactors.insert(needReplicationFactors.begin(), + needReplicationFactors.end()); + specialFactors.insert(permutationFactors.begin(), permutationFactors.end()); + if (specialFactors.size() != reductionFactors.size() + + needReplicationFactors.size() + + permutationFactors.size()) { return op->emitOpError( - "reduction and need_replication factors must be disjoint"); + "a factor can only be in one of the special factor sets"); } return success(); diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.cc index 47f0dcb4..0ca697dc 100644 --- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.cc +++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.cc @@ -117,7 +117,7 @@ OpShardingRuleAttr OpShardingRuleBuilder::build() { auto result = OpShardingRuleAttr::get( context, factorSizes, operandMappingAttrs, resultMappingAttrs, - reductionFactors, needReplicationFactors); + reductionFactors, needReplicationFactors, permutationFactors); // Erase all added factors, to return the builder to its original state before // calling this method. @@ -153,6 +153,9 @@ void OpShardingRuleBuilder::updateFactorType(FactorType factorType, case FactorType::kNeedReplication: needReplicationFactors.push_back(factorIndex); return; + case FactorType::kPermutation: + permutationFactors.push_back(factorIndex); + return; case FactorType::kDefault: return; } diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h index caae661d..3bd9e6c8 100644 --- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h +++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h @@ -50,6 +50,10 @@ enum class FactorType { // If we have sharding along a dimension that needs replication, the // partitioner will make this dimension replicated. kNeedReplication, + + // If we have sharding along a dimension that needs permutation, the + // partitioner will add collective-permute operations. + kPermutation, }; // The factor mappings that compose a dimension of a tensor. @@ -139,8 +143,8 @@ class OpShardingRuleBuilder { SmallVector operandMappings; SmallVector resultMappings; - SmallVector reductionFactors; - SmallVector needReplicationFactors; + SmallVector reductionFactors, needReplicationFactors, + permutationFactors; }; // Creates an identity mapping for an op with `numOperands` operands and diff --git a/shardy/integrations/c/attributes.cc b/shardy/integrations/c/attributes.cc index da256d97..105c8621 100644 --- a/shardy/integrations/c/attributes.cc +++ b/shardy/integrations/c/attributes.cc @@ -315,6 +315,7 @@ MlirAttribute sdyOpShardingRuleAttrGet( intptr_t nResultMappings, const MlirAttribute* resultMappings, intptr_t nReductionFactors, const int64_t* reductionFactors, intptr_t nNeedReplicationFactors, const int64_t* needReplicationFactors, + intptr_t nPermutationFactors, const int64_t* permutationFactors, bool isCustomRule) { return wrap(sdy::OpShardingRuleAttr::get( unwrap(ctx), mlir::ArrayRef(factorSizes, nFactorSizes), @@ -322,7 +323,7 @@ MlirAttribute sdyOpShardingRuleAttrGet( unwrapAttrs(resultMappings, nResultMappings), mlir::ArrayRef(reductionFactors, nReductionFactors), mlir::ArrayRef(needReplicationFactors, nNeedReplicationFactors), - isCustomRule)); + mlir::ArrayRef(permutationFactors, nPermutationFactors), isCustomRule)); } bool sdyOpShardingRuleAttrGetIsCustom(MlirAttribute attr) { @@ -380,6 +381,17 @@ int64_t sdyOpShardingRuleAttrGetNeedReplicationFactorsElem(MlirAttribute attr, .getNeedReplicationFactors()[pos]; } +intptr_t sdyOpShardingRuleAttrGetPermutationFactorsSize(MlirAttribute attr) { + return unwrapAttr(attr) + .getPermutationFactors() + .size(); +} + +int64_t sdyOpShardingRuleAttrGetPermutationFactorsElem(MlirAttribute attr, + intptr_t pos) { + return unwrapAttr(attr).getPermutationFactors()[pos]; +} + //===----------------------------------------------------------------------===// // ManualAxesAttr //===----------------------------------------------------------------------===// diff --git a/shardy/integrations/c/attributes.h b/shardy/integrations/c/attributes.h index 6a002cd2..d27a7e02 100644 --- a/shardy/integrations/c/attributes.h +++ b/shardy/integrations/c/attributes.h @@ -202,6 +202,7 @@ MLIR_CAPI_EXPORTED MlirAttribute sdyOpShardingRuleAttrGet( intptr_t nResultMappings, const MlirAttribute* resultMappings, intptr_t nReductionFactors, const int64_t* reductionFactors, intptr_t nNeedReplicationFactors, const int64_t* needReplicationFactors, + intptr_t nPermutationFactors, const int64_t* permutationFactors, bool isCustomRule); MLIR_CAPI_EXPORTED bool sdyOpShardingRuleAttrGetIsCustom(MlirAttribute attr); @@ -236,6 +237,12 @@ sdyOpShardingRuleAttrGetNeedReplicationFactorsSize(MlirAttribute attr); MLIR_CAPI_EXPORTED int64_t sdyOpShardingRuleAttrGetNeedReplicationFactorsElem( MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED intptr_t +sdyOpShardingRuleAttrGetPermutationFactorsSize(MlirAttribute attr); + +MLIR_CAPI_EXPORTED int64_t sdyOpShardingRuleAttrGetPermutationFactorsElem( + MlirAttribute attr, intptr_t pos); + //===----------------------------------------------------------------------===// // ManualAxesAttr //===----------------------------------------------------------------------===// diff --git a/shardy/integrations/python/ir/sdy_module.cc b/shardy/integrations/python/ir/sdy_module.cc index 0d1b44c5..a0ddbb76 100644 --- a/shardy/integrations/python/ir/sdy_module.cc +++ b/shardy/integrations/python/ir/sdy_module.cc @@ -311,7 +311,8 @@ NB_MODULE(_sdy, m) { const std::vector& operandMappings, const std::vector& resultMappings, const std::vector& reductionFactors, - const std::vector& needReplicationFactors, bool isCustom, + const std::vector& needReplicationFactors, + const std::vector& permutationFactors, bool isCustom, MlirContext ctx) { return cls(sdyOpShardingRuleAttrGet( ctx, factorSizes.size(), factorSizes.data(), @@ -319,12 +320,14 @@ NB_MODULE(_sdy, m) { resultMappings.size(), resultMappings.data(), reductionFactors.size(), reductionFactors.data(), needReplicationFactors.size(), needReplicationFactors.data(), + permutationFactors.size(), permutationFactors.data(), isCustom)); }, nb::arg("cls"), nb::arg("factor_sizes"), nb::arg("operand_mappings"), nb::arg("result_mappings"), nb::arg("reduction_factors") = std::vector(), nb::arg("need_replication_factors") = std::vector(), + nb::arg("permutation_factors") = std::vector(), nb::arg("is_custom") = false, nb::arg("context").none() = nb::none(), "Creates a OpShardingRuleAttr with the factor sizes and mappings for " "operands and results.") @@ -361,11 +364,17 @@ NB_MODULE(_sdy, m) { sdyOpShardingRuleAttrGetReductionFactorsElem); }) .def_property_readonly( - "need_replication_factors", [](MlirAttribute self) { + "need_replication_factors", + [](MlirAttribute self) { return propertyVector( self, sdyOpShardingRuleAttrGetNeedReplicationFactorsSize, sdyOpShardingRuleAttrGetNeedReplicationFactorsElem); - }); + }) + .def_property_readonly("permutation_factors", [](MlirAttribute self) { + return propertyVector( + self, sdyOpShardingRuleAttrGetPermutationFactorsSize, + sdyOpShardingRuleAttrGetPermutationFactorsElem); + }); mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "ManualAxesAttr", sdyAttributeIsAManualAxesAttr)