Skip to content

Commit

Permalink
#sdy cleanup sdy.all_reduce
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721093677
  • Loading branch information
Google-ML-Automation authored and copybara-github committed Jan 29, 2025
1 parent 83f5d87 commit ca1ff47
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 15 deletions.
4 changes: 2 additions & 2 deletions shardy/dialect/sdy/ir/attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -796,8 +796,8 @@ def Sdy_TensorSharding : AttrDef<Sdy_Dialect, "TensorSharding"> {

// Returns true if axes of all dimensions are the same.
bool areDimAxesEqual(TensorShardingAttr otherSharding) const {
auto left = getDimShardings();
auto right = otherSharding.getDimShardings();
ArrayRef<DimensionShardingAttr> left = getDimShardings();
ArrayRef<DimensionShardingAttr> right = otherSharding.getDimShardings();
return left.size() == right.size() &&
llvm::all_of(llvm::zip_equal(left, right),
[](auto&& pair) {
Expand Down
8 changes: 4 additions & 4 deletions shardy/dialect/sdy/ir/test/collective_parse_print.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ func.func @all_reduce_many_axes(%arg0 : tensor<16x2xf32> {sdy.sharding=#sdy.shar
return %0 : tensor<16x2xf32>
}

// CHECK-LABEL: func @all_reduce_split_axis
func.func @all_reduce_split_axis(%arg0 : tensor<16x32xf32> {sdy.sharding=#sdy.sharding<@mesh7, [{"y"}, {"x": (2)4}]>}) -> tensor<16x32xf32> {
// CHECK-LABEL: func @all_reduce_sub_axis
func.func @all_reduce_sub_axis(%arg0 : tensor<16x32xf32> {sdy.sharding=#sdy.sharding<@mesh7, [{"y"}, {"x": (2)4}]>}) -> tensor<16x32xf32> {
// CHECK-NEXT: sdy.all_reduce {"x":(1)2} %arg0 out_sharding=<@mesh7, [{"y"}, {"x":(2)4}]> : tensor<16x32xf32>
%0 = sdy.all_reduce {"x":(1)2} %arg0 out_sharding=<@mesh7, [{"y"}, {"x":(2)4}]> : tensor<16x32xf32>
return %0 : tensor<16x32xf32>
Expand All @@ -241,8 +241,8 @@ func.func @all_reduce_split_axis_y(%arg0 : tensor<16x32xf32> {sdy.sharding=#sdy.
return %0 : tensor<16x32xf32>
}

// CHECK-LABEL: func @all_reduce_output_is_explicitely_replicated
func.func @all_reduce_output_is_explicitely_replicated(%arg0 : tensor<16x2xf32> {sdy.sharding=#sdy.sharding<@mesh2, [{}, {"x", "y"}]>}) -> tensor<16x2xf32> {
// CHECK-LABEL: func @all_reduce_output_is_explicitly_replicated
func.func @all_reduce_output_is_explicitly_replicated(%arg0 : tensor<16x2xf32> {sdy.sharding=#sdy.sharding<@mesh2, [{}, {"x", "y"}]>}) -> tensor<16x2xf32> {
// CHECK-NEXT: sdy.all_reduce {} %arg0 out_sharding=<@mesh2, [{}, {"x", "y"}], replicated={"z"}> : tensor<16x2xf32>
%0 = sdy.all_reduce {} %arg0 out_sharding=<@mesh2, [{}, {"x", "y"}], replicated={"z"}> : tensor<16x2xf32>
return %0 : tensor<16x2xf32>
Expand Down
14 changes: 7 additions & 7 deletions shardy/dialect/sdy/ir/test/collective_verification.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ func.func @all_reduce_mismatch_output_mesh(%arg0 : tensor<16x2xf32> {sdy.shardin
sdy.mesh @mesh= <["x"=2, "y"=8, "z"=2]>

func.func @all_reduce_overlapping_part_axis(%arg0 : tensor<16x32xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> tensor<16x32xf32> {
// expected-error@+1 {{'sdy.all_reduce' op reduction axis overlaps with operand sharding: "y":(2)2}}
// expected-error@+1 {{'sdy.all_reduce' op reduction axis "y":(2)2 overlaps with operand sharding}}
%0 = sdy.all_reduce {"y":(2)2} %arg0 out_sharding=<@mesh, [{"y"}, {"x"}]> : tensor<16x32xf32>
return %0 : tensor<16x32xf32>
}
Expand All @@ -481,15 +481,15 @@ func.func @all_reduce_overlapping_part_axis(%arg0 : tensor<16x32xf32> {sdy.shard
sdy.mesh @mesh= <["x"=2, "y"=8, "y2"=8, "z"=2]>

func.func @all_reduce_overlapping_axis_minor(%arg0 : tensor<16x32xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"y", "y2"}, {"x"}]>}) -> tensor<16x32xf32> {
// expected-error@+1 {{'sdy.all_reduce' op reduction axis overlaps with operand sharding: "y2"}}
// expected-error@+1 {{'sdy.all_reduce' op reduction axis "y2" overlaps with operand sharding}}
%0 = sdy.all_reduce {"y2"} %arg0 out_sharding=<@mesh, [{"y", "y2"}, {"x"}]> : tensor<16x32xf32>
return %0 : tensor<16x32xf32>
}
// -----
sdy.mesh @mesh= <["x"=2, "y"=8, "y2"=8, "z"=2]>

func.func @all_reduce_overlapping_axis_major(%arg0 : tensor<16x32xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"y", "y2"}, {"x"}]>}) -> tensor<16x32xf32> {
// expected-error@+1 {{'sdy.all_reduce' op reduction axis overlaps with operand sharding: "y"}}
// expected-error@+1 {{'sdy.all_reduce' op reduction axis "y" overlaps with operand sharding}}
%0 = sdy.all_reduce {"y"} %arg0 out_sharding=<@mesh, [{"y", "y2"}, {"x"}]> : tensor<16x32xf32>
return %0 : tensor<16x32xf32>
}
Expand Down Expand Up @@ -519,8 +519,8 @@ func.func @all_reduce_on_operand_without_sharding(%arg0 : tensor<16x2xf32>) -> t
sdy.mesh @mesh = <["x"=4, "y"=2]>

func.func @all_reduce_reduction_axes_can_be_merged(%arg0 : tensor<16x2xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> tensor<16x2xf32> {
// expected-error @+1 {{'sdy.all_reduce' op operand and result sharding have different axes}}
%0 = sdy.all_reduce {"x":(1)2, "x":(2)2} %arg0 out_sharding=<@mesh, [{"y"}, {}]> : tensor<16x2xf32>
// expected-error @+1 {{'sdy.all_reduce' op two consecutive sub-axes can be merged: "x":(1)2, "x":(2)2}}
%0 = sdy.all_reduce {"x":(1)2, "x":(2)2} %arg0 out_sharding=<@mesh, [{"y"}, {"x"}]> : tensor<16x2xf32>
return %0 : tensor<16x2xf32>
}

Expand All @@ -529,8 +529,8 @@ func.func @all_reduce_reduction_axes_can_be_merged(%arg0 : tensor<16x2xf32> {sdy
sdy.mesh @mesh = <["x"=2, "y"=2]>

func.func @all_reduce_duplicate_reduction_axes_across_dims(%arg0 : tensor<16x2xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> tensor<16x2xf32> {
// expected-error @+1 {{'sdy.all_reduce' op operand and result sharding have different axes}}
%0 = sdy.all_reduce {"x", "x"} %arg0 out_sharding=<@mesh, [{}, {}]> : tensor<16x2xf32>
// expected-error @+1 {{'sdy.all_reduce' op duplicate axis ref: "x"}}
%0 = sdy.all_reduce {"x", "x"} %arg0 out_sharding=<@mesh, [{"y"}, {"x"}]> : tensor<16x2xf32>
return %0 : tensor<16x2xf32>
}

Expand Down
5 changes: 3 additions & 2 deletions shardy/dialect/sdy/ir/verifiers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1383,8 +1383,9 @@ LogicalResult AllReduceOp::verify() {
if (operandSharding.anyOfAxisRef([reductionAxisRef](AxisRefAttr axisRef) {
return axisRef.overlaps(reductionAxisRef);
})) {
return emitOpError("reduction axis overlaps with operand sharding: ")
<< reductionAxisRef.toString();
return emitOpError("reduction axis ")
<< reductionAxisRef.toString()
<< " overlaps with operand sharding";
}
}

Expand Down

0 comments on commit ca1ff47

Please sign in to comment.