From 70b1513faa0f79a3aa8e33a9369e9c6948a1ce51 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 7 Feb 2024 10:09:22 -0800 Subject: [PATCH] [xla:gpu] Move set/get module annotations to runtime3 Move APIs required for thunks runtime out of gpu/runtime folder. PiperOrigin-RevId: 605020738 --- xla/service/gpu/gpu_executable.cc | 8 ++------ xla/service/gpu/runtime/tracing.cc | 10 +++------- xla/service/gpu/runtime/tracing.h | 4 ---- xla/service/gpu/runtime3/annotation.cc | 21 +++++++++++++++++++++ xla/service/gpu/runtime3/annotation.h | 15 +++++++++++++++ 5 files changed, 41 insertions(+), 17 deletions(-) diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc index 554ebe0fc3f00..e78885cb8fc7f 100644 --- a/xla/service/gpu/gpu_executable.cc +++ b/xla/service/gpu/gpu_executable.cc @@ -25,7 +25,6 @@ limitations under the License. #include #include -#include "absl/cleanup/cleanup.h" #include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" @@ -47,7 +46,7 @@ limitations under the License. #include "xla/service/gpu/nccl_clique.h" #include "xla/service/gpu/nccl_clique_key.h" #include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" -#include "xla/service/gpu/runtime/tracing.h" +#include "xla/service/gpu/runtime3/annotation.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/thunk.h" #include "xla/service/hlo_parser.h" @@ -937,10 +936,7 @@ absl::Status GpuExecutable::ExecuteThunksOrXlaRuntime( CheckCompatibilityWithServiceExecutableRunOptions(run_options)); ScopedAnnotation annotation([&] { return module_annotations_.top_level; }); - absl::Cleanup annotations_cleanup = - [previous = SetCurrentModuleAnnotations(&module_annotations_)] { - SetCurrentModuleAnnotations(previous); - }; + ScopedModuleAnnotations module_annotations(&module_annotations_); ModuleIdentifier unique_id = has_module() ? module().unique_id() : -1; diff --git a/xla/service/gpu/runtime/tracing.cc b/xla/service/gpu/runtime/tracing.cc index 38069b4f27ace..3a954a584e4cc 100644 --- a/xla/service/gpu/runtime/tracing.cc +++ b/xla/service/gpu/runtime/tracing.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/runtime/tracing.h" +#include #include #include @@ -53,13 +54,12 @@ void RegisterTracingTypeIdNames(runtime::TypeIDNameRegistry& registry) { //===----------------------------------------------------------------------===// namespace { -thread_local const ModuleAnnotations* current_annotations{}; thread_local std::string_view current_tracing_scope = {}; } // namespace static absl::StatusOr ActivityStart(runtime::HloTrace annotation) { current_tracing_scope = annotation.hlo_op; - if (current_annotations) { + if (auto* current_annotations = GetCurrentModuleAnnotations()) { // We know which HloModule we belong to, and may have pre-prepared // annotation structs ready to use const auto it = current_annotations->kernels.find(annotation.hlo_op); @@ -94,11 +94,6 @@ void RegisterTracingCustomCalls(runtime::DirectCustomCallRegistry& registry) { registry.Register("xla.trace.activity_end", End); } -const ModuleAnnotations* SetCurrentModuleAnnotations( - const ModuleAnnotations* annotations) { - return std::exchange(current_annotations, annotations); -} - static void AppendTracingScopeAndModuleAnnotations( std::string* diagnostic, bool append_annotation_stack) { // Append the current trace which should help identifying original HLO @@ -108,6 +103,7 @@ static void AppendTracingScopeAndModuleAnnotations( "; current tracing scope: ", current_tracing_scope); } + auto* current_annotations = GetCurrentModuleAnnotations(); if (!append_annotation_stack || current_annotations == nullptr) { return; } diff --git a/xla/service/gpu/runtime/tracing.h b/xla/service/gpu/runtime/tracing.h index c34b8832cb367..2fe9ad4722698 100644 --- a/xla/service/gpu/runtime/tracing.h +++ b/xla/service/gpu/runtime/tracing.h @@ -21,7 +21,6 @@ limitations under the License. #include "xla/runtime/custom_call_registry.h" #include "xla/runtime/diagnostics.h" #include "xla/runtime/type_id.h" -#include "xla/service/gpu/runtime3/annotation.h" namespace xla { namespace gpu { @@ -30,9 +29,6 @@ void RegisterTracingTypeIdNames(runtime::TypeIDNameRegistry& registry); void RegisterTracingCustomCalls(runtime::DirectCustomCallRegistry& registry); -const ModuleAnnotations* SetCurrentModuleAnnotations( - const ModuleAnnotations* annotations); - // Appends to `diagnostic_engine` a handler that appends all emitted errors to // the `diagnostic` string. If `append_annotation_stack` is true, it will append // current profiler annotation stack to the diagnostic message (annotation used diff --git a/xla/service/gpu/runtime3/annotation.cc b/xla/service/gpu/runtime3/annotation.cc index 6446c5dd3935f..9763a04b1d5f7 100644 --- a/xla/service/gpu/runtime3/annotation.cc +++ b/xla/service/gpu/runtime3/annotation.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "absl/status/status.h" #include "absl/strings/str_format.h" @@ -233,4 +234,24 @@ ModuleAnnotations::ModuleAnnotations(const HloModule& mod) : top_level{mod} { } } +//===----------------------------------------------------------------------===// +// Scoped RAII helper to set and restore thread local module annotations +//===----------------------------------------------------------------------===// + +namespace { +thread_local const ModuleAnnotations* current_annotations = nullptr; +} // namespace + +ScopedModuleAnnotations::ScopedModuleAnnotations( + const ModuleAnnotations* annotations) + : restore_(std::exchange(current_annotations, annotations)) {} + +ScopedModuleAnnotations::~ScopedModuleAnnotations() { + std::exchange(current_annotations, restore_); +} + +const ModuleAnnotations* GetCurrentModuleAnnotations() { + return current_annotations; +} + } // namespace xla::gpu diff --git a/xla/service/gpu/runtime3/annotation.h b/xla/service/gpu/runtime3/annotation.h index b081342854c6d..c518be2daffd8 100644 --- a/xla/service/gpu/runtime3/annotation.h +++ b/xla/service/gpu/runtime3/annotation.h @@ -65,6 +65,21 @@ struct ModuleAnnotations { absl::flat_hash_map kernels; }; +//===----------------------------------------------------------------------===// +// Scoped RAII helper to set and restore thread local module annotations +//===----------------------------------------------------------------------===// + +class ScopedModuleAnnotations { + public: + explicit ScopedModuleAnnotations(const ModuleAnnotations* annotations); + ~ScopedModuleAnnotations(); + + private: + const ModuleAnnotations* restore_; +}; + +const ModuleAnnotations* GetCurrentModuleAnnotations(); + } // namespace xla::gpu #endif // XLA_SERVICE_GPU_RUNTIME3_ANNOTATION_H_