Skip to content

Commit eb166e1

Browse files
committed
chore: fixes in tls and client sockets
Signed-off-by: Roman Gershman <romange@gmail.com>
1 parent bd38683 commit eb166e1

11 files changed

+119
-65
lines changed

util/fiber_socket_base.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ class FiberSocketBase : public io::Sink, public io::AsyncSink, public io::Source
3939

4040
ABSL_MUST_USE_RESULT virtual AcceptResult Accept() = 0;
4141

42-
ABSL_MUST_USE_RESULT virtual error_code Connect(const endpoint_type& ep) = 0;
42+
ABSL_MUST_USE_RESULT virtual error_code Connect(const endpoint_type& ep,
43+
std::function<void(int)> on_pre_connect = {}) = 0;
4344

4445
ABSL_MUST_USE_RESULT virtual error_code Close() = 0;
4546

@@ -200,8 +201,8 @@ class LinuxSocketBase : public FiberSocketBase {
200201
// gives me 256M descriptors.
201202
int32_t fd_;
202203

203-
private:
204-
uint32_t timeout_ = UINT32_MAX;
204+
private:
205+
uint32_t timeout_ = UINT32_MAX;
205206
};
206207

207208
void SetNonBlocking(int fd);

util/fibers/epoll_socket.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ auto EpollSocket::Accept() -> AcceptResult {
189189
return nonstd::make_unexpected(ec);
190190
}
191191

