Skip to content

Commit

Permalink
Enhance rendezvous docs and fix race condition (#1956)
Browse files Browse the repository at this point in the history
The PR removes the use of `shared_ptr`s as it complicates the logic of
`rendezvous`. Instead of relying on the use count of `shared_ptr`, I've
added a variable to track the state separately instead. Now, there is no
need to worry about the scope of the `shared_ptr`s and cleaning up
resource before exiting.

Thanks @mlevesquedion for sharing code snippets offline to help simplify
this!
  • Loading branch information
ghpvnist authored Feb 1, 2024
1 parent aeeca3c commit 10dc10d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 102 deletions.
12 changes: 6 additions & 6 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ Tensor evalAllGatherOp(const Tensor &operand, int64_t allGatherDim,
process->getId().replicaId, process->getId().partitionId));

auto rendezvousResult =
*process->rendezvous(*processGroup, channelId, operand);
process->rendezvous(*processGroup, channelId, operand);
SmallVector<Tensor> groupOperands(llvm::map_range(
*processGroup,
[&](const ProcessId &id) { return rendezvousResult.lookup(id); }));
Expand Down Expand Up @@ -901,8 +901,8 @@ Tensor evalAllReduceOp(const Tensor &operand,
"Failed to find process group with process_id: (%d, %d)",
process->getId().replicaId, process->getId().partitionId));

auto groupOperands = process->rendezvous(*processGroup, channelId, operand)
->getSortedTensors();
auto groupOperands =
process->rendezvous(*processGroup, channelId, operand).getSortedTensors();

Tensor result(resultType);
for (auto resultIt = result.index_begin(); resultIt != result.index_end();
Expand Down Expand Up @@ -942,8 +942,8 @@ Tensor evalAllToAllOp(const Tensor &operand, Axis splitDimension,
"Failed to find process group with process_id: (%d, %d)",
process->getId().replicaId, process->getId().partitionId));

auto groupOperands = process->rendezvous(*processGroup, channelId, operand)
->getSortedTensors();
auto groupOperands =
process->rendezvous(*processGroup, channelId, operand).getSortedTensors();

