Skip to content

Commit ed4aeb2

Browse files
committed
chore: revisit tls_socket code
Signed-off-by: Roman Gershman <romange@gmail.com>
1 parent 1fe964d commit ed4aeb2

File tree

4 files changed

+94
-103
lines changed

4 files changed

+94
-103
lines changed

util/tls/tls_engine.cc

+20-28
Original file line numberDiff line numberDiff line change
@@ -41,36 +41,28 @@ static Engine::OpResult ToOpResult(const SSL* ssl, int result, const char* locat
4141
return nonstd::make_unexpected(error);
4242
}
4343

44-
int want = SSL_want(ssl);
45-
46-
if (want == SSL_NOTHING) {
47-
int ssl_error = SSL_get_error(ssl, result);
48-
int io_err = errno;
49-
50-
switch (ssl_error) {
51-
case SSL_ERROR_ZERO_RETURN:
52-
break;
53-
case SSL_ERROR_SYSCALL:
54-
LOG(WARNING) << "SSL syscall error " << io_err << ":" << result << " " << location;
55-
break;
56-
case SSL_ERROR_SSL:
57-
LOG(WARNING) << "SSL protocol error " << io_err << ":" << result << " " << location;
58-
break;
59-
default:
60-
LOG(WARNING) << "Unexpected SSL error " << io_err << ":" << result << " " << location;
61-
break;
62-
}
63-
64-
return Engine::EOF_STREAM;
44+
int ssl_error = SSL_get_error(ssl, result);
45+
int io_err = errno;
46+
47+
switch (ssl_error) {
48+
case SSL_ERROR_ZERO_RETURN:
49+
break;
50+
case SSL_ERROR_WANT_READ:
51+
return Engine::NEED_READ_AND_MAYBE_WRITE;
52+
case SSL_ERROR_WANT_WRITE:
53+
VLOG(1) << "SSL_ERROR_WANT_WRITE " << location;
54+
return Engine::NEED_WRITE;
55+
case SSL_ERROR_SYSCALL:
56+
LOG(WARNING) << "SSL syscall error " << io_err << ":" << result << " " << location;
57+
break;
58+
case SSL_ERROR_SSL:
59+
LOG(WARNING) << "SSL protocol error " << io_err << ":" << result << " " << location;
60+
break;
61+
default:
62+
LOG(WARNING) << "Unexpected SSL error " << io_err << ":" << result << " " << location;
63+
break;
6564
}
6665

67-
if (SSL_WRITING == want)
68-
return Engine::NEED_WRITE;
69-
if (SSL_READING == want)
70-
return Engine::NEED_READ_AND_MAYBE_WRITE;
71-
72-
LOG(ERROR) << "Unsupported want value " << want << ", ssl_error: " << SSL_get_error(ssl, result);
73-
7466
return Engine::EOF_STREAM;
7567
}
7668

