Skip to content

Commit

Permalink
Add interpreter for CollectiveBroadcastOp
Browse files Browse the repository at this point in the history
This mirrors the upstream PR: openxla/stablehlo#1983

PiperOrigin-RevId: 604745617
  • Loading branch information
ghpvnist authored and copybara-github committed Feb 6, 2024
1 parent 8135795 commit f229220
Showing 1 changed file with 324 additions and 0 deletions.
324 changes: 324 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,18 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt

#-------------------------------------------------------------------------------
# Directory setup
diff --ruN a/stablehlo/docs/status.md b/stablehlo/docs/status.md
--- stablehlo/docs/status.md
+++ stablehlo/docs/status.md
@@ -61,7 +61,7 @@
| ceil | yes | yes | yes | yes | yes |
| cholesky | yes | yes | yes | yes | revisit |
| clamp | yes | revisit | yes | yes | yes |
-| collective_broadcast | yes | revisit | yes | no | no |
+| collective_broadcast | yes | revisit | yes | no | yes |
| collective_permute | yes | revisit | yes | no | yes |
| compare | yes | yes | yes | yes | yes |
| complex | yes | yes | yes | yes | yes |
diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt
--- stablehlo/stablehlo/CMakeLists.txt
+++ stablehlo/stablehlo/CMakeLists.txt
Expand Down Expand Up @@ -2548,4 +2560,316 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+} // namespace experimental
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/reference/Ops.cpp b/stablehlo/stablehlo/reference/Ops.cpp
--- stablehlo/stablehlo/reference/Ops.cpp
+++ stablehlo/stablehlo/reference/Ops.cpp
@@ -328,6 +328,25 @@
auto operand = scope.findTensor(clzOp.getOperand());
auto result = evalClzOp(operand, clzOp.getType());
scope.add(clzOp.getResult(), result);
+ } else if (auto collectiveBroadcastOp =
+ dyn_cast<CollectiveBroadcastOp>(op)) {
+ auto operand = scope.findTensor(collectiveBroadcastOp.getOperand());
+
+ auto replicaGroupsAttr = collectiveBroadcastOp.getReplicaGroups();
+ auto replicaGroupsShape = replicaGroupsAttr.getShapedType().getShape();
+ SmallVector<SmallVector<uint32_t>> replicaGroups(replicaGroupsShape[0]);
+ auto replicaGroupsIt = replicaGroupsAttr.getValues<int64_t>().begin();
+ for (auto &replicaGroup : replicaGroups)
+ for (auto i = 0; i < replicaGroupsShape[1]; ++i, ++replicaGroupsIt)
+ replicaGroup.push_back(*replicaGroupsIt);
+
+ ChannelId channelId = 0;
+ if (auto channelHandle = collectiveBroadcastOp.getChannelHandle())
+ channelId = channelHandle->getHandle();
+
+ auto result =
+ evalCollectiveBroadcastOp(operand, replicaGroups, channelId, process);
+ scope.add(collectiveBroadcastOp.getResult(), result);
} else if (auto collectivePermuteOp = dyn_cast<CollectivePermuteOp>(op)) {
auto operand = scope.findTensor(collectivePermuteOp.getOperand());

@@ -1074,6 +1093,28 @@
return result;
}

