From bddd47e1e1ecd307bc53cf1a748b7c9a8b51d470 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 16 Jan 2025 02:33:55 -0800 Subject: [PATCH] #sdy #debug refactor how we figure out the new axes that were introduced on a value if a reshape merged split axes into a single merged one. PiperOrigin-RevId: 716150387 --- .../propagation/debugging/source_sharding.cc | 96 ++++++++----------- .../debugging/test/sharding_origins.mlir | 20 ++++ 2 files changed, 58 insertions(+), 58 deletions(-) diff --git a/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc b/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc index 6461e0fc..2519953d 100644 --- a/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc +++ b/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -108,68 +107,49 @@ FactorsToEdgeMap createSourceMap( } }; - MLIRContext* context = mesh.getContext(); - ArrayRef factorSizes = shardingRule.getFactorSizes(); auto visitValue = [&](const TensorFactorShardings& oldValue, const TensorFactorShardings& newValue, EdgeNodeType valueType, int64_t valueIndex, TensorMappingAttr tensorMapping, llvm::SmallVector& valueSourceMap) { - DenseSet oldAxes; - for (const auto& [_, oldFactorSharding] : oldValue.factorIndexToSharding) { - oldAxes.insert(oldFactorSharding.axisRefs.begin(), - oldFactorSharding.axisRefs.end()); - } - for (const auto& [oldFactorSharding, newFactorSharding] : llvm::zip_equal( - oldValue.factorIndexToSharding, newValue.factorIndexToSharding)) { - const auto& newAxisRefs = newFactorSharding.second.axisRefs; - if (newAxisRefs == oldFactorSharding.second.axisRefs) { - continue; - } - SmallVector newlyIntroducedAxes; - // If multiple sub axes can be merged due to a dimension sharding having - // multiple factors, each sharded on a sub axis, make sure we only save - // the merged one. This can happen during an `(A, B) -> (AB,)` reshape. - TensorShardingAttr tensorSharding = newValue.createTensorShardingAttr( - context, tensorMapping, factorSizes, "", mesh); - for (DimensionShardingAttr dimSharding : - tensorSharding.getDimShardings()) { - llvm::copy_if( - dimSharding.getAxes(), std::back_inserter(newlyIntroducedAxes), - [&](const AxisRefAttr& axisRef) { - // Don't add any axes that were already in the - // old sharding. We just want new axes. - if (oldAxes.contains(axisRef)) { - return false; - } - // We need to avoid any axes that already existed in the old - // sharding, but aren't in the new projection as the conflicted. - // E.g. for a contracting dim matmul, if both the LHS/RHS are - // sharded on the same axis on their respective non-contracting - // dims, the dimension sharding will contain the conflicting axes, - // but the factor sharding will not. And we don't want this axis - // as it isn't a newly introduced axis. - for (AxisRefAttr newAxisRef : newAxisRefs) { - if (newAxisRef.prefixOf(axisRef)) { - return true; - } - } - return false; - }); - } - // This factor sharding has changed, let's find who changed it. - if (std::optional operandSource = - findNewAxisRefMatch(newAxisRefs, oldFactorSharding.first, - oldShardingProjection.getOperands())) { - saveEdges(newlyIntroducedAxes, oldFactorSharding.second.axisRefs, - EdgeNode{EdgeNodeType::OPERAND, *operandSource}, - EdgeNode{valueType, valueIndex}, valueSourceMap[valueIndex]); - } else if (std::optional resultSource = - findNewAxisRefMatch(newAxisRefs, oldFactorSharding.first, - oldShardingProjection.getResults())) { - saveEdges(newlyIntroducedAxes, oldFactorSharding.second.axisRefs, - EdgeNode{EdgeNodeType::RESULT, *resultSource}, - EdgeNode{valueType, valueIndex}, valueSourceMap[valueIndex]); + for (DimMappingAttr dimMapping : tensorMapping.getDimMappings()) { + AxisRefAttr previousAxis; + for (int64_t factorIndex : dimMapping.getFactorIndices()) { + const FactorSharding& oldFactorSharding = + oldValue.factorIndexToSharding.at(factorIndex); + const FactorSharding& newFactorSharding = + newValue.factorIndexToSharding.at(factorIndex); + if (oldFactorSharding.axisRefs == newFactorSharding.axisRefs) { + // No new axes introduced. + continue; + } + // This factor sharding has changed, let's find who changed it. + // + // But first merge any sub-axes. + AxisRefAttr lastNewAxis = newFactorSharding.axisRefs.back(); + if (previousAxis && previousAxis.canMerge(lastNewAxis)) { + valueSourceMap[valueIndex].erase(previousAxis); + lastNewAxis = previousAxis.merge(lastNewAxis, mesh); + } + previousAxis = newFactorSharding.axisRefs.back(); + SmallVector newlyIntroducedAxes = + newFactorSharding.axisRefs; + newlyIntroducedAxes.back() = lastNewAxis; + if (std::optional operandSource = + findNewAxisRefMatch(newFactorSharding.axisRefs, factorIndex, + oldShardingProjection.getOperands())) { + saveEdges(newlyIntroducedAxes, oldFactorSharding.axisRefs, + EdgeNode{EdgeNodeType::OPERAND, *operandSource}, + EdgeNode{valueType, valueIndex}, + valueSourceMap[valueIndex]); + } else if (std::optional resultSource = findNewAxisRefMatch( + newFactorSharding.axisRefs, factorIndex, + oldShardingProjection.getResults())) { + saveEdges(newlyIntroducedAxes, oldFactorSharding.axisRefs, + EdgeNode{EdgeNodeType::RESULT, *resultSource}, + EdgeNode{valueType, valueIndex}, + valueSourceMap[valueIndex]); + } } } }; diff --git a/shardy/dialect/sdy/transforms/propagation/debugging/test/sharding_origins.mlir b/shardy/dialect/sdy/transforms/propagation/debugging/test/sharding_origins.mlir index 9935aed6..3c49ee73 100644 --- a/shardy/dialect/sdy/transforms/propagation/debugging/test/sharding_origins.mlir +++ b/shardy/dialect/sdy/transforms/propagation/debugging/test/sharding_origins.mlir @@ -421,3 +421,23 @@ func.func @sub_axes_merge_after_propagation_step( %0 = stablehlo.add %arg0, %arg0 : tensor<16xf32> return %0 : tensor<16xf32> } + +// CHECK-LABEL: already_split_sub_axis_result_reshape +// CHECK-SAME: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", "c", ?}]>, +// CHECK-SAME: sdy.sharding_origins = {a = "self", c = "self"}} +// CHECK-SAME: ) -> (tensor<4x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", "c":(1)2, ?}, {"c":(2)4, ?}]>, +// CHECK-SAME: sdy.sharding_origins = {a = "input: 0", +// CHECK-SAME: "c:(1)2" = "input: 0", +// CHECK-SAME: "c:(2)2" = "self", +// CHECK-SAME: "c:(2)4" = "input: 0"}}) +func.func @already_split_sub_axis_result_reshape( + %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", "c", ?}]>}) + -> (tensor<4x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"c":(2)2, ?}]>}) { + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %arg0 { + // CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "c":(1)2, ?}, {"c":(2)4, ?}]>]>, + // CHECK-SAME: sdy.sharding_origins = [{a = "input: 0", "c:(1)2" = "input: 0", + // CHECK-SAME: "c:(2)2" = "output: 0", "c:(2)4" = "input: 0"}] + // CHECK-SAME: } : (tensor<16xf32>) -> tensor<4x4xf32> + %0 = stablehlo.reshape %arg0 : (tensor<16xf32>) -> tensor<4x4xf32> + return %0 : tensor<4x4xf32> +}