util/tls/tls_engine.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ class Engine {
1717
enum HandshakeType { CLIENT = 1, SERVER = 2 };
1818
enum OpCode {
1919
EOF_STREAM = -1,
20+
21+
// We use BIO buffers, therefore any SSL operation can end up writing to the internal BIO
22+
// and result in success, even though the data has not been flushed to the underlying socket.
23+
// See https://www.openssl.org/docs/man1.0.2/man3/BIO_new_bio_pair.html
24+
// As a result, we must flush output buffer (if OutputPending() > 0)if before we do any
25+
// Socket reads. We could flush after each SSL operation but that would result in fragmented
26+
// Socket writes which we want to avoid.
2027
NEED_READ_AND_MAYBE_WRITE = -2,
2128
NEED_WRITE = -3,
2229
};
@@ -89,7 +96,7 @@ class Engine {
8996
void CommitInput(unsigned sz);
9097

9198
// Returns size of pending data that needs to be flushed out from SSL to I/O.
92-
// See https://www.openssl.org/docs/man1.1.0/man3/BIO_new_bio_pair.html
99+
// See https://www.openssl.org/docs/man1.0.2/man3/BIO_new_bio_pair.html
93100
// Specifically, warning that says: "An application must not rely on the error value of
94101
// SSL_operation() but must assure that the write buffer is always flushed first".
95102
size_t OutputPending() const {

util/tls/tls_socket.cc

+65-74
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ auto TlsSocket::Shutdown(int how) -> error_code {
9191
Engine::OpResult op_result = engine_->Shutdown();
9292
if (op_result) {
9393
// engine_ could send notification messages to the peer.
94-
MaybeSendOutput();
94+
std::ignore = MaybeSendOutput();
9595
}
9696

9797
// In any case we should also shutdown the underlying TCP socket without relying on the
@@ -132,14 +132,10 @@ auto TlsSocket::Accept() -> AcceptResult {
132132
if (op_val >= 0) { // Shutdown or empty read/write may return 0.
133133
break;
134134
}
135-
if (op_val == Engine::EOF_STREAM) {
136-
return make_unexpected(make_error_code(errc::connection_reset));
137-
}
138-
if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
139-
ec = HandleSocketRead();
140-
if (ec)
141-
return make_unexpected(ec);
142-
}
135+
136+
ec = HandleOp(op_val);
137+
if (ec)
138+
return make_unexpected(ec);
143139
}
144140

145141
return nullptr;
@@ -162,36 +158,26 @@ error_code TlsSocket::Connect(const endpoint_type& endpoint,
162158

163159
// Flush the ssl data to the socket and run the loop that ensures handshaking converges.
164160
int op_val = *op_result;
165-
error_code ec;
166161

167162
// it should guide us to write and then read.
168163
DCHECK_EQ(op_val, Engine::NEED_READ_AND_MAYBE_WRITE);
169164
while (op_val < 0) {
170-
if (op_val == Engine::EOF_STREAM) {
171-
return make_error_code(errc::connection_reset);
172-
}
165+
error_code ec = HandleOp(op_val);
166+
if (ec)
167+
return ec;
173168

174-
if (op_val == Engine::NEED_WRITE) {
175-
ec = HandleSocketWrite();
176-
if (ec)
177-
return ec;
178-
} else if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
179-
ec = HandleSocketWrite();
180-
if (ec)
181-
return ec;
182-
183-
ec = HandleSocketRead();
184-
if (ec)
185-
return ec;
186-
}
187169
op_result = engine_->Handshake(Engine::HandshakeType::CLIENT);
188170
if (!op_result) {
189171
return std::error_code(op_result.error(), std::system_category());
190172
}
191173
op_val = *op_result;
192174
}
193175

194-
return ec;
176+
const auto* cipher = SSL_get_current_cipher(engine_->native_handle());
177+
VLOG(1) << "SSL handshake success, chosen " << SSL_CIPHER_get_name(cipher) << "/"
178+
<< SSL_CIPHER_get_version(cipher);
179+
180+
return {};
195181
}
196182

197183
auto TlsSocket::Close() -> error_code {
@@ -245,11 +231,6 @@ io::Result<size_t> TlsSocket::RecvMsg(const msghdr& msg, int flags) {
245231
return make_unexpected(SSL2Error(op_result.error()));
246232
}
247233

248-
error_code ec = MaybeSendOutput();
249-
if (ec) {
250-
return make_unexpected(ec);
251-
}
252-
253234
int op_val = *op_result;
254235
if (spin_count.Check(op_val <= 0)) {
255236
// Once every 30 seconds.
@@ -267,26 +248,18 @@ io::Result<size_t> TlsSocket::RecvMsg(const msghdr& msg, int flags) {
267248
++io;
268249
--io_len;
269250
if (io_len == 0)
270-
break;
251+
break; // Finished reading everything.
271252
dest = Engine::MutableBuffer{reinterpret_cast<uint8_t*>(io->iov_base), io->iov_len};
272253
}
273-
continue; // We read everything we asked for - lets retry.
254+
// We read everything we asked for but there are still buffers left to fill.
255+
continue;
274256
}
275257
break;
276258
}
277259

278-
if (read_total) // if we read something lets return it before we handle other states.
279-
break;
280-
281-
if (op_val == Engine::EOF_STREAM) {
282-
return make_unexpected(make_error_code(errc::connection_reset));
283-
}
284-
285-
if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
286-
ec = HandleSocketRead();
287-
if (ec)
288-
return make_unexpected(ec);
289-
}
260+
error_code ec = HandleOp(op_val);
261+
if (ec)
262+
return make_unexpected(ec);
290263
}
291264
return read_total;
292265
}
@@ -307,12 +280,12 @@ io::Result<size_t> TlsSocket::WriteSome(const iovec* ptr, uint32_t len) {
307280
// Chosen to be sufficiently smaller than the usual MTU (1500) and a multiple of 16.
308281
// IP - max 24 bytes. TCP - max 60 bytes. TLS - max 21 bytes.
309282
constexpr size_t kBufferSize = 1392;
310-
io::Result<size_t> ec;
283+
io::Result<size_t> res;
311284
size_t total_sent = 0;
312285

313286
while (len) {
314287
if (ptr->iov_len > kBufferSize || len == 1) {
315-
ec = SendBuffer(Engine::Buffer{reinterpret_cast<uint8_t*>(ptr->iov_base), ptr->iov_len});
288+
res = SendBuffer(Engine::Buffer{reinterpret_cast<uint8_t*>(ptr->iov_base), ptr->iov_len});
316289
ptr++;
317290
len--;
318291
} else {
@@ -324,18 +297,18 @@ io::Result<size_t> TlsSocket::WriteSome(const iovec* ptr, uint32_t len) {
324297
ptr++;
325298
len--;
326299
}
327-
ec = SendBuffer({scratch, buffered_size});
300+
res = SendBuffer({scratch, buffered_size});
328301
}
329-
if (!ec.has_value()) {
330-
return ec;
331-
} else {
332-
total_sent += ec.value();
302+
if (!res) {
303+
return res;
333304
}
305+
total_sent += *res;
334306
}
335307
return total_sent;
336308
}
337309

338310
io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
311+
// Sending buffer into ssl.
339312
DCHECK(engine_);
340313
DCHECK_GT(buf.size(), 0u);
341314

@@ -348,17 +321,7 @@ io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
348321
return make_unexpected(SSL2Error(op_result.error()));
349322
}
350323

