Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: fixes in tls and client sockets #295

Merged
merged 1 commit into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions util/fiber_socket_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class FiberSocketBase : public io::Sink, public io::AsyncSink, public io::Source

ABSL_MUST_USE_RESULT virtual AcceptResult Accept() = 0;

ABSL_MUST_USE_RESULT virtual error_code Connect(const endpoint_type& ep) = 0;
ABSL_MUST_USE_RESULT virtual error_code Connect(const endpoint_type& ep,
std::function<void(int)> on_pre_connect = {}) = 0;

ABSL_MUST_USE_RESULT virtual error_code Close() = 0;

Expand Down Expand Up @@ -200,8 +201,8 @@ class LinuxSocketBase : public FiberSocketBase {
// gives me 256M descriptors.
int32_t fd_;

private:
uint32_t timeout_ = UINT32_MAX;
private:
uint32_t timeout_ = UINT32_MAX;
};

void SetNonBlocking(int fd);
Expand Down
6 changes: 4 additions & 2 deletions util/fibers/epoll_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ auto EpollSocket::Accept() -> AcceptResult {
return nonstd::make_unexpected(ec);
}

auto EpollSocket::Connect(const endpoint_type& ep) -> error_code {
error_code EpollSocket::Connect(const endpoint_type& ep, std::function<void(int)> on_pre_connect) {
CHECK_EQ(fd_, -1);
CHECK(proactor() && proactor()->InMyThread());

Expand All @@ -208,7 +208,9 @@ auto EpollSocket::Connect(const endpoint_type& ep) -> error_code {
write_context_ = detail::FiberActive();
absl::Cleanup clean = [this]() { write_context_ = nullptr; };

// RegisterEvents(GetProactor()->ev_loop_fd(), fd, arm_index_ + 1024);
if (on_pre_connect) {
on_pre_connect(fd);
}

DVSOCK(2) << "Connecting";

Expand Down
3 changes: 2 additions & 1 deletion util/fibers/epoll_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class EpollSocket : public LinuxSocketBase {

ABSL_MUST_USE_RESULT AcceptResult Accept() final;

ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep) final;
ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep,
std::function<void(int)> on_pre_connect) final;
ABSL_MUST_USE_RESULT error_code Close() final;

// Really need here expected.
Expand Down
11 changes: 6 additions & 5 deletions util/fibers/uring_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ auto UringSocket::Accept() -> AcceptResult {
return fs;
}

auto UringSocket::Connect(const endpoint_type& ep) -> error_code {
auto UringSocket::Connect(const endpoint_type& ep, std::function<void(int)> on_pre_connect) -> error_code {
CHECK_EQ(fd_, -1);
CHECK(proactor() && proactor()->InMyThread());

Expand All @@ -163,12 +163,13 @@ auto UringSocket::Connect(const endpoint_type& ep) -> error_code {
// TODO: support direct descriptors. For now client sockets always use regular linux fds.
fd_ = fd << kFdShift;

IoResult io_res;
ep.data();
if (on_pre_connect) {
on_pre_connect(fd);
}

FiberCall fc(proactor, timeout());
fc->PrepConnect(fd, (const sockaddr*)ep.data(), ep.size());
io_res = fc.Get();
IoResult io_res = fc.Get();

if (io_res < 0) { // In that case connect returns -errno.
ec = error_code(-io_res, system_category());
Expand Down Expand Up @@ -333,7 +334,7 @@ io::Result<size_t> UringSocket::Recv(const io::MutableBytes& mb, int flags) {
Proactor* p = GetProactor();
DCHECK(ProactorBase::me() == p);

VSOCK(2) << "Recv [" << fd << "] " << flags;
VSOCK(2) << "Recv [" << fd << "], flags: " << flags;
ssize_t res;
while (true) {
FiberCall fc(p, timeout());
Expand Down
5 changes: 3 additions & 2 deletions util/fibers/uring_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class UringSocket : public LinuxSocketBase {

ABSL_MUST_USE_RESULT AcceptResult Accept() final;

ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep) final;
ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep,
std::function<void(int)> on_pre_connect) final;
ABSL_MUST_USE_RESULT error_code Close() final;

io::Result<size_t> WriteSome(const iovec* v, uint32_t len) override;
Expand Down Expand Up @@ -75,7 +76,7 @@ class UringSocket : public LinuxSocketBase {

struct ErrorCbRefWrapper {
uint32_t error_cb_id = 0;
uint32_t ref_count = 2; // one for the socket reference, one for the completion lambda.
uint32_t ref_count = 2; // one for the socket reference, one for the completion lambda.
std::function<void(uint32_t)> cb;

static ErrorCbRefWrapper* New(std::function<void(uint32_t)> cb) {
Expand Down
15 changes: 10 additions & 5 deletions util/http/http_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,15 @@ std::error_code Client::Reconnect() {
return berr;

FiberSocketBase* sock = proactor_->CreateSocket();
if (on_connect_cb_) {
on_connect_cb_(sock->native_handle());
}

socket_.reset(sock);
FiberSocketBase::endpoint_type ep{address, port_};
return socket_->Connect(ep);
auto on_connect = [this](int fd) {
if (on_connect_cb_) {
on_connect_cb_(fd);
}
};
return socket_->Connect(ep, std::move(on_connect));
}

#if 0
Expand Down Expand Up @@ -181,7 +184,9 @@ std::error_code TlsClient::Connect(string_view host, string_view service, SSL_CT
// verify server cert using server hostname
SSL_dane_enable(ssl_handle, host);
ec = tls_socket->Connect(FiberSocketBase::endpoint_type{});
if (!ec) {
if (ec) {
std::ignore = tls_socket->Close();
} else {
socket_.reset(tls_socket.release());
}
}
Expand Down
15 changes: 10 additions & 5 deletions util/tls/tls_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,13 @@ Engine::Engine(SSL_CTX* context) : ssl_(::SSL_new(context)) {
// SSL_set0_[rw]bio take ownership of the passed reference,
// so if we call both with the same BIO, we need the refcount to be 2.
BIO_up_ref(int_bio);

SSL_set0_rbio(ssl_, int_bio);
SSL_set0_wbio(ssl_, int_bio);

// Debugging traces.
// SSL_set_msg_callback(ssl_, SSL_trace);
// SSL_set_msg_callback_arg(ssl_, BIO_new_fp(stdout,0));
}

Engine::~Engine() {
Expand All @@ -111,21 +116,21 @@ Engine::~Engine() {
}


auto Engine::FetchOutputBuf() -> BufResult {
auto Engine::FetchOutputBuf() -> Buffer {
char* buf = nullptr;

int res = BIO_nread(external_bio_, &buf, INT_MAX);
if (res < 0) {
unsigned long error = ::ERR_get_error();
return nonstd::make_unexpected(error);
LOG(DFATAL) << "Unexpected result " << res << " " << error;

return Buffer{};
}

return Buffer(reinterpret_cast<const uint8_t*>(buf), res);
}

// TODO: to consider replacing BufResult with Buffer since
// it seems BIO_C_NREAD0 should not return negative values when used properly.
auto Engine::PeekOutputBuf() -> BufResult {
auto Engine::PeekOutputBuf() -> Buffer {
char* buf = nullptr;

long res = BIO_ctrl(external_bio_, BIO_C_NREAD0, 0, &buf);
Expand Down
5 changes: 2 additions & 3 deletions util/tls/tls_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class Engine {
// write. In any case for non-error OpResult a caller must check OutputPending and write the
// output buffer to the appropriate channel.
using OpResult = io::Result<int, unsigned long>;
using BufResult = io::Result<Buffer, unsigned long>;

// Construct a new engine for the specified context.
explicit Engine(SSL_CTX* context);
Expand Down Expand Up @@ -67,11 +66,11 @@ class Engine {
//! Returns output (read) buffer. This operation is destructive, i.e. after calling
//! this function the buffer is being consumed.
//! See OutputPending() for checking if there is a output buffer to consume.
BufResult FetchOutputBuf();
Buffer FetchOutputBuf();

//! Returns output buffer which is the read buffer of tls engine.
//! This operation is not destructive.
BufResult PeekOutputBuf();
Buffer PeekOutputBuf();

//! Tells the engine that sz bytes were consumed from the output buffer.
//! sz should be not greater than the buffer size from the last PeekOutputBuf() call.
Expand Down
13 changes: 6 additions & 7 deletions util/tls/tls_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,17 @@ static unsigned long RunPeer(SslStreamTest::Options opts, SslStreamTest::OpCb cb
if (opts.drain_output)
src->FetchOutputBuf();
else {
auto buf_result = src->PeekOutputBuf();
CHECK(buf_result);
VLOG(1) << opts.name << " wrote " << buf_result->size() << " bytes";
CHECK(!buf_result->empty());
auto buffer = src->PeekOutputBuf();
VLOG(1) << opts.name << " wrote " << buffer.size() << " bytes";
CHECK(!buffer.empty());

if (opts.mutate_indx) {
uint8_t* mem = const_cast<uint8_t*>(buf_result->data());
mem[opts.mutate_indx % buf_result->size()] = opts.mutate_val;
uint8_t* mem = const_cast<uint8_t*>(buffer.data());
mem[opts.mutate_indx % buffer.size()] = opts.mutate_val;
opts.mutate_indx = 0;
}

auto write_result = dest->WriteBuf(*buf_result);
auto write_result = dest->WriteBuf(buffer);
if (!write_result) {
return write_result.error();
}
Expand Down
98 changes: 68 additions & 30 deletions util/tls/tls_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ auto TlsSocket::Accept() -> AcceptResult {
return make_unexpected(make_error_code(errc::connection_reset));
}
if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleRead();
ec = HandleSocketRead();
if (ec)
return make_unexpected(ec);
}
Expand All @@ -145,19 +145,53 @@ auto TlsSocket::Accept() -> AcceptResult {
return nullptr;
}

auto TlsSocket::Connect(const endpoint_type& endpoint) -> error_code {
error_code TlsSocket::Connect(const endpoint_type& endpoint,
std::function<void(int)> on_pre_connect) {
DCHECK(engine_);
auto io_result = engine_->Handshake(Engine::HandshakeType::CLIENT);
if (!io_result.has_value()) {
return std::error_code(io_result.error(), std::system_category());
Engine::OpResult op_result = engine_->Handshake(Engine::HandshakeType::CLIENT);
if (!op_result) {
return std::error_code(op_result.error(), std::system_category());
}

// If the socket is already open, we should not call connect on it
if (IsOpen()) {
return {};
if (!IsOpen()) {
error_code ec = next_sock_->Connect(endpoint, std::move(on_pre_connect));
if (ec)
return ec;
}

// Flush the ssl data to the socket and run the loop that ensures handshaking converges.
int op_val = *op_result;
error_code ec;

// it should guide us to write and then read.
DCHECK_EQ(op_val, Engine::NEED_READ_AND_MAYBE_WRITE);
while (op_val < 0) {
if (op_val == Engine::EOF_STREAM) {
return make_error_code(errc::connection_reset);
}

if (op_val == Engine::NEED_WRITE) {
ec = HandleSocketWrite();
if (ec)
return ec;
} else if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleSocketWrite();
if (ec)
return ec;

ec = HandleSocketRead();
if (ec)
return ec;
}
op_result = engine_->Handshake(Engine::HandshakeType::CLIENT);
if (!op_result) {
return std::error_code(op_result.error(), std::system_category());
}
op_val = *op_result;
}

return next_sock_->Connect(endpoint);
return ec;
}

auto TlsSocket::Close() -> error_code {
Expand Down Expand Up @@ -249,7 +283,7 @@ io::Result<size_t> TlsSocket::RecvMsg(const msghdr& msg, int flags) {
}

if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleRead();
ec = HandleSocketRead();
if (ec)
return make_unexpected(ec);
}
Expand Down Expand Up @@ -341,7 +375,7 @@ io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
}

if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleRead();
ec = HandleSocketRead();
if (ec)
return make_unexpected(ec);
}
Expand Down Expand Up @@ -381,28 +415,10 @@ auto TlsSocket::MaybeSendOutput() -> error_code {
return error_code{};
}

auto buf_result = engine_->PeekOutputBuf();
CHECK(buf_result);

if (!buf_result->empty()) {
// we do not allow concurrent writes from multiple fibers.
state_ |= WRITE_IN_PROGRESS;
io::Result<size_t> write_result = next_sock_->WriteSome(*buf_result);

// Safe to clear here since the code below is atomic fiber-wise.
state_ &= ~WRITE_IN_PROGRESS;
DCHECK(engine_);
if (!write_result) {
return write_result.error();
}
CHECK_GT(*write_result, 0u);
engine_->ConsumeOutputBuf(*write_result);
}

return error_code{};
return HandleSocketWrite();
}

auto TlsSocket::HandleRead() -> error_code {
auto TlsSocket::HandleSocketRead() -> error_code {
if (state_ & READ_IN_PROGRESS) {
// We need to Yield because otherwise we might end up in an infinite loop.
// See also comments in MaybeSendOutput.
Expand All @@ -423,6 +439,28 @@ auto TlsSocket::HandleRead() -> error_code {
return error_code{};
}

error_code TlsSocket::HandleSocketWrite() {
Engine::Buffer buffer = engine_->PeekOutputBuf();

while (!buffer.empty()) {
// we do not allow concurrent writes from multiple fibers.
state_ |= WRITE_IN_PROGRESS;
io::Result<size_t> write_result = next_sock_->WriteSome(buffer);

// Safe to clear here since the code below is atomic fiber-wise.
state_ &= ~WRITE_IN_PROGRESS;
DCHECK(engine_);
if (!write_result) {
return write_result.error();
}
CHECK_GT(*write_result, 0u);
engine_->ConsumeOutputBuf(*write_result);
buffer.remove_prefix(*write_result);
}

return error_code{};
}

TlsSocket::endpoint_type TlsSocket::LocalEndpoint() const {
return next_sock_->LocalEndpoint();
}
Expand Down
6 changes: 4 additions & 2 deletions util/tls/tls_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class TlsSocket final : public FiberSocketBase {

// The endpoint should not really pass here, it is to keep
// the interface with FiberSocketBase.
error_code Connect(const endpoint_type&) final;
error_code Connect(const endpoint_type& ep, std::function<void(int)> on_pre_connect = {}) final;

error_code Close() final;

Expand Down Expand Up @@ -92,7 +92,9 @@ class TlsSocket final : public FiberSocketBase {
error_code MaybeSendOutput();

/// Read encrypted data from the network socket and feed it into the TLS engine.
error_code HandleRead();
error_code HandleSocketRead();

error_code HandleSocketWrite();

std::unique_ptr<FiberSocketBase> next_sock_;
std::unique_ptr<Engine> engine_;
Expand Down
Loading