@@ -136,7 +136,7 @@ auto TlsSocket::Accept() -> AcceptResult {
136
136
return make_unexpected (make_error_code (errc::connection_reset));
137
137
}
138
138
if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
139
- ec = HandleRead ();
139
+ ec = HandleSocketRead ();
140
140
if (ec)
141
141
return make_unexpected (ec);
142
142
}
@@ -145,19 +145,53 @@ auto TlsSocket::Accept() -> AcceptResult {
145
145
return nullptr ;
146
146
}
147
147
148
- auto TlsSocket::Connect (const endpoint_type& endpoint) -> error_code {
148
+ error_code TlsSocket::Connect (const endpoint_type& endpoint,
149
+ std::function<void (int )> on_pre_connect) {
149
150
DCHECK (engine_);
150
- auto io_result = engine_->Handshake (Engine::HandshakeType::CLIENT);
151
- if (!io_result. has_value () ) {
152
- return std::error_code (io_result .error (), std::system_category ());
151
+ Engine::OpResult op_result = engine_->Handshake (Engine::HandshakeType::CLIENT);
152
+ if (!op_result ) {
153
+ return std::error_code (op_result .error (), std::system_category ());
153
154
}
154
155
155
156
// If the socket is already open, we should not call connect on it
156
- if (IsOpen ()) {
157
- return {};
157
+ if (!IsOpen ()) {
158
+ error_code ec = next_sock_->Connect (endpoint, std::move (on_pre_connect));
159
+ if (ec)
160
+ return ec;
161
+ }
162
+
163
+ // Flush the ssl data to the socket and run the loop that ensures handshaking converges.
164
+ int op_val = *op_result;
165
+ error_code ec;
166
+
167
+ // it should guide us to write and then read.
168
+ DCHECK_EQ (op_val, Engine::NEED_READ_AND_MAYBE_WRITE);
169
+ while (op_val < 0 ) {
170
+ if (op_val == Engine::EOF_STREAM) {
171
+ return make_error_code (errc::connection_reset);
172
+ }
173
+
174
+ if (op_val == Engine::NEED_WRITE) {
175
+ ec = HandleSocketWrite ();
176
+ if (ec)
177
+ return ec;
178
+ } else if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
179
+ ec = HandleSocketWrite ();
180
+ if (ec)
181
+ return ec;
182
+
183
+ ec = HandleSocketRead ();
184
+ if (ec)
185
+ return ec;
186
+ }
187
+ op_result = engine_->Handshake (Engine::HandshakeType::CLIENT);
188
+ if (!op_result) {
189
+ return std::error_code (op_result.error (), std::system_category ());
190
+ }
191
+ op_val = *op_result;
158
192
}
159
193
160
- return next_sock_-> Connect (endpoint) ;
194
+ return ec ;
161
195
}
162
196
163
197
auto TlsSocket::Close () -> error_code {
@@ -249,7 +283,7 @@ io::Result<size_t> TlsSocket::RecvMsg(const msghdr& msg, int flags) {
249
283
}
250
284
251
285
if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
252
- ec = HandleRead ();
286
+ ec = HandleSocketRead ();
253
287
if (ec)
254
288
return make_unexpected (ec);
255
289
}
@@ -341,7 +375,7 @@ io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
341
375
}
342
376
343
377
if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
344
- ec = HandleRead ();
378
+ ec = HandleSocketRead ();
345
379
if (ec)
346
380
return make_unexpected (ec);
347
381
}
@@ -381,28 +415,10 @@ auto TlsSocket::MaybeSendOutput() -> error_code {
381
415
return error_code{};
382
416
}
383
417
384
- auto buf_result = engine_->PeekOutputBuf ();
385
- CHECK (buf_result);
386
-
387
- if (!buf_result->empty ()) {
388
- // we do not allow concurrent writes from multiple fibers.
389
- state_ |= WRITE_IN_PROGRESS;
390
- io::Result<size_t > write_result = next_sock_->WriteSome (*buf_result);
391
-
392
- // Safe to clear here since the code below is atomic fiber-wise.
393
- state_ &= ~WRITE_IN_PROGRESS;
394
- DCHECK (engine_);
395
- if (!write_result) {
396
- return write_result.error ();
397
- }
398
- CHECK_GT (*write_result, 0u );
399
- engine_->ConsumeOutputBuf (*write_result);
400
- }
401
-
402
- return error_code{};
418
+ return HandleSocketWrite ();
403
419
}
404
420
405
- auto TlsSocket::HandleRead () -> error_code {
421
+ auto TlsSocket::HandleSocketRead () -> error_code {
406
422
if (state_ & READ_IN_PROGRESS) {
407
423
// We need to Yield because otherwise we might end up in an infinite loop.
408
424
// See also comments in MaybeSendOutput.
@@ -423,6 +439,28 @@ auto TlsSocket::HandleRead() -> error_code {
423
439
return error_code{};
424
440
}
425
441
442
+ error_code TlsSocket::HandleSocketWrite () {
443
+ Engine::Buffer buffer = engine_->PeekOutputBuf ();
444
+
445
+ while (!buffer.empty ()) {
446
+ // we do not allow concurrent writes from multiple fibers.
447
+ state_ |= WRITE_IN_PROGRESS;
448
+ io::Result<size_t > write_result = next_sock_->WriteSome (buffer);
449
+
450
+ // Safe to clear here since the code below is atomic fiber-wise.
451
+ state_ &= ~WRITE_IN_PROGRESS;
452
+ DCHECK (engine_);
453
+ if (!write_result) {
454
+ return write_result.error ();
455
+ }
456
+ CHECK_GT (*write_result, 0u );
457
+ engine_->ConsumeOutputBuf (*write_result);
458
+ buffer.remove_prefix (*write_result);
459
+ }
460
+
461
+ return error_code{};
462
+ }
463
+
426
464
TlsSocket::endpoint_type TlsSocket::LocalEndpoint () const {
427
465
return next_sock_->LocalEndpoint ();
428
466
}
0 commit comments