Skip to content

Commit

Permalink
[xla:gpu] Move set/get module annotations to runtime3
Browse files Browse the repository at this point in the history
Move APIs required for thunks runtime out of gpu/runtime folder.

PiperOrigin-RevId: 605020738
  • Loading branch information
ezhulenev authored and copybara-github committed Feb 7, 2024
1 parent 549e95a commit 70b1513
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 17 deletions.
8 changes: 2 additions & 6 deletions xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ limitations under the License.
#include <variant>
#include <vector>

#include "absl/cleanup/cleanup.h"
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
Expand All @@ -47,7 +46,7 @@ limitations under the License.
#include "xla/service/gpu/nccl_clique.h"
#include "xla/service/gpu/nccl_clique_key.h"
#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h"
#include "xla/service/gpu/runtime/tracing.h"
#include "xla/service/gpu/runtime3/annotation.h"
#include "xla/service/gpu/stream_executor_util.h"
#include "xla/service/gpu/thunk.h"
#include "xla/service/hlo_parser.h"
Expand Down Expand Up @@ -937,10 +936,7 @@ absl::Status GpuExecutable::ExecuteThunksOrXlaRuntime(
CheckCompatibilityWithServiceExecutableRunOptions(run_options));

ScopedAnnotation annotation([&] { return module_annotations_.top_level; });
absl::Cleanup annotations_cleanup =
[previous = SetCurrentModuleAnnotations(&module_annotations_)] {
SetCurrentModuleAnnotations(previous);
};
ScopedModuleAnnotations module_annotations(&module_annotations_);

ModuleIdentifier unique_id = has_module() ? module().unique_id() : -1;

Expand Down
10 changes: 3 additions & 7 deletions xla/service/gpu/runtime/tracing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "xla/service/gpu/runtime/tracing.h"

#include <cstdint>
#include <string>
#include <string_view>

Expand Down Expand Up @@ -53,13 +54,12 @@ void RegisterTracingTypeIdNames(runtime::TypeIDNameRegistry& registry) {
//===----------------------------------------------------------------------===//

namespace {
thread_local const ModuleAnnotations* current_annotations{};
thread_local std::string_view current_tracing_scope = {};
} // namespace

static absl::StatusOr<int64_t> ActivityStart(runtime::HloTrace annotation) {
current_tracing_scope = annotation.hlo_op;
if (current_annotations) {
if (auto* current_annotations = GetCurrentModuleAnnotations()) {
// We know which HloModule we belong to, and may have pre-prepared
// annotation structs ready to use
const auto it = current_annotations->kernels.find(annotation.hlo_op);
Expand Down Expand Up @@ -94,11 +94,6 @@ void RegisterTracingCustomCalls(runtime::DirectCustomCallRegistry& registry) {
registry.Register("xla.trace.activity_end", End);
}

const ModuleAnnotations* SetCurrentModuleAnnotations(
const ModuleAnnotations* annotations) {
return std::exchange(current_annotations, annotations);
}

static void AppendTracingScopeAndModuleAnnotations(
std::string* diagnostic, bool append_annotation_stack) {
// Append the current trace which should help identifying original HLO
Expand All @@ -108,6 +103,7 @@ static void AppendTracingScopeAndModuleAnnotations(
"; current tracing scope: ", current_tracing_scope);
}

auto* current_annotations = GetCurrentModuleAnnotations();
if (!append_annotation_stack || current_annotations == nullptr) {
return;
}
Expand Down
4 changes: 0 additions & 4 deletions xla/service/gpu/runtime/tracing.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ limitations under the License.
#include "xla/runtime/custom_call_registry.h"
#include "xla/runtime/diagnostics.h"
#include "xla/runtime/type_id.h"
#include "xla/service/gpu/runtime3/annotation.h"

namespace xla {
namespace gpu {
Expand All @@ -30,9 +29,6 @@ void RegisterTracingTypeIdNames(runtime::TypeIDNameRegistry& registry);

void RegisterTracingCustomCalls(runtime::DirectCustomCallRegistry& registry);

const ModuleAnnotations* SetCurrentModuleAnnotations(
const ModuleAnnotations* annotations);

// Appends to `diagnostic_engine` a handler that appends all emitted errors to
// the `diagnostic` string. If `append_annotation_stack` is true, it will append
// current profiler annotation stack to the diagnostic message (annotation used
Expand Down
21 changes: 21 additions & 0 deletions xla/service/gpu/runtime3/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <optional>
#include <string>
#include <string_view>
#include <utility>

#include "absl/status/status.h"
#include "absl/strings/str_format.h"
Expand Down Expand Up @@ -233,4 +234,24 @@ ModuleAnnotations::ModuleAnnotations(const HloModule& mod) : top_level{mod} {
}
}

//===----------------------------------------------------------------------===//
// Scoped RAII helper to set and restore thread local module annotations
//===----------------------------------------------------------------------===//

namespace {
thread_local const ModuleAnnotations* current_annotations = nullptr;
} // namespace

ScopedModuleAnnotations::ScopedModuleAnnotations(
const ModuleAnnotations* annotations)
: restore_(std::exchange(current_annotations, annotations)) {}

ScopedModuleAnnotations::~ScopedModuleAnnotations() {
std::exchange(current_annotations, restore_);
}

const ModuleAnnotations* GetCurrentModuleAnnotations() {
return current_annotations;
}

} // namespace xla::gpu
15 changes: 15 additions & 0 deletions xla/service/gpu/runtime3/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,21 @@ struct ModuleAnnotations {
absl::flat_hash_map<std::string_view, KernelAnnotation> kernels;
};

//===----------------------------------------------------------------------===//
// Scoped RAII helper to set and restore thread local module annotations
//===----------------------------------------------------------------------===//

class ScopedModuleAnnotations {
public:
explicit ScopedModuleAnnotations(const ModuleAnnotations* annotations);
~ScopedModuleAnnotations();

private:
const ModuleAnnotations* restore_;
};

const ModuleAnnotations* GetCurrentModuleAnnotations();

} // namespace xla::gpu

#endif // XLA_SERVICE_GPU_RUNTIME3_ANNOTATION_H_

0 comments on commit 70b1513

Please sign in to comment.