192-
auto EpollSocket::Connect(const endpoint_type& ep) -> error_code {
192+
error_code EpollSocket::Connect(const endpoint_type& ep, std::function<void(int)> on_pre_connect) {
193193
CHECK_EQ(fd_, -1);
194194
CHECK(proactor() && proactor()->InMyThread());
195195

@@ -208,7 +208,9 @@ auto EpollSocket::Connect(const endpoint_type& ep) -> error_code {
208208
write_context_ = detail::FiberActive();
209209
absl::Cleanup clean = [this]() { write_context_ = nullptr; };
210210

211-
// RegisterEvents(GetProactor()->ev_loop_fd(), fd, arm_index_ + 1024);
211+
if (on_pre_connect) {
212+
on_pre_connect(fd);
213+
}
212214

213215
DVSOCK(2) << "Connecting";
214216

util/fibers/epoll_socket.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ class EpollSocket : public LinuxSocketBase {
2020

2121
ABSL_MUST_USE_RESULT AcceptResult Accept() final;
2222

23-
ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep) final;
23+
ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep,
24+
std::function<void(int)> on_pre_connect) final;
2425
ABSL_MUST_USE_RESULT error_code Close() final;
2526

2627
// Really need here expected.

util/fibers/uring_socket.cc

+6-5
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ auto UringSocket::Accept() -> AcceptResult {
146146
return fs;
147147
}
148148

149-
auto UringSocket::Connect(const endpoint_type& ep) -> error_code {
149+
auto UringSocket::Connect(const endpoint_type& ep, std::function<void(int)> on_pre_connect) -> error_code {
150150
CHECK_EQ(fd_, -1);
151151
CHECK(proactor() && proactor()->InMyThread());
152152

@@ -163,12 +163,13 @@ auto UringSocket::Connect(const endpoint_type& ep) -> error_code {
163163
// TODO: support direct descriptors. For now client sockets always use regular linux fds.
164164
fd_ = fd << kFdShift;
165165

166-
IoResult io_res;
167-
ep.data();
166+
if (on_pre_connect) {
167+
on_pre_connect(fd);
168+
}
168169

169170
FiberCall fc(proactor, timeout());
170171
fc->PrepConnect(fd, (const sockaddr*)ep.data(), ep.size());
171-
io_res = fc.Get();
172+
IoResult io_res = fc.Get();
172173

173174
if (io_res < 0) { // In that case connect returns -errno.
174175
ec = error_code(-io_res, system_category());
@@ -333,7 +334,7 @@ io::Result<size_t> UringSocket::Recv(const io::MutableBytes& mb, int flags) {
333334
Proactor* p = GetProactor();
334335
DCHECK(ProactorBase::me() == p);
335336

336-
VSOCK(2) << "Recv [" << fd << "] " << flags;
337+
VSOCK(2) << "Recv [" << fd << "], flags: " << flags;
337338
ssize_t res;
338339
while (true) {
339340
FiberCall fc(p, timeout());

util/fibers/uring_socket.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ class UringSocket : public LinuxSocketBase {
3232

3333
ABSL_MUST_USE_RESULT AcceptResult Accept() final;
3434

35-
ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep) final;
35+
ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep,
36+
std::function<void(int)> on_pre_connect) final;
3637
ABSL_MUST_USE_RESULT error_code Close() final;
3738

3839
io::Result<size_t> WriteSome(const iovec* v, uint32_t len) override;
@@ -75,7 +76,7 @@ class UringSocket : public LinuxSocketBase {
7576

7677
struct ErrorCbRefWrapper {
7778
uint32_t error_cb_id = 0;
78-
uint32_t ref_count = 2; // one for the socket reference, one for the completion lambda.
79+
uint32_t ref_count = 2; // one for the socket reference, one for the completion lambda.
7980
std::function<void(uint32_t)> cb;
8081

8182
static ErrorCbRefWrapper* New(std::function<void(uint32_t)> cb) {

util/http/http_client.cc

+10-5
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,15 @@ std::error_code Client::Reconnect() {
8282
return berr;
8383

8484
FiberSocketBase* sock = proactor_->CreateSocket();
85-
if (on_connect_cb_) {
86-
on_connect_cb_(sock->native_handle());
87-
}
85+
8886
socket_.reset(sock);
8987
FiberSocketBase::endpoint_type ep{address, port_};
90-
return socket_->Connect(ep);
88+
auto on_connect = [this](int fd) {
89+
if (on_connect_cb_) {
90+
on_connect_cb_(fd);
91+
}
92+
};
93+
return socket_->Connect(ep, std::move(on_connect));
9194
}
9295

9396
#if 0
@@ -181,7 +184,9 @@ std::error_code TlsClient::Connect(string_view host, string_view service, SSL_CT
181184
// verify server cert using server hostname
182185
SSL_dane_enable(ssl_handle, host);
183186
ec = tls_socket->Connect(FiberSocketBase::endpoint_type{});
184-
if (!ec) {
187+
if (ec) {
188+
std::ignore = tls_socket->Close();
189+
} else {
185190
socket_.reset(tls_socket.release());
186191
}
187192
}

util/tls/tls_engine.cc

+10-5
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,13 @@ Engine::Engine(SSL_CTX* context) : ssl_(::SSL_new(context)) {
9797
// SSL_set0_[rw]bio take ownership of the passed reference,
9898
// so if we call both with the same BIO, we need the refcount to be 2.
9999
BIO_up_ref(int_bio);
100+
100101
SSL_set0_rbio(ssl_, int_bio);
101102
SSL_set0_wbio(ssl_, int_bio);
103+
104+
// Debugging traces.
105+
// SSL_set_msg_callback(ssl_, SSL_trace);
106+
// SSL_set_msg_callback_arg(ssl_, BIO_new_fp(stdout,0));
102107
}
103108

104109
Engine::~Engine() {
@@ -111,21 +116,21 @@ Engine::~Engine() {
111116
}
112117

113118

114-
auto Engine::FetchOutputBuf() -> BufResult {
119+
auto Engine::FetchOutputBuf() -> Buffer {
115120
char* buf = nullptr;
116121

117122
int res = BIO_nread(external_bio_, &buf, INT_MAX);
118123
if (res < 0) {
119124
unsigned long error = ::ERR_get_error();
120-
return nonstd::make_unexpected(error);
125+
LOG(DFATAL) << "Unexpected result " << res << " " << error;
126+
127+
return Buffer{};
121128
}
122129

123130
return Buffer(reinterpret_cast<const uint8_t*>(buf), res);
124131
}
125132

126-
// TODO: to consider replacing BufResult with Buffer since
127-
// it seems BIO_C_NREAD0 should not return negative values when used properly.
128-
auto Engine::PeekOutputBuf() -> BufResult {
133+
auto Engine::PeekOutputBuf() -> Buffer {
129134
char* buf = nullptr;
130135

131136
long res = BIO_ctrl(external_bio_, BIO_C_NREAD0, 0, &buf);

util/tls/tls_engine.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ class Engine {
3333
// write. In any case for non-error OpResult a caller must check OutputPending and write the
3434
// output buffer to the appropriate channel.
3535
using OpResult = io::Result<int, unsigned long>;
36-
using BufResult = io::Result<Buffer, unsigned long>;
3736

3837
// Construct a new engine for the specified context.
3938
explicit Engine(SSL_CTX* context);
@@ -67,11 +66,11 @@ class Engine {
6766
//! Returns output (read) buffer. This operation is destructive, i.e. after calling
6867
//! this function the buffer is being consumed.
6968
//! See OutputPending() for checking if there is a output buffer to consume.
70-
BufResult FetchOutputBuf();
69+
Buffer FetchOutputBuf();
7170

7271
//! Returns output buffer which is the read buffer of tls engine.
7372
//! This operation is not destructive.
74-
BufResult PeekOutputBuf();
73+
Buffer PeekOutputBuf();
7574

7675
//! Tells the engine that sz bytes were consumed from the output buffer.
7776
//! sz should be not greater than the buffer size from the last PeekOutputBuf() call.

util/tls/tls_engine_test.cc

+6-7
Original file line numberDiff line numberDiff line change
@@ -143,18 +143,17 @@ static unsigned long RunPeer(SslStreamTest::Options opts, SslStreamTest::OpCb cb
143143
if (opts.drain_output)
144144
src->FetchOutputBuf();
145145
else {
146-
auto buf_result = src->PeekOutputBuf();
147-
CHECK(buf_result);
148-
VLOG(1) << opts.name << " wrote " << buf_result->size() << " bytes";
149-
CHECK(!buf_result->empty());
146+
auto buffer = src->PeekOutputBuf();
147+
VLOG(1) << opts.name << " wrote " << buffer.size() << " bytes";
148+
CHECK(!buffer.empty());
150149

151150
if (opts.mutate_indx) {
152-
uint8_t* mem = const_cast<uint8_t*>(buf_result->data());
153-
mem[opts.mutate_indx % buf_result->size()] = opts.mutate_val;
151+
uint8_t* mem = const_cast<uint8_t*>(buffer.data());
152+
mem[opts.mutate_indx % buffer.size()] = opts.mutate_val;
154153
opts.mutate_indx = 0;
155154
}
156155

157-
auto write_result = dest->WriteBuf(*buf_result);
156+
auto write_result = dest->WriteBuf(buffer);
158157
if (!write_result) {
159158
return write_result.error();
160159
}

util/tls/tls_socket.cc

+68-30
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ auto TlsSocket::Accept() -> AcceptResult {
136136
return make_unexpected(make_error_code(errc::connection_reset));
137137
}
138138
if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
139-
ec = HandleRead();
139+
ec = HandleSocketRead();
140140
if (ec)
141141
return make_unexpected(ec);
142142
}
@@ -145,19 +145,53 @@ auto TlsSocket::Accept() -> AcceptResult {
145145
return nullptr;
146146
}
147147

148-
auto TlsSocket::Connect(const endpoint_type& endpoint) -> error_code {
148+
error_code TlsSocket::Connect(const endpoint_type& endpoint,
149+
std::function<void(int)> on_pre_connect) {
149150
DCHECK(engine_);
150-
auto io_result = engine_->Handshake(Engine::HandshakeType::CLIENT);
151-
if (!io_result.has_value()) {
152-
return std::error_code(io_result.error(), std::system_category());
151+
Engine::OpResult op_result = engine_->Handshake(Engine::HandshakeType::CLIENT);
152+
if (!op_result) {
153+
return std::error_code(op_result.error(), std::system_category());
153154
}
154155

155156
// If the socket is already open, we should not call connect on it
156-
if (IsOpen()) {
157-
return {};
157+
if (!IsOpen()) {
158+
error_code ec = next_sock_->Connect(endpoint, std::move(on_pre_connect));
159+
if (ec)
160+
return ec;
161+
}
162+
163+
// Flush the ssl data to the socket and run the loop that ensures handshaking converges.
164+
int op_val = *op_result;
165+
error_code ec;
166+
167+
// it should guide us to write and then read.
168+
DCHECK_EQ(op_val, Engine::NEED_READ_AND_MAYBE_WRITE);
169+
while (op_val < 0) {
170+
if (op_val == Engine::EOF_STREAM) {
171+
return make_error_code(errc::connection_reset);
172+
}
173+
174+
if (op_val == Engine::NEED_WRITE) {
175+
ec = HandleSocketWrite();
176+
if (ec)
177+
return ec;
178+
} else if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
179+
ec = HandleSocketWrite();
180+
if (ec)
181+
return ec;
182+
183+
ec = HandleSocketRead();
184+
if (ec)
185+
return ec;
186+
}
187+
op_result = engine_->Handshake(Engine::HandshakeType::CLIENT);
188+
if (!op_result) {
189+
return std::error_code(op_result.error(), std::system_category());
190+
}
191+
op_val = *op_result;
158192
}
159193

160-
return next_sock_->Connect(endpoint);
194+
return ec;
161195
}
162196

163197
auto TlsSocket::Close() -> error_code {
@@ -249,7 +283,7 @@ io::Result<size_t> TlsSocket::RecvMsg(const msghdr& msg, int flags) {
249283
}
250284

251285
if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
252-
ec = HandleRead();
286+
ec = HandleSocketRead();
253287
if (ec)
254288
return make_unexpected(ec);
255289
}
@@ -341,7 +375,7 @@ io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
341375
}
342376

343377
if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
344-
ec = HandleRead();
378+
ec = HandleSocketRead();
345379
if (ec)
346380
return make_unexpected(ec);
347381
}
@@ -381,28 +415,10 @@ auto TlsSocket::MaybeSendOutput() -> error_code {
381415
return error_code{};
382416
}
383417

384-
auto buf_result = engine_->PeekOutputBuf();
385-
CHECK(buf_result);
386-
387-
if (!buf_result->empty()) {
388-
// we do not allow concurrent writes from multiple fibers.
389-
state_ |= WRITE_IN_PROGRESS;
390-
io::Result<size_t> write_result = next_sock_->WriteSome(*buf_result);
391-
392-
// Safe to clear here since the code below is atomic fiber-wise.
393-
state_ &= ~WRITE_IN_PROGRESS;
394-
DCHECK(engine_);
395-
if (!write_result) {
396-
return write_result.error();
397-
}
398-
CHECK_GT(*write_result, 0u);
399-
engine_->ConsumeOutputBuf(*write_result);
400-
}
401-
402-
return error_code{};
418+
return HandleSocketWrite();
403419
}
404420

405-
auto TlsSocket::HandleRead() -> error_code {
421+
auto TlsSocket::HandleSocketRead() -> error_code {
406422
if (state_ & READ_IN_PROGRESS) {
407423
// We need to Yield because otherwise we might end up in an infinite loop.
408424
// See also comments in MaybeSendOutput.
@@ -423,6 +439,28 @@ auto TlsSocket::HandleRead() -> error_code {
423439
return error_code{};
424440
}
425441

442+
error_code TlsSocket::HandleSocketWrite() {
443+
Engine::Buffer buffer = engine_->PeekOutputBuf();
444+
445+
while (!buffer.empty()) {
446+
// we do not allow concurrent writes from multiple fibers.
447+
state_ |= WRITE_IN_PROGRESS;
448+
io::Result<size_t> write_result = next_sock_->WriteSome(buffer);
449+
450+
// Safe to clear here since the code below is atomic fiber-wise.
451+
state_ &= ~WRITE_IN_PROGRESS;
452+
DCHECK(engine_);
453+
if (!write_result) {
454+
return write_result.error();
455+
}
456+
CHECK_GT(*write_result, 0u);
457+
engine_->ConsumeOutputBuf(*write_result);
458+
buffer.remove_prefix(*write_result);
459+
}
460+
461+
return error_code{};
462+
}
463+
426464
TlsSocket::endpoint_type TlsSocket::LocalEndpoint() const {
427465
return next_sock_->LocalEndpoint();
428466
}

util/tls/tls_socket.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class TlsSocket final : public FiberSocketBase {
3737

3838
// The endpoint should not really pass here, it is to keep
3939
// the interface with FiberSocketBase.
40-
error_code Connect(const endpoint_type&) final;
40+
error_code Connect(const endpoint_type& ep, std::function<void(int)> on_pre_connect = {}) final;
4141

4242
error_code Close() final;
4343

@@ -92,7 +92,9 @@ class TlsSocket final : public FiberSocketBase {
9292
error_code MaybeSendOutput();
9393

9494
/// Read encrypted data from the network socket and feed it into the TLS engine.
95-
error_code HandleRead();
95+
error_code HandleSocketRead();
96+
97+
error_code HandleSocketWrite();
9698

9799
std::unique_ptr<FiberSocketBase> next_sock_;
98100
std::unique_ptr<Engine> engine_;

0 commit comments

Comments
 (0)