Skip to content

Commit

Permalink
#sdy #debug refactor how we figure out the new axes that were introdu…
Browse files Browse the repository at this point in the history
…ced on a value if a reshape merged split axes into a single merged one.

PiperOrigin-RevId: 716150387
  • Loading branch information
bartchr808 authored and copybara-github committed Feb 10, 2025
1 parent 1280d8e commit bddd47e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ limitations under the License.
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <iterator>
#include <optional>
#include <string>

Expand Down Expand Up @@ -108,68 +107,49 @@ FactorsToEdgeMap createSourceMap(
}
};

MLIRContext* context = mesh.getContext();
ArrayRef<int64_t> factorSizes = shardingRule.getFactorSizes();
auto visitValue = [&](const TensorFactorShardings& oldValue,
const TensorFactorShardings& newValue,
EdgeNodeType valueType, int64_t valueIndex,
TensorMappingAttr tensorMapping,
llvm::SmallVector<AxisToEdgeMap>& valueSourceMap) {
DenseSet<AxisRefAttr> oldAxes;
for (const auto& [_, oldFactorSharding] : oldValue.factorIndexToSharding) {
oldAxes.insert(oldFactorSharding.axisRefs.begin(),
oldFactorSharding.axisRefs.end());
}
for (const auto& [oldFactorSharding, newFactorSharding] : llvm::zip_equal(
oldValue.factorIndexToSharding, newValue.factorIndexToSharding)) {
const auto& newAxisRefs = newFactorSharding.second.axisRefs;
if (newAxisRefs == oldFactorSharding.second.axisRefs) {
continue;
}
SmallVector<AxisRefAttr> newlyIntroducedAxes;
// If multiple sub axes can be merged due to a dimension sharding having
// multiple factors, each sharded on a sub axis, make sure we only save
// the merged one. This can happen during an `(A, B) -> (AB,)` reshape.
TensorShardingAttr tensorSharding = newValue.createTensorShardingAttr(
context, tensorMapping, factorSizes, "", mesh);
for (DimensionShardingAttr dimSharding :
tensorSharding.getDimShardings()) {
llvm::copy_if(
dimSharding.getAxes(), std::back_inserter(newlyIntroducedAxes),
[&](const AxisRefAttr& axisRef) {
// Don't add any axes that were already in the
// old sharding. We just want new axes.
if (oldAxes.contains(axisRef)) {
return false;
}
// We need to avoid any axes that already existed in the old
// sharding, but aren't in the new projection as the conflicted.
// E.g. for a contracting dim matmul, if both the LHS/RHS are
// sharded on the same axis on their respective non-contracting
// dims, the dimension sharding will contain the conflicting axes,
// but the factor sharding will not. And we don't want this axis
// as it isn't a newly introduced axis.
for (AxisRefAttr newAxisRef : newAxisRefs) {
if (newAxisRef.prefixOf(axisRef)) {
return true;
}
}
return false;
});
}
// This factor sharding has changed, let's find who changed it.
if (std::optional<int64_t> operandSource =
findNewAxisRefMatch(newAxisRefs, oldFactorSharding.first,
oldShardingProjection.getOperands())) {
saveEdges(newlyIntroducedAxes, oldFactorSharding.second.axisRefs,
EdgeNode{EdgeNodeType::OPERAND, *operandSource},
EdgeNode{valueType, valueIndex}, valueSourceMap[valueIndex]);
} else if (std::optional<int64_t> resultSource =
findNewAxisRefMatch(newAxisRefs, oldFactorSharding.first,
oldShardingProjection.getResults())) {
saveEdges(newlyIntroducedAxes, oldFactorSharding.second.axisRefs,
EdgeNode{EdgeNodeType::RESULT, *resultSource},
EdgeNode{valueType, valueIndex}, valueSourceMap[valueIndex]);
for (DimMappingAttr dimMapping : tensorMapping.getDimMappings()) {
AxisRefAttr previousAxis;
for (int64_t factorIndex : dimMapping.getFactorIndices()) {
const FactorSharding& oldFactorSharding =
oldValue.factorIndexToSharding.at(factorIndex);
const FactorSharding& newFactorSharding =
newValue.factorIndexToSharding.at(factorIndex);
if (oldFactorSharding.axisRefs == newFactorSharding.axisRefs) {
// No new axes introduced.
continue;
}
// This factor sharding has changed, let's find who changed it.
//
// But first merge any sub-axes.
AxisRefAttr lastNewAxis = newFactorSharding.axisRefs.back();
if (previousAxis && previousAxis.canMerge(lastNewAxis)) {
valueSourceMap[valueIndex].erase(previousAxis);
lastNewAxis = previousAxis.merge(lastNewAxis, mesh);
}
previousAxis = newFactorSharding.axisRefs.back();
SmallVector<AxisRefAttr> newlyIntroducedAxes =
newFactorSharding.axisRefs;
newlyIntroducedAxes.back() = lastNewAxis;
if (std::optional<int64_t> operandSource =
findNewAxisRefMatch(newFactorSharding.axisRefs, factorIndex,
oldShardingProjection.getOperands())) {
saveEdges(newlyIntroducedAxes, oldFactorSharding.axisRefs,
EdgeNode{EdgeNodeType::OPERAND, *operandSource},
EdgeNode{valueType, valueIndex},
valueSourceMap[valueIndex]);
} else if (std::optional<int64_t> resultSource = findNewAxisRefMatch(
newFactorSharding.axisRefs, factorIndex,
oldShardingProjection.getResults())) {
saveEdges(newlyIntroducedAxes, oldFactorSharding.axisRefs,
EdgeNode{EdgeNodeType::RESULT, *resultSource},
EdgeNode{valueType, valueIndex},
valueSourceMap[valueIndex]);
}
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,23 @@ func.func @sub_axes_merge_after_propagation_step(
%0 = stablehlo.add %arg0, %arg0 : tensor<16xf32>
return %0 : tensor<16xf32>
}

// CHECK-LABEL: already_split_sub_axis_result_reshape
// CHECK-SAME: %arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", "c", ?}]>,
// CHECK-SAME: sdy.sharding_origins = {a = "self", c = "self"}}
// CHECK-SAME: ) -> (tensor<4x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", "c":(1)2, ?}, {"c":(2)4, ?}]>,
// CHECK-SAME: sdy.sharding_origins = {a = "input: 0",
// CHECK-SAME: "c:(1)2" = "input: 0",
// CHECK-SAME: "c:(2)2" = "self",
// CHECK-SAME: "c:(2)4" = "input: 0"}})
func.func @already_split_sub_axis_result_reshape(
%arg0: tensor<16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", "c", ?}]>})
-> (tensor<4x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"c":(2)2, ?}]>}) {
// CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %arg0 {
// CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "c":(1)2, ?}, {"c":(2)4, ?}]>]>,
// CHECK-SAME: sdy.sharding_origins = [{a = "input: 0", "c:(1)2" = "input: 0",
// CHECK-SAME: "c:(2)2" = "output: 0", "c:(2)4" = "input: 0"}]
// CHECK-SAME: } : (tensor<16xf32>) -> tensor<4x4xf32>
%0 = stablehlo.reshape %arg0 : (tensor<16xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}

0 comments on commit bddd47e

Please sign in to comment.