351-
error_code ec = MaybeSendOutput();
352-
if (ec) {
353-
return make_unexpected(ec);
354-
}
355-
356324
int op_val = *op_result;
357-
if (spin_count.Check(op_val <= 0)) {
358-
// Once every 30 seconds.
359-
LOG_EVERY_T(WARNING, 30) << "IO loop spin limit reached. Limit: " << spin_count.Limit()
360-
<< " Spins: " << spin_count.Spins();
361-
}
362325

363326
if (op_val > 0) {
364327
send_total += op_val;
@@ -370,15 +333,15 @@ io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
370333
}
371334
}
372335

373-
if (op_val == Engine::EOF_STREAM) {
374-
return make_unexpected(make_error_code(errc::connection_reset));
336+
if (spin_count.Check(op_val <= 0)) {
337+
// Once every 30 seconds.
338+
LOG_EVERY_T(WARNING, 30) << "IO loop spin limit reached. Limit: " << spin_count.Limit()
339+
<< " Spins: " << spin_count.Spins();
375340
}
376341

377-
if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
378-
ec = HandleSocketRead();
379-
if (ec)
380-
return make_unexpected(ec);
381-
}
342+
error_code ec = HandleOp(op_val);
343+
if (ec)
344+
return make_unexpected(ec);
382345
}
383346

384347
return send_total;
@@ -395,6 +358,9 @@ SSL* TlsSocket::ssl_handle() {
395358
}
396359

397360
auto TlsSocket::MaybeSendOutput() -> error_code {
361+
if (engine_->OutputPending() == 0)
362+
return {};
363+
398364
// This function is present in both read and write paths.
399365
// meaning that both of them can be called concurrently from differrent fibers and then
400366
// race over flushing the output buffer. We use state_ to prevent that.
@@ -419,6 +385,10 @@ auto TlsSocket::MaybeSendOutput() -> error_code {
419385
}
420386

421387
auto TlsSocket::HandleSocketRead() -> error_code {
388+
error_code ec = MaybeSendOutput();
389+
if (ec)
390+
return ec;
391+
422392
if (state_ & READ_IN_PROGRESS) {
423393
// We need to Yield because otherwise we might end up in an infinite loop.
424394
// See also comments in MaybeSendOutput.
@@ -434,33 +404,54 @@ auto TlsSocket::HandleSocketRead() -> error_code {
434404
return esz.error();
435405
}
436406

407+
DVLOG(1) << "TlsSocket:Read " << *esz << " bytes";
408+
437409
engine_->CommitInput(*esz);
438410

439411
return error_code{};
440412
}
441413

442414
error_code TlsSocket::HandleSocketWrite() {
443415
Engine::Buffer buffer = engine_->PeekOutputBuf();
416+
DCHECK(!buffer.empty());
417+
418+
if (buffer.empty())
419+
return {};
444420

421+
// we do not allow concurrent writes from multiple fibers.
422+
state_ |= WRITE_IN_PROGRESS;
445423
while (!buffer.empty()) {
446-
// we do not allow concurrent writes from multiple fibers.
447-
state_ |= WRITE_IN_PROGRESS;
448424
io::Result<size_t> write_result = next_sock_->WriteSome(buffer);
449425

450-
// Safe to clear here since the code below is atomic fiber-wise.
451-
state_ &= ~WRITE_IN_PROGRESS;
452426
DCHECK(engine_);
453427
if (!write_result) {
428+
state_ &= ~WRITE_IN_PROGRESS;
429+
454430
return write_result.error();
455431
}
456432
CHECK_GT(*write_result, 0u);
457433
engine_->ConsumeOutputBuf(*write_result);
458434
buffer.remove_prefix(*write_result);
459435
}
436+
DCHECK_EQ(engine_->OutputPending(), 0u);
437+
438+
state_ &= ~WRITE_IN_PROGRESS;
460439

461440
return error_code{};
462441
}
463442

443+
error_code TlsSocket::HandleOp(int op_val) {
444+
switch (op_val) {
445+
case Engine::EOF_STREAM:
446+
return make_error_code(errc::connection_reset);
447+
case Engine::NEED_READ_AND_MAYBE_WRITE:
448+
return HandleSocketRead();
449+
default:
450+
LOG(DFATAL) << "Unsupported " << op_val;
451+
}
452+
return {};
453+
}
454+
464455
TlsSocket::endpoint_type TlsSocket::LocalEndpoint() const {
465456
return next_sock_->LocalEndpoint();
466457
}

util/tls/tls_socket.h

+1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class TlsSocket final : public FiberSocketBase {
9595
error_code HandleSocketRead();
9696

9797
error_code HandleSocketWrite();
98+
error_code HandleOp(int op);
9899

99100
std::unique_ptr<FiberSocketBase> next_sock_;
100101
std::unique_ptr<Engine> engine_;

0 commit comments

Comments
 (0)