Skip to content

Commit 5c80d30

Browse files
authored
chore: Implement dev-credential support for GCS (#291)
In addition, implement basic ListBuckets method for GCS. Signed-off-by: Roman Gershman <roman@dragonflydb.io>
1 parent be3c777 commit 5c80d30

File tree

5 files changed

+352
-10
lines changed

5 files changed

+352
-10
lines changed

examples/gcs_demo.cc

+8-4
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,22 @@ using namespace util;
1414
using absl::GetFlag;
1515

1616
ABSL_FLAG(string, bucket, "", "");
17-
ABSL_FLAG(string, access_token, "", "");
1817
ABSL_FLAG(uint32_t, connect_ms, 2000, "");
1918
ABSL_FLAG(bool, epoll, false, "Whether to use epoll instead of io_uring");
2019

2120

2221
void Run(SSL_CTX* ctx) {
2322
fb2::ProactorBase* pb = fb2::ProactorBase::me();
24-
cloud::GCS gcs(ctx, pb);
25-
error_code ec = gcs.Connect(GetFlag(FLAGS_connect_ms));
23+
cloud::GCPCredsProvider provider;
24+
unsigned connect_ms = GetFlag(FLAGS_connect_ms);
25+
error_code ec = provider.Init(connect_ms, pb);
26+
CHECK(!ec) << "Could not load credentials " << ec.message();
27+
28+
cloud::GCS gcs(&provider, ctx, pb);
29+
ec = gcs.Connect(connect_ms);
2630
CHECK(!ec) << "Could not connect " << ec;
2731
auto res = gcs.ListBuckets();
28-
CHECK(res) << res.error();
32+
CHECK(res) << res.error().message();
2933
for (auto v : *res) {
3034
CONSOLE_INFO << v;
3135
}

util/cloud/gcp/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
add_library(gcp_lib gcs.cc)
22

3-
cxx_link(gcp_lib http_client_lib)
3+
cxx_link(gcp_lib http_client_lib TRDP::rapidjson)

util/cloud/gcp/gcs.cc

+287-3
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,288 @@
33

44
#include "util/cloud/gcp/gcs.h"
55

6+
#include <absl/strings/str_cat.h>
7+
#include <absl/strings/str_split.h>
8+
#include <rapidjson/document.h>
9+
10+
#include <boost/beast/http/empty_body.hpp>
11+
#include <boost/beast/http/string_body.hpp>
12+
13+
#include "base/logging.h"
14+
#include "io/file.h"
15+
#include "io/file_util.h"
16+
#include "io/line_reader.h"
17+
18+
using namespace std;
19+
namespace h2 = boost::beast::http;
20+
namespace rj = rapidjson;
21+
622
namespace util {
723
namespace cloud {
824

9-
namespace {
25+
namespace {
1026
constexpr char kDomain[] = "www.googleapis.com";
27+
28+
using EmptyRequest = h2::request<h2::empty_body>;
29+
30+
auto Unexpected(std::errc code) {
31+
return nonstd::make_unexpected(make_error_code(code));
32+
}
33+
34+
string AuthHeader(string_view access_token) {
35+
return absl::StrCat("Bearer ", access_token);
36+
}
37+
38+
EmptyRequest PrepareRequest(h2::verb req_verb, boost::beast::string_view url,
39+
const string_view access_token) {
40+
EmptyRequest req(req_verb, url, 11);
41+
req.set(h2::field::host, kDomain);
42+
req.set(h2::field::authorization, AuthHeader(access_token));
43+
req.keep_alive(true);
44+
45+
return req;
46+
}
47+
48+
bool IsUnauthorized(const h2::header<false, h2::fields>& resp) {
49+
if (resp.result() != h2::status::unauthorized) {
50+
return false;
51+
}
52+
auto it = resp.find("WWW-Authenticate");
53+
54+
return it != resp.end();
55+
}
56+
57+
io::Result<string> ExpandFile(string_view path) {
58+
io::Result<io::StatShortVec> res = io::StatFiles(path);
59+
60+
if (!res) {
61+
return nonstd::make_unexpected(res.error());
62+
}
63+
64+
if (res->empty()) {
65+
VLOG(1) << "Could not find " << path;
66+
return Unexpected(errc::no_such_file_or_directory);
67+
}
68+
return res->front().name;
69+
}
70+
71+
std::error_code LoadGCPConfig(string* account_id, string* project_id) {
72+
io::Result<string> path = ExpandFile("~/.config/gcloud/configurations/config_default");
73+
if (!path) {
74+
return path.error();
75+
}
76+
77+
io::Result<string> config = io::ReadFileToString(*path);
78+
if (!config) {
79+
return config.error();
80+
}
81+
82+
io::BytesSource bs(*config);
83+
io::LineReader reader(&bs, DO_NOT_TAKE_OWNERSHIP, 11);
84+
string scratch;
85+
string_view line;
86+
while (reader.Next(&line, &scratch)) {
87+
vector<string_view> vals = absl::StrSplit(line, "=");
88+
if (vals.size() != 2)
89+
continue;
90+
for (auto& v : vals) {
91+
v = absl::StripAsciiWhitespace(v);
92+
}
93+
if (vals[0] == "account") {
94+
*account_id = string(vals[1]);
95+
} else if (vals[0] == "project") {
96+
*project_id = string(vals[1]);
97+
}
98+
}
99+
100+
return {};
101+
}
102+
103+
std::error_code ParseADC(string_view adc_file, string* client_id, string* client_secret,
104+
string* refresh_token) {
105+
io::Result<string> adc = io::ReadFileToString(adc_file);
106+
if (!adc) {
107+
return adc.error();
108+
}
109+
110+
rj::Document adc_doc;
111+
constexpr unsigned kFlags = rj::kParseTrailingCommasFlag | rj::kParseCommentsFlag;
112+
adc_doc.ParseInsitu<kFlags>(&adc->front());
113+
114+
if (adc_doc.HasParseError()) {
115+
return make_error_code(errc::protocol_error);
116+
}
117+
118+
for (auto it = adc_doc.MemberBegin(); it != adc_doc.MemberEnd(); ++it) {
119+
if (it->name == "client_id") {
120+
*client_id = it->value.GetString();
121+
} else if (it->name == "client_secret") {
122+
*client_secret = it->value.GetString();
123+
} else if (it->name == "refresh_token") {
124+
*refresh_token = it->value.GetString();
125+
}
126+
}
127+
128+
return {};
129+
}
130+
131+
// token, expire_in (seconds)
132+
using TokenTtl = pair<string, unsigned>;
133+
134+
io::Result<TokenTtl> ParseTokenResponse(std::string&& response) {
135+
VLOG(1) << "Refresh Token response: " << response;
136+
137+
rj::Document doc;
138+
constexpr unsigned kFlags = rj::kParseTrailingCommasFlag | rj::kParseCommentsFlag;
139+
doc.ParseInsitu<kFlags>(&response.front());
140+
141+
if (doc.HasParseError()) {
142+
return Unexpected(errc::bad_message);
143+
}
144+
145+
TokenTtl result;
146+
auto it = doc.FindMember("token_type");
147+
if (it == doc.MemberEnd() || string_view{it->value.GetString()} != "Bearer"sv) {
148+
return Unexpected(errc::bad_message);
149+
}
150+
151+
it = doc.FindMember("access_token");
152+
if (it == doc.MemberEnd()) {
153+
return Unexpected(errc::bad_message);
154+
}
155+
result.first = it->value.GetString();
156+
it = doc.FindMember("expires_in");
157+
if (it == doc.MemberEnd() || !it->value.IsUint()) {
158+
return Unexpected(errc::bad_message);
159+
}
160+
result.second = it->value.GetUint();
161+
162+
return result;
163+
}
164+
165+
template <typename RespBody>
166+
error_code SendWithToken(GCPCredsProvider* provider, http::Client* client, EmptyRequest* req, h2::response<RespBody>* resp) {
167+
for (unsigned i = 0; i < 2; ++i) { // Iterate for possible token refresh.
168+
VLOG(1) << "HttpReq" << i << ": " << *req << ", socket " << client->native_handle();
169+
170+
error_code ec = client->Send(*req, resp);
171+
if (ec) {
172+
return ec;
173+
}
174+
VLOG(1) << "HttpResp" << i << ": " << *resp;
175+
176+
if (resp->result() == h2::status::ok) {
177+
break;
178+
};
179+
180+
if (IsUnauthorized(*resp)) {
181+
ec = provider->RefreshToken(client->proactor());
182+
if (ec) {
183+
return ec;
184+
}
185+
186+
*resp = {};
187+
req->set(h2::field::authorization, AuthHeader(provider->access_token()));
188+
189+
continue;
190+
}
191+
LOG(FATAL) << "Unexpected response " << *resp;
192+
}
193+
return {};
194+
}
195+
11196
} // namespace
12197

198+
error_code GCPCredsProvider::Init(unsigned connect_ms, fb2::ProactorBase* pb) {
199+
CHECK_GT(connect_ms, 0u);
200+
201+
io::Result<string> root_path = ExpandFile("~/.config/gcloud");
202+
if (!root_path) {
203+
return root_path.error();
204+
}
205+
206+
bool is_cloud_env = false;
207+
string gce_file = absl::StrCat(*root_path, "/gce");
208+
209+
VLOG(1) << "Reading from " << gce_file;
13210

14-
GCS::GCS(SSL_CTX* ssl_cntx, fb2::ProactorBase* pb) {
211+
io::Result<string> gce_file_str = io::ReadFileToString(gce_file);
212+
213+
if (gce_file_str && *gce_file_str == "True") {
214+
is_cloud_env = true;
215+
}
216+
217+
if (is_cloud_env) {
218+
use_instance_metadata_ = true;
219+
LOG(FATAL) << "TBD: do not support reading from instance metadata";
220+
} else {
221+
error_code ec = LoadGCPConfig(&account_id_, &project_id_);
222+
if (ec)
223+
return ec;
224+
if (account_id_.empty() || project_id_.empty()) {
225+
LOG(WARNING) << "gcloud config file is not valid";
226+
return make_error_code(errc::not_supported);
227+
}
228+
string adc_file = absl::StrCat(*root_path, "/legacy_credentials/", account_id_, "/adc.json");
229+
VLOG(1) << "ADC file: " << adc_file;
230+
ec = ParseADC(adc_file, &client_id_, &client_secret_, &refresh_token_);
231+
if (ec)
232+
return ec;
233+
if (client_id_.empty() || client_secret_.empty() || refresh_token_.empty()) {
234+
LOG(WARNING) << "Bad ADC file " << adc_file;
235+
return make_error_code(errc::bad_message);
236+
}
237+
}
238+
239+
// At this point we should have all the data to get an access token.
240+
connect_ms_ = connect_ms;
241+
return RefreshToken(pb);
242+
}
243+
244+
error_code GCPCredsProvider::RefreshToken(fb2::ProactorBase* pb) {
245+
constexpr char kDomain[] = "oauth2.googleapis.com";
246+
247+
http::TlsClient https_client(pb);
248+
https_client.set_connect_timeout_ms(connect_ms_);
249+
SSL_CTX* context = http::TlsClient::CreateSslContext();
250+
error_code ec = https_client.Connect(kDomain, "443", context);
251+
http::TlsClient::FreeContext(context);
252+
253+
if (ec)
254+
return ec;
255+
h2::request<h2::string_body> req{h2::verb::post, "/token", 11};
256+
req.set(h2::field::host, kDomain);
257+
req.set(h2::field::content_type, "application/x-www-form-urlencoded");
258+
259+
string& body = req.body();
260+
body = absl::StrCat("grant_type=refresh_token&client_secret=", client_secret_,
261+
"&refresh_token=", refresh_token_);
262+
absl::StrAppend(&body, "&client_id=", client_id_);
263+
req.prepare_payload();
264+
VLOG(1) << "Req: " << req;
265+
266+
h2::response<h2::string_body> resp;
267+
ec = https_client.Send(req, &resp);
268+
if (ec)
269+
return ec;
270+
if (resp.result() != h2::status::ok) {
271+
LOG(WARNING) << "Http error: " << string(resp.reason()) << ", Body: ", resp.body();
272+
return make_error_code(errc::permission_denied);
273+
}
274+
275+
io::Result<TokenTtl> token = ParseTokenResponse(std::move(resp.body()));
276+
if (!token)
277+
return token.error();
278+
279+
folly::RWSpinLock::WriteHolder lock(lock_);
280+
access_token_ = token->first;
281+
expire_time_.store(time(nullptr) + token->second, std::memory_order_release);
282+
283+
return {};
284+
}
285+
286+
GCS::GCS(GCPCredsProvider* provider, SSL_CTX* ssl_cntx, fb2::ProactorBase* pb)
287+
: creds_provider_(*provider), ssl_ctx_(ssl_cntx) {
15288
client_.reset(new http::TlsClient(pb));
16289
}
17290

@@ -21,10 +294,21 @@ GCS::~GCS() {
21294
std::error_code GCS::Connect(unsigned msec) {
22295
client_->set_connect_timeout_ms(msec);
23296

24-
return client_->Connect(kDomain, "443");
297+
return client_->Connect(kDomain, "443", ssl_ctx_);
25298
}
26299

27300
auto GCS::ListBuckets() -> ListBucketResult {
301+
string url = absl::StrCat("/storage/v1/b?project=", creds_provider_.project_id());
302+
absl::StrAppend(&url, "&fields=items,nextPageToken");
303+
304+
auto http_req = PrepareRequest(h2::verb::get, url, creds_provider_.access_token());
305+
306+
rj::Document doc;
307+
h2::response<h2::string_body> resp_msg;
308+
error_code ec = SendWithToken(&creds_provider_, client_.get(), &http_req, &resp_msg);
309+
if (ec)
310+
return nonstd::make_unexpected(ec);
311+
VLOG(2) << "ListResponse: " << resp_msg.body();
28312
return {};
29313
}
30314

0 commit comments

Comments
 (0)