-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathfiber_socket_base.h
212 lines (151 loc) · 6.23 KB
/
fiber_socket_base.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
// Copyright 2023, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include <absl/base/attributes.h>
// for tcp::endpoint. Consider introducing our own.
#include <boost/asio/ip/tcp.hpp>
#include <functional>
#include "io/io.h"
namespace util {
namespace fb2 {
class ProactorBase;
} // namespace fb2
class FiberSocketBase : public io::Sink, public io::AsyncSink, public io::Source {
FiberSocketBase(const FiberSocketBase&) = delete;
void operator=(const FiberSocketBase&) = delete;
FiberSocketBase(FiberSocketBase&& other) = delete;
FiberSocketBase& operator=(FiberSocketBase&& other) = delete;
protected:
explicit FiberSocketBase(fb2::ProactorBase* pb) : proactor_(pb) {
}
public:
using endpoint_type = ::boost::asio::ip::tcp::endpoint;
using error_code = std::error_code;
using AcceptResult = ::io::Result<FiberSocketBase*>;
using io::AsyncSink::AsyncProgressCb;
using ProactorBase = fb2::ProactorBase;
ABSL_MUST_USE_RESULT virtual error_code Shutdown(int how) = 0;
ABSL_MUST_USE_RESULT virtual AcceptResult Accept() = 0;
ABSL_MUST_USE_RESULT virtual error_code Connect(const endpoint_type& ep,
std::function<void(int)> on_pre_connect = {}) = 0;
ABSL_MUST_USE_RESULT virtual error_code Close() = 0;
virtual bool IsOpen() const = 0;
::io::Result<size_t> virtual RecvMsg(const msghdr& msg, int flags) = 0;
::io::Result<size_t> Recv(const iovec* ptr, size_t len);
// to satisfy io::Source concept.
::io::Result<size_t> ReadSome(const iovec* v, uint32_t len) final {
return len > 1 ? Recv(v, len)
: Recv(io::MutableBytes{reinterpret_cast<uint8_t*>(v->iov_base), v->iov_len}, 0);
}
virtual ::io::Result<size_t> Recv(const io::MutableBytes& mb, int flags = 0) = 0;
static bool IsConnClosed(const error_code& ec) {
return (ec == std::errc::connection_aborted) || (ec == std::errc::connection_reset);
}
virtual void SetProactor(ProactorBase* p);
ProactorBase* proactor() {
return proactor_;
}
const ProactorBase* proactor() const {
return proactor_;
}
// UINT32_MAX to disable timeout.
virtual void set_timeout(uint32_t msec) = 0;
virtual uint32_t timeout() const = 0;
using AsyncSink::AsyncWrite;
using AsyncSink::AsyncWriteSome;
virtual endpoint_type LocalEndpoint() const = 0;
virtual endpoint_type RemoteEndpoint() const = 0;
//! Registers a callback that will be called if the socket is closed or has an error.
//! Should not be called if a callback is already registered.
virtual void RegisterOnErrorCb(std::function<void(uint32_t)> cb) = 0;
//! Cancels a callback that was registered with RegisterOnErrorCb. Must be reentrant.
virtual void CancelOnErrorCb() = 0;
virtual bool IsUDS() const = 0;
using native_handle_type = int;
virtual native_handle_type native_handle() const = 0;
/// Creates a socket. By default with AF_INET family (2).
virtual error_code Create(unsigned short protocol_family = 2) = 0;
virtual ABSL_MUST_USE_RESULT error_code Bind(const struct sockaddr* bind_addr,
unsigned addr_len) = 0;
virtual ABSL_MUST_USE_RESULT error_code Listen(unsigned backlog) = 0;
// Listens on all interfaces. If port is 0 then a random available port is chosen
// by the OS.
virtual ABSL_MUST_USE_RESULT error_code Listen(uint16_t port, unsigned backlog) = 0;
// Listen on UDS socket. Must be created with Create(AF_UNIX) first.
virtual ABSL_MUST_USE_RESULT error_code ListenUDS(const char* path, mode_t permissions,
unsigned backlog) = 0;
protected:
virtual void OnSetProactor() {
}
virtual void OnResetProactor() {
}
private:
// We must reference proactor in each socket so that we could support write_some/read_some
// with predefined interface and be compliant with SyncWriteStream/SyncReadStream concepts.
ProactorBase* proactor_;
};
class LinuxSocketBase : public FiberSocketBase {
public:
using FiberSocketBase::native_handle_type;
virtual ~LinuxSocketBase();
native_handle_type native_handle() const override {
static_assert(int32_t(-1) >> kFdShift == -1);
return ShiftedFd();
}
/// Creates a socket. By default with AF_INET family (2).
error_code Create(unsigned short protocol_family = 2) override;
ABSL_MUST_USE_RESULT error_code Bind(const struct sockaddr* bind_addr,
unsigned addr_len) override;
ABSL_MUST_USE_RESULT error_code Listen(unsigned backlog) override;
// Listens on all interfaces. If port is 0 then a random available port is chosen
// by the OS.
ABSL_MUST_USE_RESULT error_code Listen(uint16_t port, unsigned backlog) override;
// Listen on UDS socket. Must be created with Create(AF_UNIX) first.
ABSL_MUST_USE_RESULT error_code ListenUDS(const char* path, mode_t permissions,
unsigned backlog) override;
error_code Shutdown(int how) override;
// UINT32_MAX to disable timeout.
void set_timeout(uint32_t msec) final override {
timeout_ = msec;
}
uint32_t timeout() const final override {
return timeout_;
}
//! Removes the ownership over file descriptor. Use with caution.
void Detach() {
fd_ = -1;
}
//! IsOpen does not promise that the socket is TCP connected or live,
// just that the file descriptor is valid and its state is open.
bool IsOpen() const final {
return (fd_ & IS_SHUTDOWN) == 0;
}
endpoint_type LocalEndpoint() const override;
endpoint_type RemoteEndpoint() const override;
bool IsUDS() const override {
return fd_ & IS_UDS;
}
protected:
constexpr static unsigned kFdShift = 4;
LinuxSocketBase(int fd, ProactorBase* pb)
: FiberSocketBase(pb), fd_(fd > 0 ? fd << kFdShift : fd) {
}
int ShiftedFd() const {
return fd_ >> kFdShift;
}
enum {
IS_SHUTDOWN = 0x1,
IS_UDS = 0x2,
};
// Flags which are passed on to peers produced by Accept()
const static int32_t kInheritedFlags = IS_UDS;
// kFdShift low bits are used for masking the state of fd.
// gives me 256M descriptors.
int32_t fd_;
private:
uint32_t timeout_ = UINT32_MAX;
};
void SetNonBlocking(int fd);
void SetCloexec(int fd);
} // namespace util