SmallVector<Tensor> scatteredParts;
for (const auto &groupOperand : groupOperands) {
Expand Down Expand Up @@ -1093,7 +1093,7 @@ Tensor evalCollectivePermuteOp(
if (from != process->getId() && to != process->getId()) continue;

auto rendezvousResult =
*process->rendezvous(processGroup, channelId, operand);
process->rendezvous(processGroup, channelId, operand);
if (to != process->getId()) continue;
result = rendezvousResult.lookup(from);
}
Expand Down
7 changes: 4 additions & 3 deletions stablehlo/reference/Process.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The StableHLO Authors.
/* Copyright 2023-2024 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,8 +50,9 @@ SmallVector<Tensor> Process::recv(ChannelId channelId) {
return grid_->recv(channelId, getId());
}

std::shared_ptr<RendezvousResult const> Process::rendezvous(
ProcessGroup processGroup, ChannelId channelId, const Tensor &operand) {
RendezvousResult Process::rendezvous(ProcessGroup processGroup,
ChannelId channelId,
const Tensor &operand) {
return grid_->rendezvous(processGroup, channelId, getId(), operand);
}

Expand Down
7 changes: 3 additions & 4 deletions stablehlo/reference/Process.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The StableHLO Authors.
/* Copyright 2023-2024 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -60,9 +60,8 @@ class Process {
SmallVector<Tensor> recv(ChannelId channelId);

/// See `ProcessGrid::rendezvous`.
std::shared_ptr<RendezvousResult const> rendezvous(ProcessGroup processGroup,
ChannelId channelId,
const Tensor &operand);
RendezvousResult rendezvous(ProcessGroup processGroup, ChannelId channelId,
const Tensor &operand);

/// See `ProcessGrid::send`.
void send(ArrayRef<Tensor> inputs, ChannelId channelId);
Expand Down
73 changes: 18 additions & 55 deletions stablehlo/reference/ProcessGrid.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The StableHLO Authors.
/* Copyright 2023-2024 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -220,72 +220,35 @@ SmallVector<Tensor> ProcessGrid::recv(ChannelId channelId,
return result;
}

std::shared_ptr<RendezvousResult const> ProcessGrid::rendezvous(
ProcessGroup processGroup, ChannelId channelId, ProcessId processId,
const Tensor &operand) {
std::pair<ProcessGroup, ChannelId> channelKey(processGroup, channelId);
RendezvousResult ProcessGrid::rendezvous(ProcessGroup processGroup,
ChannelId channelId,
ProcessId processId,
const Tensor &operand) {
// Process wait/notify logic below doesn't work for single process.
if (processGroup.size() == 1)
return std::make_shared<RendezvousResult>(
RendezvousResult({std::pair{processId, operand}}));
return RendezvousResult({std::pair{processId, operand}});

std::pair<ProcessGroup, ChannelId> channelKey(processGroup, channelId);
auto &state = channels_[channelKey];

std::unique_lock<std::mutex> lock(state.mutex);
state.values[processId] = operand;
state.useCount++;

if (state.values.size() == processGroup.size()) {
// If values are full, that means all other processes are currently waiting.
// The last process to contribute moves the values into the result
// then waits for each process to return a copy of the result before
// cleaning up the state variable for future computations in this process
// grid.
state.result = std::make_shared<RendezvousResult>(state.values);
state.values.clear();
channelConditions_[channelKey].notify_one();

// The last process to contribute waits until the rest of the processes have
// read the values.
// After each process contributes, wait for the last process to notify.
if (state.values.size() < processGroup.size()) {
if (!channelConditions_[channelKey].wait_for(
lock, std::chrono::seconds(3), [&] {
return state.result.use_count() >=
static_cast<int64_t>(processGroup.size());
}))
llvm::report_fatal_error(
"rendezvous timed out: not all processes have contributed yet");

if (state.result.use_count() > static_cast<int64_t>(processGroup.size()))
llvm::report_fatal_error(
"Each process should have only one shared access to the result.");

// The last process to contribute takes the result from the state to allow
// the process that contributed last to exit the function.
auto result = std::move(state.result);
channelConditions_[channelKey].notify_one();
return result;
lock, std::chrono::seconds(3),
[&] { return state.values.size() == processGroup.size(); }))
llvm::report_fatal_error("rendezvous timed out");
} else {
state.result = std::move(state.values);
channelConditions_[channelKey].notify_all();
}

// Wait for all processes to contribute values.
if (!channelConditions_[channelKey].wait_for(
lock, std::chrono::seconds(3),
[&] { return state.result != nullptr; }))
llvm::report_fatal_error(
"rendezvous timed out: not all process has received the results yet");

// Copy result from the state before notifying.
auto result = state.result;
channelConditions_[channelKey].notify_one();
state.useCount--;

// Wait for the remaining processes to have retrieved the result. In other
// words, wait until the last process to contribute exit the function.
if (!channelConditions_[channelKey].wait_for(
lock, std::chrono::seconds(3),
[&] { return state.result == nullptr; }))
llvm::report_fatal_error(
"rendezvous timed out: not all process has received the results yet");

channelConditions_[channelKey].notify_one();
return result;
return state.useCount > 0 ? state.result : std::move(state.result);
}

void ProcessGrid::send(ArrayRef<Tensor> inputs, ChannelId channelId,
Expand Down
70 changes: 36 additions & 34 deletions stablehlo/reference/ProcessGrid.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The StableHLO Authors.
/* Copyright 2023-2024 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,11 +33,36 @@ namespace mlir {
namespace stablehlo {

struct ProcessId;
class RendezvousResult;

/// Represents a result of a `ProcessGrid::rendezvous` where multiple processes
/// synchronize at a barrier and contribute a Tensor each.
/// This class is pretty much a map from ProcessId to Tensor, with the
/// map-like API.
class RendezvousResult {
public:
RendezvousResult() = default;
RendezvousResult(std::map<ProcessId, Tensor> const &result);

/// Iterates through the (ProcessId, Tensor) map entires and returns a vector
/// of Tensors sorted by ProcessId--(replicaId, partitionId) pair--in
/// lexicographical order.
SmallVector<Tensor> getSortedTensors() const;

/// Inserts `tensor` into the map using the key `processId`.
void insert(ProcessId processId, Tensor tensor);

/// Iterates through the map and returns the value associated with the key
/// `processId`. If key is not found, return an empty `Tensor`.
Tensor lookup(ProcessId processId) const;

private:
/// Internal map representation of the result of `ProcessGrid::rendezvous`.
std::map<ProcessId, Tensor> result_;
};

namespace detail {

/// Internal storate used in `rendezvous` to manage concurrent access to the
/// Internal storage used in `rendezvous` to manage concurrent access to the
/// shared resource. Processes contribute their data to `values` concurrently.
/// Once all processes have added their data, the data in `values` is moved to
/// `result` that multiple processes can concurrently read from.
Expand All @@ -49,8 +74,12 @@ struct RendezvousState {
/// Internal storage used to store data contributed by the processes.
std::map<ProcessId, Tensor> values;

/// Shared pointer to the result of `rendezvous`.
std::shared_ptr<RendezvousResult> result;
/// Internal state management counter which counts the number of processes
/// that contributed already.
size_t useCount;

/// Stores the result of `rendezvous`.
RendezvousResult result;
};

struct SendRecvState {
Expand Down Expand Up @@ -164,31 +193,6 @@ class ProcessGroups : public SmallVector<ProcessGroup> {
std::optional<ProcessGroup> findGroup(ProcessId processId);
};

/// Represents a result of a `ProcessGrid::rendezvous` where multiple processes
/// synchronize at a barrier and contribute a Tensor each.
/// This class is pretty much a map from ProcessId to Tensor, with the
/// map-like API.
class RendezvousResult {
public:
RendezvousResult(std::map<ProcessId, Tensor> const &result);

/// Iterates through the (ProcessId, Tensor) map entires and returns a vector
/// of Tensors sorted by ProcessId--(replicaId, partitionId) pair--in
/// lexicographical order.
SmallVector<Tensor> getSortedTensors() const;

/// Inserts `tensor` into the map using the key `processId`.
void insert(ProcessId processId, Tensor tensor);

/// Iterates through the map and returns the value associated with the key
/// `processId`. If key is not found, return an empty `Tensor`.
Tensor lookup(ProcessId processId) const;

private:
/// Internal map representation of the result of `ProcessGrid::rendezvous`.
std::map<ProcessId, Tensor> result_;
};

/// StableHLO process grid.
class ProcessGrid {
public:
Expand Down Expand Up @@ -245,10 +249,8 @@ class ProcessGrid {
/// tensors are accumulated in `RendezvousResult` whose shared pointer is
/// returned to all callers once the barrier has been reached by all StableHLO
/// processes.
std::shared_ptr<RendezvousResult const> rendezvous(ProcessGroup processGroup,
ChannelId channelId,
ProcessId processId,
const Tensor &operand);
RendezvousResult rendezvous(ProcessGroup processGroup, ChannelId channelId,
ProcessId processId, const Tensor &operand);

/// Sends `inputs` to a channel with `channelId`.
/// The channel with `channelId` is emptied before the receiving process can
Expand Down

0 comments on commit 10dc10d

Please sign in to comment.