@@ -91,7 +91,7 @@ auto TlsSocket::Shutdown(int how) -> error_code {
91
91
Engine::OpResult op_result = engine_->Shutdown ();
92
92
if (op_result) {
93
93
// engine_ could send notification messages to the peer.
94
- MaybeSendOutput ();
94
+ std::ignore = MaybeSendOutput ();
95
95
}
96
96
97
97
// In any case we should also shutdown the underlying TCP socket without relying on the
@@ -132,14 +132,10 @@ auto TlsSocket::Accept() -> AcceptResult {
132
132
if (op_val >= 0 ) { // Shutdown or empty read/write may return 0.
133
133
break ;
134
134
}
135
- if (op_val == Engine::EOF_STREAM) {
136
- return make_unexpected (make_error_code (errc::connection_reset));
137
- }
138
- if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
139
- ec = HandleSocketRead ();
140
- if (ec)
141
- return make_unexpected (ec);
142
- }
135
+
136
+ ec = HandleOp (op_val);
137
+ if (ec)
138
+ return make_unexpected (ec);
143
139
}
144
140
145
141
return nullptr ;
@@ -162,36 +158,26 @@ error_code TlsSocket::Connect(const endpoint_type& endpoint,
162
158
163
159
// Flush the ssl data to the socket and run the loop that ensures handshaking converges.
164
160
int op_val = *op_result;
165
- error_code ec;
166
161
167
162
// it should guide us to write and then read.
168
163
DCHECK_EQ (op_val, Engine::NEED_READ_AND_MAYBE_WRITE);
169
164
while (op_val < 0 ) {
170
- if (op_val == Engine::EOF_STREAM) {
171
- return make_error_code (errc::connection_reset);
172
- }
165
+ error_code ec = HandleOp (op_val);
166
+ if (ec)
167
+ return ec;
173
168
174
- if (op_val == Engine::NEED_WRITE) {
175
- ec = HandleSocketWrite ();
176
- if (ec)
177
- return ec;
178
- } else if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
179
- ec = HandleSocketWrite ();
180
- if (ec)
181
- return ec;
182
-
183
- ec = HandleSocketRead ();
184
- if (ec)
185
- return ec;
186
- }
187
169
op_result = engine_->Handshake (Engine::HandshakeType::CLIENT);
188
170
if (!op_result) {
189
171
return std::error_code (op_result.error (), std::system_category ());
190
172
}
191
173
op_val = *op_result;
192
174
}
193
175
194
- return ec;
176
+ const auto * cipher = SSL_get_current_cipher (engine_->native_handle ());
177
+ VLOG (1 ) << " SSL handshake success, chosen " << SSL_CIPHER_get_name (cipher) << " /"
178
+ << SSL_CIPHER_get_version (cipher);
179
+
180
+ return {};
195
181
}
196
182
197
183
auto TlsSocket::Close () -> error_code {
@@ -245,11 +231,6 @@ io::Result<size_t> TlsSocket::RecvMsg(const msghdr& msg, int flags) {
245
231
return make_unexpected (SSL2Error (op_result.error ()));
246
232
}
247
233
248
- error_code ec = MaybeSendOutput ();
249
- if (ec) {
250
- return make_unexpected (ec);
251
- }
252
-
253
234
int op_val = *op_result;
254
235
if (spin_count.Check (op_val <= 0 )) {
255
236
// Once every 30 seconds.
@@ -267,26 +248,18 @@ io::Result<size_t> TlsSocket::RecvMsg(const msghdr& msg, int flags) {
267
248
++io;
268
249
--io_len;
269
250
if (io_len == 0 )
270
- break ;
251
+ break ; // Finished reading everything.
271
252
dest = Engine::MutableBuffer{reinterpret_cast <uint8_t *>(io->iov_base ), io->iov_len };
272
253
}
273
- continue ; // We read everything we asked for - lets retry.
254
+ // We read everything we asked for but there are still buffers left to fill.
255
+ continue ;
274
256
}
275
257
break ;
276
258
}
277
259
278
- if (read_total) // if we read something lets return it before we handle other states.
279
- break ;
280
-
281
- if (op_val == Engine::EOF_STREAM) {
282
- return make_unexpected (make_error_code (errc::connection_reset));
283
- }
284
-
285
- if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
286
- ec = HandleSocketRead ();
287
- if (ec)
288
- return make_unexpected (ec);
289
- }
260
+ error_code ec = HandleOp (op_val);
261
+ if (ec)
262
+ return make_unexpected (ec);
290
263
}
291
264
return read_total;
292
265
}
@@ -307,12 +280,12 @@ io::Result<size_t> TlsSocket::WriteSome(const iovec* ptr, uint32_t len) {
307
280
// Chosen to be sufficiently smaller than the usual MTU (1500) and a multiple of 16.
308
281
// IP - max 24 bytes. TCP - max 60 bytes. TLS - max 21 bytes.
309
282
constexpr size_t kBufferSize = 1392 ;
310
- io::Result<size_t > ec ;
283
+ io::Result<size_t > res ;
311
284
size_t total_sent = 0 ;
312
285
313
286
while (len) {
314
287
if (ptr->iov_len > kBufferSize || len == 1 ) {
315
- ec = SendBuffer (Engine::Buffer{reinterpret_cast <uint8_t *>(ptr->iov_base ), ptr->iov_len });
288
+ res = SendBuffer (Engine::Buffer{reinterpret_cast <uint8_t *>(ptr->iov_base ), ptr->iov_len });
316
289
ptr++;
317
290
len--;
318
291
} else {
@@ -324,18 +297,18 @@ io::Result<size_t> TlsSocket::WriteSome(const iovec* ptr, uint32_t len) {
324
297
ptr++;
325
298
len--;
326
299
}
327
- ec = SendBuffer ({scratch, buffered_size});
300
+ res = SendBuffer ({scratch, buffered_size});
328
301
}
329
- if (!ec.has_value ()) {
330
- return ec;
331
- } else {
332
- total_sent += ec.value ();
302
+ if (!res) {
303
+ return res;
333
304
}
305
+ total_sent += *res;
334
306
}
335
307
return total_sent;
336
308
}
337
309
338
310
io::Result<size_t > TlsSocket::SendBuffer (Engine::Buffer buf) {
311
+ // Sending buffer into ssl.
339
312
DCHECK (engine_);
340
313
DCHECK_GT (buf.size (), 0u );
341
314
@@ -348,17 +321,7 @@ io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
348
321
return make_unexpected (SSL2Error (op_result.error ()));
349
322
}
350
323
351
- error_code ec = MaybeSendOutput ();
352
- if (ec) {
353
- return make_unexpected (ec);
354
- }
355
-
356
324
int op_val = *op_result;
357
- if (spin_count.Check (op_val <= 0 )) {
358
- // Once every 30 seconds.
359
- LOG_EVERY_T (WARNING, 30 ) << " IO loop spin limit reached. Limit: " << spin_count.Limit ()
360
- << " Spins: " << spin_count.Spins ();
361
- }
362
325
363
326
if (op_val > 0 ) {
364
327
send_total += op_val;
@@ -370,15 +333,15 @@ io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
370
333
}
371
334
}
372
335
373
- if (op_val == Engine::EOF_STREAM) {
374
- return make_unexpected (make_error_code (errc::connection_reset));
336
+ if (spin_count.Check (op_val <= 0 )) {
337
+ // Once every 30 seconds.
338
+ LOG_EVERY_T (WARNING, 30 ) << " IO loop spin limit reached. Limit: " << spin_count.Limit ()
339
+ << " Spins: " << spin_count.Spins ();
375
340
}
376
341
377
- if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
378
- ec = HandleSocketRead ();
379
- if (ec)
380
- return make_unexpected (ec);
381
- }
342
+ error_code ec = HandleOp (op_val);
343
+ if (ec)
344
+ return make_unexpected (ec);
382
345
}
383
346
384
347
return send_total;
@@ -395,6 +358,9 @@ SSL* TlsSocket::ssl_handle() {
395
358
}
396
359
397
360
auto TlsSocket::MaybeSendOutput () -> error_code {
361
+ if (engine_->OutputPending () == 0 )
362
+ return {};
363
+
398
364
// This function is present in both read and write paths.
399
365
// meaning that both of them can be called concurrently from differrent fibers and then
400
366
// race over flushing the output buffer. We use state_ to prevent that.
@@ -419,6 +385,10 @@ auto TlsSocket::MaybeSendOutput() -> error_code {
419
385
}
420
386
421
387
auto TlsSocket::HandleSocketRead () -> error_code {
388
+ error_code ec = MaybeSendOutput ();
389
+ if (ec)
390
+ return ec;
391
+
422
392
if (state_ & READ_IN_PROGRESS) {
423
393
// We need to Yield because otherwise we might end up in an infinite loop.
424
394
// See also comments in MaybeSendOutput.
@@ -434,33 +404,54 @@ auto TlsSocket::HandleSocketRead() -> error_code {
434
404
return esz.error ();
435
405
}
436
406
407
+ DVLOG (1 ) << " TlsSocket:Read " << *esz << " bytes" ;
408
+
437
409
engine_->CommitInput (*esz);
438
410
439
411
return error_code{};
440
412
}
441
413
442
414
error_code TlsSocket::HandleSocketWrite () {
443
415
Engine::Buffer buffer = engine_->PeekOutputBuf ();
416
+ DCHECK (!buffer.empty ());
417
+
418
+ if (buffer.empty ())
419
+ return {};
444
420
421
+ // we do not allow concurrent writes from multiple fibers.
422
+ state_ |= WRITE_IN_PROGRESS;
445
423
while (!buffer.empty ()) {
446
- // we do not allow concurrent writes from multiple fibers.
447
- state_ |= WRITE_IN_PROGRESS;
448
424
io::Result<size_t > write_result = next_sock_->WriteSome (buffer);
449
425
450
- // Safe to clear here since the code below is atomic fiber-wise.
451
- state_ &= ~WRITE_IN_PROGRESS;
452
426
DCHECK (engine_);
453
427
if (!write_result) {
428
+ state_ &= ~WRITE_IN_PROGRESS;
429
+
454
430
return write_result.error ();
455
431
}
456
432
CHECK_GT (*write_result, 0u );
457
433
engine_->ConsumeOutputBuf (*write_result);
458
434
buffer.remove_prefix (*write_result);
459
435
}
436
+ DCHECK_EQ (engine_->OutputPending (), 0u );
437
+
438
+ state_ &= ~WRITE_IN_PROGRESS;
460
439
461
440
return error_code{};
462
441
}
463
442
443
+ error_code TlsSocket::HandleOp (int op_val) {
444
+ switch (op_val) {
445
+ case Engine::EOF_STREAM:
446
+ return make_error_code (errc::connection_reset);
447
+ case Engine::NEED_READ_AND_MAYBE_WRITE:
448
+ return HandleSocketRead ();
449
+ default :
450
+ LOG (DFATAL) << " Unsupported " << op_val;
451
+ }
452
+ return {};
453
+ }
454
+
464
455
TlsSocket::endpoint_type TlsSocket::LocalEndpoint () const {
465
456
return next_sock_->LocalEndpoint ();
466
457
}
0 commit comments