+Tensor evalCollectiveBroadcastOp(
+ const Tensor &operand, SmallVector<SmallVector<uint32_t>> replicaGroups,
+ ChannelId channelId, Process *process) {
+ if (!process)
+ llvm::report_fatal_error(
+ "collective_broadcast is only supported when run via "
+ "interpreter.run_parallel");
+
+ ProcessGroups processGroups;
+ if (channelId <= 0) processGroups = process->crossReplica(replicaGroups);
+ if (channelId > 0) processGroups = process->crossPartition(replicaGroups);
+
+ auto processGroup = processGroups.findGroup(process->getId());
+ if (processGroup)
+ return process->rendezvous(*processGroup, channelId, operand)
+ .lookup((*processGroup)[0]);
+
+ return evalBroadcastInDimOp(
+ makeScalar(convert(operand.getElementType(), 0.0)), {},
+ operand.getType());
+}
+
Tensor evalCollectivePermuteOp(
const Tensor &operand, SmallVector<SmallVector<uint32_t>> sourceTargetPairs,
ChannelId channelId, Process *process) {
diff --ruN a/stablehlo/stablehlo/reference/Ops.h b/stablehlo/stablehlo/reference/Ops.h
--- stablehlo/stablehlo/reference/Ops.h
+++ stablehlo/stablehlo/reference/Ops.h
@@ -62,6 +62,9 @@
Tensor evalClampOp(const Tensor &min, const Tensor &operand, const Tensor &max,
ShapedType resultType);
Tensor evalClzOp(const Tensor &operand, ShapedType resultType);
+Tensor evalCollectiveBroadcastOp(
+ const Tensor &operand, SmallVector<SmallVector<uint32_t>> replicaGroups,
+ ChannelId channelId, Process *process);
Tensor evalCollectivePermuteOp(
const Tensor &operand, SmallVector<SmallVector<uint32_t>> sourceTargetPairs,
ChannelId channelId, Process *process);
diff --ruN a/stablehlo/stablehlo/reference/ProcessGrid.cpp b/stablehlo/stablehlo/reference/ProcessGrid.cpp
--- stablehlo/stablehlo/reference/ProcessGrid.cpp
+++ stablehlo/stablehlo/reference/ProcessGrid.cpp
@@ -49,8 +49,8 @@

std::optional<ProcessGroup> ProcessGroups::findGroup(ProcessId processId) {
for (auto processGroup : *this)
- for (auto id : processGroup)
- if (id == processId) return processGroup;
+ if (llvm::find(processGroup, processId) != processGroup.end())
+ return processGroup;

return std::nullopt;
}
diff --ruN a/stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir b/stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir
--- stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir
+++ stablehlo/stablehlo/tests/interpret/collective_broadcast.mlir
@@ -0,0 +1,223 @@
+// RUN: stablehlo-translate --interpret -split-input-file %s
+
+module @cross_replica {
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
+ %result = "stablehlo.collective_broadcast"(%operand) {
+ replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
+ channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
+ return %result : tensor<1x2xi64>
+ }
+ func.func @main() {
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
+ programs=[[@collective_broadcast], [@collective_broadcast],
+ [@collective_broadcast], [@collective_broadcast]]
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
+ check.expect_eq_const %results#0, dense<[[0, 0]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
+ func.return
+ }
+}
+
+// -----
+
+module @cross_replica_multiple_output {
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
+ %result = "stablehlo.collective_broadcast"(%operand) {
+ replica_groups = dense<[[2, 1, 0]]> : tensor<1x3xi64>,
+ channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
+ return %result : tensor<1x2xi64>
+ }
+ func.func @main() {
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
+ programs=[[@collective_broadcast], [@collective_broadcast],
+ [@collective_broadcast], [@collective_broadcast]]
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
+ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
+ func.return
+ }
+}
+
+// -----
+
+module @cross_replica_single_replica {
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
+ %result = "stablehlo.collective_broadcast"(%operand) {
+ replica_groups = dense<[[0]]> : tensor<1x1xi64>,
+ channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
+ return %result : tensor<1x2xi64>
+ }
+ func.func @main() {
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
+ programs=[[@collective_broadcast, @collective_broadcast,
+ @collective_broadcast, @collective_broadcast]]
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
+ check.expect_eq_const %results#0, dense<[[1, 2]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
+ func.return
+ }
+}
+
+// -----
+
+module @cross_replica_multiple_partitions {
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
+ %result = "stablehlo.collective_broadcast"(%operand) {
+ replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>,
+ channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
+ return %result : tensor<1x2xi64>
+ }
+ func.func @main() {
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
+ programs=[[@collective_broadcast, @collective_broadcast],
+ [@collective_broadcast, @collective_broadcast]]
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
+ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#1, dense<[[7, 8]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
+ func.return
+ }
+}
+
+// -----
+
+module @cross_partition {
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
+ %result = "stablehlo.collective_broadcast"(%operand) {
+ replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
+ channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
+ return %result : tensor<1x2xi64>
+ }
+ func.func @main() {
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
+ programs=[[@collective_broadcast, @collective_broadcast,
+ @collective_broadcast, @collective_broadcast]]
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
+ check.expect_eq_const %results#0, dense<[[0, 0]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
+ func.return
+ }
+}
+
+// -----
+
+module @cross_partition_multiple_output {
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
+ %result = "stablehlo.collective_broadcast"(%operand) {
+ replica_groups = dense<[[2, 1, 0]]> : tensor<1x3xi64>,
+ channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
+ return %result : tensor<1x2xi64>
+ }
+ func.func @main() {
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
+ programs=[[@collective_broadcast, @collective_broadcast,
+ @collective_broadcast, @collective_broadcast]]
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
+ check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
+ func.return
+ }
+}
+
+// -----
+
+module @cross_partition_single_partition {
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
+ %result = "stablehlo.collective_broadcast"(%operand) {
+ replica_groups = dense<[[0]]> : tensor<1x1xi64>,
+ channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
+ return %result : tensor<1x2xi64>
+ }
+ func.func @main() {
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
+ programs=[[@collective_broadcast], [@collective_broadcast],
+ [@collective_broadcast], [@collective_broadcast]]
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
+ check.expect_eq_const %results#0, dense<[[1, 2]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
+ func.return
+ }
+}
+
+// -----
+
+module @cross_partition_multiple_replicas {
+ func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
+ %result = "stablehlo.collective_broadcast"(%operand) {
+ replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>,
+ channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
+ } : (tensor<1x2xi64>) -> tensor<1x2xi64>
+ return %result : tensor<1x2xi64>
+ }
+ func.func @main() {
+ %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
+ %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
+ %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
+ %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
+ %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
+ programs=[[@collective_broadcast, @collective_broadcast],
+ [@collective_broadcast, @collective_broadcast]]
+ } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
+ (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
+ check.expect_eq_const %results#0, dense<[[3, 4]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#2, dense<[[7, 8]]> : tensor<1x2xi64>
+ check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
+ func.return
+ }
+}

0 comments on commit f229220

Please sign in to comment.