Skip to content

Commit

Permalink
fix: use non-blocking sockets. (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored May 23, 2024
1 parent b8f57a0 commit 9a1e284
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 36 deletions.
2 changes: 2 additions & 0 deletions src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s

Tokenizer tokenizer(args->tokenizerPath, spec.vocabSize);
Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, socketPool);
socketPool->setTurbo(true);

Inference inference = Inference(&arch, args->nThreads, &transformer, socketPool);

Sampler sampler(spec.vocabSize, args->temperature, args->topp, args->seed);
Expand Down
73 changes: 43 additions & 30 deletions src/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@
#include <fcntl.h>
#include <ctime>
#include <unistd.h>
#include "socket.hpp"
#include <stdexcept>
#include <vector>
#include "socket.hpp"

#define SOCKET_LAST_ERRCODE errno
#define SOCKET_LAST_ERROR strerror(errno)

#define AUTO_NON_BLOCKING_MODULO 10000
#define AUTO_NON_BLOCKING_TIMEOUT_SECONDS 3

static inline void setNonBlocking(int socket, bool enabled) {
int flags = fcntl(socket, F_GETFL, 0);
if (enabled) {
Expand All @@ -35,6 +32,14 @@ static inline void setNoDelay(int socket) {
throw std::runtime_error("Error setting socket to no-delay");
}

static inline void setQuickAck(int socket) {
#ifdef TCP_QUICKACK
int value = 1;
if (setsockopt(socket, IPPROTO_TCP, TCP_QUICKACK, (char*)&value, sizeof(int)) < 0)
throw std::runtime_error("Error setting quick ack");
#endif
}

static inline void writeSocket(int socket, const void* data, size_t size) {
while (size > 0) {
int s = send(socket, (char*)data, size, 0);
Expand All @@ -51,22 +56,17 @@ static inline void writeSocket(int socket, const void* data, size_t size) {
}
}

static inline void readSocket(bool* isNonBlocking, int socket, void* data, size_t size) {
unsigned int attempt = 0;
time_t startTime;

while (size > 0) {
int r = recv(socket, (char*)data, size, 0);
static inline bool tryReadSocket(int socket, void* data, size_t size, unsigned long maxAttempts) {
// maxAttempts = 0 means infinite attempts
size_t s = size;
while (s > 0) {
int r = recv(socket, data, s, 0);
if (r < 0) {
if (*isNonBlocking && SOCKET_LAST_ERRCODE == EAGAIN) {
attempt++;
if (attempt % AUTO_NON_BLOCKING_MODULO == 0) {
time_t now = time(NULL);
if (attempt == AUTO_NON_BLOCKING_MODULO) {
startTime = now;
} else if (now - startTime > AUTO_NON_BLOCKING_TIMEOUT_SECONDS) {
setNonBlocking(socket, false);
*isNonBlocking = false;
if (SOCKET_LAST_ERRCODE == EAGAIN) {
if (s == size && maxAttempts > 0) {
maxAttempts--;
if (maxAttempts == 0) {
return false;
}
}
continue;
Expand All @@ -76,13 +76,13 @@ static inline void readSocket(bool* isNonBlocking, int socket, void* data, size_
throw ReadSocketException(0, "Socket closed");
}
data = (char*)data + r;
size -= r;

if (!*isNonBlocking) {
setNonBlocking(socket, true);
*isNonBlocking = true;
}
s -= r;
}
return true;
}

static inline void readSocket(int socket, void* data, size_t size) {
assert(tryReadSocket(socket, data, size, 0));
}

ReadSocketException::ReadSocketException(int code, const char* message) {
Expand Down Expand Up @@ -116,6 +116,7 @@ SocketPool* SocketPool::connect(unsigned int nSockets, char** hosts, int* ports)
}

setNoDelay(clientSocket);
setQuickAck(clientSocket);
sockets[i] = clientSocket;
}
return new SocketPool(nSockets, sockets);
Expand All @@ -124,7 +125,6 @@ SocketPool* SocketPool::connect(unsigned int nSockets, char** hosts, int* ports)
SocketPool::SocketPool(unsigned int nSockets, int* sockets) {
this->nSockets = nSockets;
this->sockets = sockets;
this->isNonBlocking = new bool[nSockets];
this->sentBytes.exchange(0);
this->recvBytes.exchange(0);
}
Expand All @@ -135,7 +135,12 @@ SocketPool::~SocketPool() {
close(sockets[i]);
}
delete[] sockets;
delete[] isNonBlocking;
}

void SocketPool::setTurbo(bool enabled) {
for (unsigned int i = 0; i < nSockets; i++) {
::setNonBlocking(sockets[i], enabled);
}
}

void SocketPool::write(unsigned int socketIndex, const void* data, size_t size) {
Expand All @@ -147,7 +152,7 @@ void SocketPool::write(unsigned int socketIndex, const void* data, size_t size)
void SocketPool::read(unsigned int socketIndex, void* data, size_t size) {
assert(socketIndex >= 0 && socketIndex < nSockets);
recvBytes += size;
readSocket(&isNonBlocking[socketIndex], sockets[socketIndex], data, size);
readSocket(sockets[socketIndex], data, size);
}

void SocketPool::writeMany(unsigned int n, SocketIo* ios) {
Expand Down Expand Up @@ -224,26 +229,34 @@ Socket SocketServer::accept() {
if (clientSocket < 0)
throw std::runtime_error("Error accepting connection");
setNoDelay(clientSocket);
setQuickAck(clientSocket);
printf("Client connected\n");
return Socket(clientSocket);
}

Socket::Socket(int socket) {
this->socket = socket;
this->isNonBlocking = false;
}

Socket::~Socket() {
shutdown(socket, 2);
close(socket);
}

void Socket::setTurbo(bool enabled) {
::setNonBlocking(socket, enabled);
}

void Socket::write(const void* data, size_t size) {
writeSocket(socket, data, size);
}

void Socket::read(void* data, size_t size) {
readSocket(&isNonBlocking, socket, data, size);
readSocket(socket, data, size);
}

bool Socket::tryRead(void* data, size_t size, unsigned long maxAttempts) {
return tryReadSocket(socket, data, size, maxAttempts);
}

std::vector<char> Socket::readHttpRequest() {
Expand Down
5 changes: 3 additions & 2 deletions src/socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ struct SocketIo {
class SocketPool {
private:
int* sockets;
bool* isNonBlocking;
std::atomic_uint sentBytes;
std::atomic_uint recvBytes;

Expand All @@ -40,6 +39,7 @@ class SocketPool {
SocketPool(unsigned int nSockets, int* sockets);
~SocketPool();

void setTurbo(bool enabled);
void write(unsigned int socketIndex, const void* data, size_t size);
void read(unsigned int socketIndex, void* data, size_t size);
void writeMany(unsigned int n, SocketIo* ios);
Expand All @@ -50,14 +50,15 @@ class SocketPool {
class Socket {
private:
int socket;
bool isNonBlocking;

public:
Socket(int socket);
~Socket();

void setTurbo(bool enabled);
void write(const void* data, size_t size);
void read(void* data, size_t size);
bool tryRead(void* data, size_t size, unsigned long maxAttempts);
std::vector<char> readHttpRequest();
};

Expand Down
26 changes: 23 additions & 3 deletions src/tasks.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "tasks.hpp"
#include <cassert>
#include <cstring>
#include <cstdio>

TransformerArch::TransformerArch() {
inference.nTasks = 0;
Expand Down Expand Up @@ -175,8 +176,8 @@ void sendPos(TASK_ARGS) {
}
}

void waitForPos(Transformer* transformer, Socket* socket) {
socket->read(&transformer->pos, sizeof(pos_t));
bool tryWaitForPos(Transformer* transformer, Socket* socket, unsigned int maxAttempts) {
return socket->tryRead(&transformer->pos, sizeof(pos_t), maxAttempts);
}

Inference::Inference(TransformerArch* arch, unsigned int nThreads, Transformer* transformer, SocketPool* socketPool) {
Expand Down Expand Up @@ -226,8 +227,27 @@ Worker::~Worker() {
}

void Worker::work() {
const unsigned long maxAttempts = 10000;

bool turbo = false;
while (true) {
waitForPos(transformer, socket);
const clock_t start = clock();

while (!tryWaitForPos(transformer, socket, maxAttempts)) {
if (turbo) {
// After one second of waiting with non-blocking read, we switch to blocking mode to not burn CPU.
if (clock() - start > CLOCKS_PER_SEC) {
socket->setTurbo(false);
turbo = false;
printf("🚁 Socket is in blocking mode\n");
}
}
}
if (!turbo) {
socket->setTurbo(true);
turbo = true;
printf("🚁 Socket is in non-blocking mode\n");
}

context.currentBlockIndex = 0;
taskLoop->run();
Expand Down
2 changes: 1 addition & 1 deletion src/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ Transformer Transformer::loadSlice(TransformerSpec* spec, Socket* socket) {
}

float kbs = blockBytes / (float)(timeMs() - t0);
printf("⏩ Received %ld bytes for block %d (%.0f kB/s)\n", blockBytes, i, kbs);
printf("⏩ Received %ld kB for block %d (%.0f kB/s)\n", blockBytes / 1024, i, kbs);
}
return transformer;
}

0 comments on commit 9a1e284

Please sign in to comment.