Skip to content

Commit bd38683

Browse files
authored
chore: refactor gcp code (#294)
1 parent 52d95e1 commit bd38683

8 files changed

+326
-116
lines changed

util/cloud/gcp/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
add_library(gcp_lib gcs.cc)
1+
add_library(gcp_lib gcs.cc gcs_file.cc gcp_utils.cc)
22

33
cxx_link(gcp_lib http_client_lib strings_lib TRDP::rapidjson)

util/cloud/gcp/gcp_creds_provider.h

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright 2024, Roman Gershman. All rights reserved.
2+
// See LICENSE for licensing terms.
3+
4+
#pragma once
5+
6+
#include "base/RWSpinLock.h"
7+
8+
namespace util {
9+
10+
namespace fb2 {
11+
class ProactorBase;
12+
} // namespace fb2
13+
14+
namespace cloud {
15+
16+
class GCPCredsProvider {
17+
GCPCredsProvider(const GCPCredsProvider&) = delete;
18+
GCPCredsProvider& operator=(const GCPCredsProvider&) = delete;
19+
20+
public:
21+
GCPCredsProvider() = default;
22+
23+
std::error_code Init(unsigned connect_ms, fb2::ProactorBase* pb);
24+
25+
const std::string& project_id() const {
26+
return project_id_;
27+
}
28+
29+
const std::string& client_id() const {
30+
return client_id_;
31+
}
32+
33+
// Thread-safe method to access the token.
34+
std::string access_token() const {
35+
folly::RWSpinLock::ReadHolder lock(lock_);
36+
return access_token_;
37+
}
38+
39+
time_t expire_time() const {
40+
return expire_time_.load(std::memory_order_acquire);
41+
}
42+
43+
// Thread-safe method issues refresh of the token.
44+
// Right now will do the refresh unconditonally.
45+
// TODO: to use expire_time_ to skip the refresh if expire time is far away.
46+
std::error_code RefreshToken(fb2::ProactorBase* pb);
47+
48+
private:
49+
bool use_instance_metadata_ = false;
50+
unsigned connect_ms_ = 0;
51+
52+
fb2::ProactorBase* pb_ = nullptr;
53+
std::string account_id_;
54+
std::string project_id_;
55+
56+
std::string client_id_, client_secret_, refresh_token_;
57+
58+
mutable folly::RWSpinLock lock_; // protects access_token_
59+
std::string access_token_;
60+
std::atomic<time_t> expire_time_ = 0; // seconds since epoch
61+
};
62+
63+
} // namespace cloud
64+
} // namespace util

util/cloud/gcp/gcp_utils.cc

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// Copyright 2024, Roman Gershman. All rights reserved.
2+
// See LICENSE for licensing terms.
3+
4+
#include "util/cloud/gcp/gcp_utils.h"
5+
6+
#include <absl/strings/str_cat.h>
7+
8+
#include <boost/beast/http/string_body.hpp>
9+
10+
#include "base/logging.h"
11+
#include "util/cloud/gcp/gcp_creds_provider.h"
12+
13+
#define RETURN_UNEXPECTED(x) \
14+
do { \
15+
auto ec = (x); \
16+
if (ec) \
17+
return nonstd::make_unexpected(ec); \
18+
} while (false)
19+
20+
namespace util::cloud {
21+
using namespace std;
22+
namespace h2 = boost::beast::http;
23+
24+
namespace {
25+
26+
bool IsUnauthorized(const h2::header<false, h2::fields>& resp) {
27+
if (resp.result() != h2::status::unauthorized) {
28+
return false;
29+
}
30+
auto it = resp.find("WWW-Authenticate");
31+
32+
return it != resp.end();
33+
}
34+
35+
inline bool DoesServerPushback(h2::status st) {
36+
return st == h2::status::too_many_requests ||
37+
h2::to_status_class(st) == h2::status_class::server_error;
38+
}
39+
40+
} // namespace
41+
42+
const char GCP_API_DOMAIN[] = "www.googleapis.com";
43+
44+
string AuthHeader(string_view access_token) {
45+
return absl::StrCat("Bearer ", access_token);
46+
}
47+
48+
EmptyRequest PrepareRequest(h2::verb req_verb, std::string_view url,
49+
const string_view access_token) {
50+
EmptyRequest req{req_verb, boost::beast::string_view{url.data(), url.size()}, 11};
51+
req.set(h2::field::host, GCP_API_DOMAIN);
52+
req.set(h2::field::authorization, AuthHeader(access_token));
53+
req.keep_alive(true);
54+
55+
return req;
56+
}
57+
58+
RobustSender::RobustSender(unsigned num_iterations, GCPCredsProvider* provider)
59+
: num_iterations_(num_iterations), provider_(provider) {
60+
}
61+
62+
auto RobustSender::Send(http::Client* client, EmptyRequest* req) -> io::Result<HeaderParserPtr> {
63+
error_code ec;
64+
for (unsigned i = 0; i < num_iterations_; ++i) { // Iterate for possible token refresh.
65+
VLOG(1) << "HttpReq" << i << ": " << *req << ", socket " << client->native_handle();
66+
67+
RETURN_UNEXPECTED(client->Send(*req));
68+
HeaderParserPtr parser(new h2::response_parser<h2::empty_body>());
69+
RETURN_UNEXPECTED(client->ReadHeader(parser.get()));
70+
{
71+
const auto& msg = parser->get();
72+
VLOG(1) << "RespHeader" << i << ": " << msg;
73+
74+
if (!parser->keep_alive()) {
75+
LOG(FATAL) << "TBD: Schedule reconnect due to conn-close header";
76+
}
77+
78+
// Partial content can appear because of the previous reconnect.
79+
if (msg.result() == h2::status::ok || msg.result() == h2::status::partial_content) {
80+
return parser;
81+
}
82+
}
83+
// We have some kind of error, possibly with body that needs to be drained.
84+
h2::response_parser<h2::string_body> drainer(std::move(*parser));
85+
RETURN_UNEXPECTED(client->Recv(&drainer));
86+
const auto& msg = drainer.get();
87+
88+
if (DoesServerPushback(msg.result())) {
89+
LOG(INFO) << "Retrying(" << client->native_handle() << ") with " << msg;
90+
91+
ThisFiber::SleepFor(1s);
92+
i = 0; // Can potentially deadlock
93+
continue;
94+
}
95+
96+
if (IsUnauthorized(msg)) {
97+
RETURN_UNEXPECTED(provider_->RefreshToken(client->proactor()));
98+
req->set(h2::field::authorization, AuthHeader(provider_->access_token()));
99+
100+
continue;
101+
}
102+
103+
if (msg.result() == h2::status::forbidden) {
104+
return nonstd::make_unexpected(make_error_code(errc::operation_not_permitted));
105+
}
106+
107+
ec = make_error_code(errc::bad_message);
108+
LOG(DFATAL) << "Unexpected response " << msg << "\n" << msg.body() << "\n";
109+
}
110+
111+
return nonstd::make_unexpected(ec);
112+
}
113+
114+
} // namespace util::cloud

util/cloud/gcp/gcp_utils.h

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright 2024, Roman Gershman. All rights reserved.
2+
// See LICENSE for licensing terms.
3+
4+
#pragma once
5+
6+
#include <boost/beast/http/empty_body.hpp>
7+
#include <memory>
8+
9+
#include "io/io.h"
10+
#include "util/http/http_client.h"
11+
12+
namespace util::cloud {
13+
class GCPCredsProvider;
14+
15+
extern const char GCP_API_DOMAIN[];
16+
17+
using EmptyRequest = boost::beast::http::request<boost::beast::http::empty_body>;
18+
19+
EmptyRequest PrepareRequest(boost::beast::http::verb req_verb, std::string_view url,
20+
const std::string_view access_token);
21+
22+
std::string AuthHeader(std::string_view access_token);
23+
24+
class RobustSender {
25+
public:
26+
using HeaderParserPtr =
27+
std::unique_ptr<boost::beast::http::response_parser<boost::beast::http::empty_body>>;
28+
29+
RobustSender(unsigned num_iterations, GCPCredsProvider* provider);
30+
31+
io::Result<HeaderParserPtr> Send(http::Client* client, EmptyRequest* req);
32+
33+
private:
34+
unsigned num_iterations_;
35+
GCPCredsProvider* provider_;
36+
};
37+
38+
} // namespace util::cloud

util/cloud/gcp/gcs.cc

+10-62
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "io/file_util.h"
1616
#include "io/line_reader.h"
1717
#include "strings/escaping.h"
18+
#include "util/cloud/gcp/gcp_utils.h"
1819

1920
using namespace std;
2021
namespace h2 = boost::beast::http;
@@ -24,9 +25,6 @@ namespace util {
2425
namespace cloud {
2526

2627
namespace {
27-
constexpr char kDomain[] = "www.googleapis.com";
28-
29-
using EmptyRequest = h2::request<h2::empty_body>;
3028

3129
auto Unexpected(std::errc code) {
3230
return nonstd::make_unexpected(make_error_code(code));
@@ -46,28 +44,6 @@ auto Unexpected(std::errc code) {
4644
return ec; \
4745
} while (false)
4846

49-
string AuthHeader(string_view access_token) {
50-
return absl::StrCat("Bearer ", access_token);
51-
}
52-
53-
EmptyRequest PrepareRequest(h2::verb req_verb, boost::beast::string_view url,
54-
const string_view access_token) {
55-
EmptyRequest req(req_verb, url, 11);
56-
req.set(h2::field::host, kDomain);
57-
req.set(h2::field::authorization, AuthHeader(access_token));
58-
req.keep_alive(true);
59-
60-
return req;
61-
}
62-
63-
bool IsUnauthorized(const h2::header<false, h2::fields>& resp) {
64-
if (resp.result() != h2::status::unauthorized) {
65-
return false;
66-
}
67-
auto it = resp.find("WWW-Authenticate");
68-
69-
return it != resp.end();
70-
}
7147

7248
io::Result<string> ExpandFile(string_view path) {
7349
io::Result<io::StatShortVec> res = io::StatFiles(path);
@@ -177,36 +153,6 @@ io::Result<TokenTtl> ParseTokenResponse(std::string&& response) {
177153
return result;
178154
}
179155

180-
using EmptyParserPtr = std::unique_ptr<h2::response_parser<h2::empty_body>>;
181-
io::Result<EmptyParserPtr> SendWithToken(GCPCredsProvider* provider, http::Client* client,
182-
EmptyRequest* req) {
183-
error_code ec;
184-
for (unsigned i = 0; i < 2; ++i) { // Iterate for possible token refresh.
185-
VLOG(1) << "HttpReq" << i << ": " << *req << ", socket " << client->native_handle();
186-
187-
RETURN_UNEXPECTED(client->Send(*req));
188-
EmptyParserPtr parser(new h2::response_parser<h2::empty_body>());
189-
RETURN_UNEXPECTED(client->ReadHeader(parser.get()));
190-
191-
VLOG(1) << "RespHeader" << i << ": " << parser->get();
192-
193-
if (parser->get().result() == h2::status::ok) {
194-
return parser;
195-
};
196-
197-
if (IsUnauthorized(parser->get())) {
198-
RETURN_UNEXPECTED(provider->RefreshToken(client->proactor()));
199-
req->set(h2::field::authorization, AuthHeader(provider->access_token()));
200-
201-
continue;
202-
}
203-
ec = make_error_code(errc::bad_message);
204-
LOG(DFATAL) << "Unexpected response " << parser.get();
205-
}
206-
207-
return nonstd::make_unexpected(ec);
208-
}
209-
210156
#define FETCH_ARRAY_MEMBER(val) \
211157
if (!(val).IsArray()) \
212158
return make_error_code(errc::bad_message); \
@@ -310,7 +256,7 @@ GCS::~GCS() {
310256
std::error_code GCS::Connect(unsigned msec) {
311257
client_->set_connect_timeout_ms(msec);
312258

313-
return client_->Connect(kDomain, "443", ssl_ctx_);
259+
return client_->Connect(GCP_API_DOMAIN, "443", ssl_ctx_);
314260
}
315261

316262
error_code GCS::ListBuckets(ListBucketCb cb) {
@@ -321,12 +267,13 @@ error_code GCS::ListBuckets(ListBucketCb cb) {
321267

322268
rj::Document doc;
323269

270+
RobustSender sender(2, &creds_provider_);
271+
324272
while (true) {
325-
io::Result<EmptyParserPtr> parse_res =
326-
SendWithToken(&creds_provider_, client_.get(), &http_req);
273+
io::Result<RobustSender::HeaderParserPtr> parse_res = sender.Send(client_.get(), &http_req);
327274
if (!parse_res)
328275
return parse_res.error();
329-
EmptyParserPtr empty_parser = std::move(*parse_res);
276+
RobustSender::HeaderParserPtr empty_parser = std::move(*parse_res);
330277
h2::response_parser<h2::string_body> resp(std::move(*empty_parser));
331278
RETURN_ERROR(client_->Recv(&resp));
332279

@@ -376,12 +323,13 @@ error_code GCS::List(string_view bucket, string_view prefix, bool recursive,
376323
auto http_req = PrepareRequest(h2::verb::get, url, creds_provider_.access_token());
377324

378325
rj::Document doc;
326+
RobustSender sender(2, &creds_provider_);
379327
while (true) {
380-
io::Result<EmptyParserPtr> parse_res =
381-
SendWithToken(&creds_provider_, client_.get(), &http_req);
328+
io::Result<RobustSender::HeaderParserPtr> parse_res = sender.Send(client_.get(), &http_req);
382329
if (!parse_res)
383330
return parse_res.error();
384-
EmptyParserPtr empty_parser = std::move(*parse_res);
331+
RobustSender::HeaderParserPtr empty_parser = std::move(*parse_res);
332+
385333
h2::response_parser<h2::string_body> resp(std::move(*empty_parser));
386334
RETURN_ERROR(client_->Recv(&resp));
387335

0 commit comments

Comments
 (0)