Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

msgq: refactor blocking recv for improved robustness and performance #616

Merged
merged 6 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions SConscript
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ msgq_objects = env.SharedObject([
'msgq/msgq.cc',
])
msgq = env.Library('msgq', msgq_objects)
msgq_python = envCython.Program('msgq/ipc_pyx.so', 'msgq/ipc_pyx.pyx', LIBS=envCython["LIBS"]+[msgq, "zmq", common])
msgq_python = envCython.Program('msgq/ipc_pyx.so', 'msgq/ipc_pyx.pyx', LIBS=envCython["LIBS"]+[msgq, "zmq", 'pthread', common])

# Build Vision IPC
vipc_files = ['visionipc.cc', 'visionipc_server.cc', 'visionipc_client.cc', 'visionbuf.cc']
Expand All @@ -31,7 +31,7 @@ visionipc = env.Library('visionipc', vipc_objects)


vipc_frameworks = []
vipc_libs = envCython["LIBS"] + [visionipc, msgq, common, "zmq"]
vipc_libs = envCython["LIBS"] + [visionipc, msgq, common, "zmq", 'pthread']
if arch == "Darwin":
vipc_frameworks.append('OpenCL')
else:
Expand All @@ -45,4 +45,5 @@ if GetOption('extras'):
[f'{visionipc_dir.abspath}/test_runner.cc', f'{visionipc_dir.abspath}/visionipc_tests.cc'],
LIBS=['pthread'] + vipc_libs, FRAMEWORKS=vipc_frameworks)

msgq = [msgq, 'pthread']
Export('visionipc', 'msgq', 'msgq_python')
104 changes: 45 additions & 59 deletions msgq/impl_msgq.cc
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
#include <cassert>
#include <cstring>
#include <iostream>
#include <cstdlib>
#include <chrono>
#include <csignal>
#include <cerrno>

#include "msgq/impl_msgq.h"


volatile sig_atomic_t msgq_do_exit = 0;

void sig_handler(int signal) {
assert(signal == SIGINT || signal == SIGTERM);
msgq_do_exit = 1;
}

using namespace std::chrono;

MSGQContext::MSGQContext() {
}
Expand Down Expand Up @@ -70,61 +62,55 @@ int MSGQSubSocket::connect(Context *context, std::string endpoint, std::string a
return 0;
}


Message * MSGQSubSocket::receive(bool non_blocking){
msgq_do_exit = 0;

void (*prev_handler_sigint)(int);
void (*prev_handler_sigterm)(int);
if (!non_blocking){
prev_handler_sigint = std::signal(SIGINT, sig_handler);
prev_handler_sigterm = std::signal(SIGTERM, sig_handler);
}

msgq_msg_t msg;

MSGQMessage *r = NULL;

Message *MSGQSubSocket::receive(bool non_blocking) {
msgq_msg_t msg{};
int rc = msgq_msg_recv(&msg, q);

// Hack to implement blocking read with a poller. Don't use this
while (!non_blocking && rc == 0 && msgq_do_exit == 0){
msgq_pollitem_t items[1];
items[0].q = q;

int t = (timeout != -1) ? timeout : 100;

int n = msgq_poll(items, 1, t);
rc = msgq_msg_recv(&msg, q);

// The poll indicated a message was ready, but the receive failed. Try again
if (n == 1 && rc == 0){
continue;
}

if (timeout != -1){
break;
if (rc == 0 && !non_blocking) {
sigset_t mask;
sigset_t old_mask;
sigemptyset(&mask);
sigaddset(&mask, SIGINT);
sigaddset(&mask, SIGTERM);
sigaddset(&mask, SIGUSR2);

pthread_sigmask(SIG_BLOCK, &mask, &old_mask);

int64_t timieout_ns = ((timeout != -1) ? timeout : 1000) * 1000000;
sshane marked this conversation as resolved.
Show resolved Hide resolved
auto start = steady_clock::now();

// Continue receiving messages until timeout or interruption by SIGINT or SIGTERM
while (rc == 0 && timieout_ns > 0) {
struct timespec ts {
timieout_ns / 1000000000,
timieout_ns % 1000000000,
};

int ret = sigtimedwait(&mask, nullptr, &ts);
if (ret == SIGINT || ret == SIGTERM) {
// Ensure signal handling is not missed
raise(ret);
break;
} else if (ret == -1 && errno == EAGAIN && timeout != -1) {
break; // Timed out
}

rc = msgq_msg_recv(&msg, q);

if (timeout != -1) {
timieout_ns -= duration_cast<nanoseconds>(steady_clock::now() - start).count();
start = steady_clock::now(); // Update start time
}
}
pthread_sigmask(SIG_SETMASK, &old_mask, nullptr);
}


if (!non_blocking){
std::signal(SIGINT, prev_handler_sigint);
std::signal(SIGTERM, prev_handler_sigterm);
}

errno = msgq_do_exit ? EINTR : 0;

if (rc > 0){
if (msgq_do_exit){
msgq_msg_close(&msg); // Free unused message on exit
} else {
r = new MSGQMessage;
r->takeOwnership(msg.data, msg.size);
}
if (rc > 0) {
MSGQMessage *r = new MSGQMessage;
r->takeOwnership(msg.data, msg.size);
return r;
}

return (Message*)r;
return nullptr;
}

void MSGQSubSocket::setTimeout(int t){
Expand Down
2 changes: 1 addition & 1 deletion msgq/ipc.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ cdef extern from "msgq/ipc.h":
@staticmethod
SubSocket * create()
int connect(Context *, string, string, bool)
Message * receive(bool)
Message * receive(bool) nogil
void setTimeout(int)

cdef cppclass PubSocket:
Expand Down
9 changes: 3 additions & 6 deletions msgq/ipc_pyx.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,11 @@ cdef class SubSocket:
self.socket.setTimeout(timeout)

def receive(self, bool non_blocking=False):
msg = self.socket.receive(non_blocking)
cdef cppMessage *msg
with nogil:
msg = self.socket.receive(non_blocking)

if msg == NULL:
# If a blocking read returns no message check errno if SIGINT was caught in the C++ code
if errno.errno == errno.EINTR:
print("SIGINT received, exiting")
sys.exit(1)

return None
else:
sz = msg.getSize()
Expand Down
25 changes: 25 additions & 0 deletions msgq/tests/test_messaging.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import pytest
import random
import signal
import threading
import time
import string
import msgq
Expand Down Expand Up @@ -67,3 +70,25 @@ def test_receive_timeout(self):
recvd = sub_sock.receive()
assert (time.monotonic() - start_time) < 0.2
assert recvd is None

def test_receive_interrupts_on_sigint(self):
sock = random_sock()
sub_sock = msgq.sub_sock(sock)

# Send SIGINT after a short delay
pid = os.getpid()
def send_sigint():
time.sleep(.5)
os.kill(pid, signal.SIGINT)

# Start a thread to send SIGINT
thread = threading.Thread(target=send_sigint)
thread.start()

with pytest.raises(KeyboardInterrupt):
start_time = time.monotonic()
recvd = sub_sock.receive()
assert (time.monotonic() - start_time) < 0.5
assert recvd is None

thread.join()
Loading