Skip to content

Commit

Permalink
Stop using some tsl aliases to absl types in XLA.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 722807259
  • Loading branch information
klucke authored and tensorflower-gardener committed Feb 3, 2025
1 parent 866f42f commit d195dd3
Show file tree
Hide file tree
Showing 18 changed files with 61 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ namespace profiler {
using tensorflow::ProfileOptions;
using tsl::mutex;
using tsl::mutex_lock;
using tsl::Status;
using tsl::profiler::Annotation;
using tsl::profiler::AnnotationStack;
using tsl::profiler::FindOrAddMutablePlaneWithName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ using tensorflow::ProfileOptions;
using tsl::mutex;
using tsl::mutex_lock;
// using tsl::OkStatus;
using tsl::Status;
using tsl::profiler::Annotation;
using tsl::profiler::AnnotationStack;
using tsl::profiler::FindOrAddMutablePlaneWithName;
Expand Down
6 changes: 2 additions & 4 deletions third_party/xla/xla/backends/profiler/tpu/tpu_tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ namespace {
using tensorflow::ProfileOptions;
using tensorflow::profiler::XPlane;
using tensorflow::profiler::XSpace;
using tsl::OkStatus; // TENSORFLOW_STATUS_OK
using tsl::Status; // TENSORFLOW_STATUS_OK
using tsl::profiler::ProfilerInterface;

class ProfilerStatusHelper {
Expand All @@ -65,9 +63,9 @@ class ProfilerStatusHelper {
TF_Status* const c_status) {
if (stream_executor::tpu::ProfilerApiFn()->TpuStatus_CodeFn(c_status) ==
TSL_OK) {
return ::tsl::OkStatus();
return absl::OkStatus();
} else {
return tsl::Status( // TENSORFLOW_STATUS_OK
return absl::Status( // TENSORFLOW_STATUS_OK
absl::StatusCode(
stream_executor::tpu::ProfilerApiFn()->TpuStatus_CodeFn(
c_status)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ using ::int64_t;
using ::tsl::int16;
using ::tsl::int32;
using ::tsl::int8;
using ::tsl::StatusOr; // TENSORFLOW_STATUS_OK
using ::tsl::uint16;
using ::tsl::uint32;
using ::tsl::uint64;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.

#include <gtest/gtest.h>
#include "absl/strings/ascii.h"
#include "absl/strings/escaping.h"
#include "absl/strings/string_view.h"
#include "mlir/IR/Builders.h" // from @llvm-project
#include "xla/backends/gpu/codegen/triton/support.h"
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/hlo_cost_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.

#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
Expand Down
5 changes: 2 additions & 3 deletions third_party/xla/xla/stream_executor/tpu/tsl_status_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License.
#include "xla/stream_executor/tpu/c_api_decl.h"
#include "xla/tsl/c/tsl_status.h"
#include "xla/tsl/c/tsl_status_helper.h"
#include "tsl/platform/status.h"

class TslStatusHelper {
public:
Expand All @@ -32,9 +31,9 @@ class TslStatusHelper {
TF_Status* const c_status) { // TENSORFLOW_STATUS_OK
absl::StatusCode code = tsl::StatusCodeFromTSLCode(TSL_GetCode(c_status));
if (code == absl::StatusCode::kOk) {
return tsl::OkStatus();
return absl::OkStatus();
}
return tsl::Status(code, TSL_Message(c_status)); // TENSORFLOW_STATUS_OK
return absl::Status(code, TSL_Message(c_status)); // TENSORFLOW_STATUS_OK
}

bool ok() const {
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/tsl/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ cc_library(
deps = [
":tsl_status_internal",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:status",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:string_view",
],
)

Expand Down
8 changes: 4 additions & 4 deletions third_party/xla/xla/tsl/c/tsl_status.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ limitations under the License.

#include <string>

#include "absl/status/status.h"
#include "absl/strings/cord.h"
#include "absl/strings/string_view.h"
#include "xla/tsl/c/tsl_status_internal.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/status.h"

using ::tsl::Status;
using ::tsl::error::Code;
using ::tsl::errors::IOError;

TSL_Status* TSL_NewStatus() { return new TSL_Status; }
Expand All @@ -35,7 +35,7 @@ void TSL_SetStatus(TSL_Status* s, TSL_Code code, const char* msg) {
return;
}
s->status =
Status(static_cast<absl::StatusCode>(code), absl::string_view(msg));
absl::Status(static_cast<absl::StatusCode>(code), absl::string_view(msg));
}

void TSL_SetPayload(TSL_Status* s, const char* key, const char* value) {
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/tsl/lib/io/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ tsl_cc_test(
"//xla/tsl/platform:env_impl",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:test",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:strcat",
],
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/tsl/lib/io/zlib_buffers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "absl/strings/match.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/lib/io/random_inputstream.h"
#include "xla/tsl/lib/io/zlib_compression_options.h"
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/tsl/platform/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include <vector>

#include "absl/functional/any_invocable.h"
#include "absl/strings/ascii.h"
#include "xla/tsl/platform/env_time.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/file_system.h"
Expand Down
77 changes: 38 additions & 39 deletions third_party/xla/xla/tsl/platform/errors.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/base/attributes.h"
#include "absl/status/status.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/macros.h"
#include "xla/tsl/platform/status.h"
#include "tsl/platform/str_util.h"
#include "tsl/platform/strcat.h"

namespace tsl {
Expand Down Expand Up @@ -146,16 +145,16 @@ inline absl::Status CreateWithUpdatedMessage(const absl::Status& status,
}

#else
inline ::absl::Status Create(
absl::StatusCode code, ::tsl::StringPiece message,
inline absl::Status Create(
absl::StatusCode code, absl::string_view message,
const std::unordered_map<std::string, std::string>& payloads) {
Status status(code, message);
InsertPayloads(status, payloads);
return status;
}
// Returns a new Status, replacing its message with the given.
inline ::tsl::Status CreateWithUpdatedMessage(const ::tsl::Status& status,
::tsl::StringPiece message) {
inline absl::Status CreateWithUpdatedMessage(const absl::Status& status,
absl::string_view message) {
return Create(static_cast<absl::StatusCode>(status.code()), message,
GetPayloads(status));
}
Expand All @@ -173,18 +172,18 @@ void AppendToMessage(absl::Status* status, Args... args) {
}

// For propagating errors when calling a function.
#define TF_RETURN_IF_ERROR(...) \
do { \
::absl::Status _status = (__VA_ARGS__); \
if (TF_PREDICT_FALSE(!_status.ok())) { \
MAYBE_ADD_SOURCE_LOCATION(_status) \
return _status; \
} \
#define TF_RETURN_IF_ERROR(...) \
do { \
absl::Status _status = (__VA_ARGS__); \
if (TF_PREDICT_FALSE(!_status.ok())) { \
MAYBE_ADD_SOURCE_LOCATION(_status) \
return _status; \
} \
} while (0)

#define TF_RETURN_WITH_CONTEXT_IF_ERROR(expr, ...) \
do { \
::tsl::Status _status = (expr); \
absl::Status _status = (expr); \
if (TF_PREDICT_FALSE(!_status.ok())) { \
::tsl::errors::AppendToMessage(&_status, __VA_ARGS__); \
return _status; \
Expand Down Expand Up @@ -222,7 +221,7 @@ absl::Status InvalidArgument(Args... args) {
#if defined(PLATFORM_GOOGLE)
// Specialized overloads to capture source location for up to three arguments.
template <typename Arg1, typename Arg2, typename Arg3, typename Arg4>
::absl::Status InvalidArgument(
absl::Status InvalidArgument(
Arg1 arg1, Arg2 arg2, Arg3 arg3, Arg4 arg4,
absl::SourceLocation loc = absl::SourceLocation::current()) {
return absl::Status(
Expand All @@ -234,7 +233,7 @@ ::absl::Status InvalidArgument(
loc);
}
template <typename Arg1, typename Arg2, typename Arg3>
::absl::Status InvalidArgument(
absl::Status InvalidArgument(
Arg1 arg1, Arg2 arg2, Arg3 arg3,
absl::SourceLocation loc = absl::SourceLocation::current()) {
return absl::Status(
Expand All @@ -245,7 +244,7 @@ ::absl::Status InvalidArgument(
loc);
}
template <typename Arg1, typename Arg2>
::absl::Status InvalidArgument(
absl::Status InvalidArgument(
Arg1 arg1, Arg2 arg2,
absl::SourceLocation loc = absl::SourceLocation::current()) {
return absl::Status(
Expand All @@ -255,15 +254,15 @@ ::absl::Status InvalidArgument(
loc);
}
template <typename Arg1>
::absl::Status InvalidArgument(
absl::Status InvalidArgument(
Arg1 arg1, absl::SourceLocation loc = absl::SourceLocation::current()) {
return absl::Status(
absl::StatusCode::kInvalidArgument,
::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1)),
loc);
}
template <typename... Args>
::absl::Status InvalidArgumentWithPayloads(
absl::Status InvalidArgumentWithPayloads(
const absl::string_view& message,
const std::unordered_map<std::string, std::string>& payloads,
absl::SourceLocation loc = absl::SourceLocation::current()) {
Expand All @@ -272,29 +271,29 @@ ::absl::Status InvalidArgumentWithPayloads(
}
#else
template <typename Arg1, typename Arg2, typename Arg3>
::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2, Arg3 arg3) {
return ::absl::Status(
absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2, Arg3 arg3) {
return absl::Status(
absl::StatusCode::kInvalidArgument,
::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1),
::tsl::errors::internal::PrepareForStrCat(arg2),
::tsl::errors::internal::PrepareForStrCat(arg3)));
}
template <typename Arg1, typename Arg2>
::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2) {
return ::absl::Status(
absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2) {
return absl::Status(
absl::StatusCode::kInvalidArgument,
::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1),
::tsl::errors::internal::PrepareForStrCat(arg2)));
}
template <typename Arg1>
::absl::Status InvalidArgument(Arg1 arg1) {
return ::absl::Status(
absl::Status InvalidArgument(Arg1 arg1) {
return absl::Status(
absl::StatusCode::kInvalidArgument,
::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1)));
}
template <typename... Args>
::absl::Status InvalidArgumentWithPayloads(
const ::tsl::StringPiece& message,
absl::Status InvalidArgumentWithPayloads(
const absl::string_view& message,
const std::unordered_map<std::string, std::string>& payloads) {
return errors::Create(absl::StatusCode::kInvalidArgument, message, payloads);
}
Expand All @@ -310,7 +309,7 @@ absl::Status NotFound(Args... args) {
#if defined(PLATFORM_GOOGLE)
// Specialized overloads to capture source location for up to three arguments.
template <typename Arg1, typename Arg2, typename Arg3>
::absl::Status NotFound(
absl::Status NotFound(
Arg1 arg1, Arg2 arg2, Arg3 arg3,
absl::SourceLocation loc = absl::SourceLocation::current()) {
return absl::Status(
Expand All @@ -321,7 +320,7 @@ ::absl::Status NotFound(
loc);
}
template <typename Arg1, typename Arg2>
::absl::Status NotFound(
absl::Status NotFound(
Arg1 arg1, Arg2 arg2,
absl::SourceLocation loc = absl::SourceLocation::current()) {
return absl::Status(
Expand All @@ -331,45 +330,45 @@ ::absl::Status NotFound(
loc);
}
template <typename Arg1>
::absl::Status NotFound(
absl::Status NotFound(
Arg1 arg1, absl::SourceLocation loc = absl::SourceLocation::current()) {
return absl::Status(
absl::StatusCode::kNotFound,
::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1)),
loc);
}
template <typename... Args>
::absl::Status NotFoundWithPayloads(
absl::Status NotFoundWithPayloads(
const absl::string_view& message,
const std::unordered_map<std::string, std::string>& payloads,
absl::SourceLocation loc = absl::SourceLocation::current()) {
return errors::Create(absl::StatusCode::kNotFound, message, payloads, loc);
}
#else
template <typename Arg1, typename Arg2, typename Arg3>
::absl::Status NotFound(Arg1 arg1, Arg2 arg2, Arg3 arg3) {
return ::absl::Status(
absl::Status NotFound(Arg1 arg1, Arg2 arg2, Arg3 arg3) {
return absl::Status(
absl::StatusCode::kNotFound,
::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1),
::tsl::errors::internal::PrepareForStrCat(arg2),
::tsl::errors::internal::PrepareForStrCat(arg3)));
}
template <typename Arg1, typename Arg2>
::absl::Status NotFound(Arg1 arg1, Arg2 arg2) {
return ::absl::Status(
absl::Status NotFound(Arg1 arg1, Arg2 arg2) {
return absl::Status(
absl::StatusCode::kNotFound,
::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1),
::tsl::errors::internal::PrepareForStrCat(arg2)));
}
template <typename Arg1>
::absl::Status NotFound(Arg1 arg1) {
return ::absl::Status(
absl::Status NotFound(Arg1 arg1) {
return absl::Status(
absl::StatusCode::kNotFound,
::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1)));
}
template <typename... Args>
::absl::Status NotFoundWithPayloads(
const ::tsl::StringPiece& message,
absl::Status NotFoundWithPayloads(
const absl::string_view& message,
const std::unordered_map<std::string, std::string>& payloads) {
return errors::Create(absl::StatusCode::kNotFound, message, payloads);
}
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/tsl/profiler/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ cc_library(
"//xla/tsl/platform:errors",
"//xla/tsl/platform:status",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/profiler/lib:profiler_session",
"@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc",
],
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/tsl/profiler/utils/session_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <variant>
#include <vector>

#include "absl/strings/str_split.h"
#include "xla/tsl/platform/errors.h"
#include "tsl/profiler/lib/profiler_session.h"
#include "tsl/profiler/protobuf/profiler_options.pb.h"
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/tsl/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ cc_library(
deps = [
"//xla/tsl/platform:errors",
"//xla/tsl/platform:status",
"@local_tsl//tsl/platform:str_util",
"@local_tsl//tsl/platform:stringpiece",
],
)
Expand All @@ -279,6 +280,7 @@ tsl_cc_test(
"//xla/tsl/platform:test",
"//xla/tsl/platform:test_benchmark",
"//xla/tsl/platform:test_main",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/platform:strcat",
],
)
Expand Down
Loading

0 comments on commit d195dd3

Please sign in to comment.