Skip to content

Commit

Permalink
Refactor heartbeat to shutdown cleanly
Browse files Browse the repository at this point in the history
From ZMQ docs: "zmq_proxy() runs in the current thread and returns only
if/when the current context is closed."

The heartbeat socket doesn't need to be global, as nothing else touches
it. BUT, if we create the heartbeat socket in a `Context` that has a global ref,
we can close the context, which will cause zmq_proxy to return and then
that thread to end/finish.

Doing that before shutting down helps avoid a segfault on shutdown.
  • Loading branch information
halleysfifthinc committed Dec 10, 2024
1 parent fb76275 commit 3a9270c
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ Conda = "1"
JSON = "0.18,0.19,0.20,0.21,1"
MbedTLS = "0.5,0.6,0.7,1"
SoftGlobalScope = "1"
ZMQ = "1"
ZMQ = "1.3"
julia = "1.6"
3 changes: 3 additions & 0 deletions src/handlers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ function connect_request(socket, msg)
end

function shutdown_request(socket, msg)
# stop heartbeat thread by closing the context
close(zmq_proxy_context[])

send_ipython(requests[], msg_reply(msg, "shutdown_reply",
msg.content))
sleep(0.1) # short delay (like in ipykernel), to hopefully ensure shutdown_reply is sent
Expand Down
29 changes: 19 additions & 10 deletions src/heartbeat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
import Libdl

const threadid = zeros(Int, 128) # sizeof(uv_thread_t) <= 8 on Linux, OSX, Win
const zmq_proxy = Ref(C_NULL)
const zmq_proxy_context = Ref{Context}()

# entry point for new thread
function heartbeat_thread(sock::Ptr{Cvoid})
function heartbeat_thread(heartbeat_addr::Cstring)
zmq_proxy_context[] = Context()
heartbeat = Socket(zmq_proxy_context[], ROUTER)
GC.@preserve heartbeat_addr bind(heartbeat, unsafe_string(heartbeat_addr))
@static if VERSION v"1.9.0-DEV.1588" # julia#46609
# julia automatically "adopts" this thread because
# we entered a Julia cfunction. We then have to enable
Expand All @@ -19,14 +22,20 @@ function heartbeat_thread(sock::Ptr{Cvoid})
# (see julia#47196)
ccall(:jl_gc_safe_enter, Int8, ())
end
ccall(zmq_proxy[], Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}),
sock, sock, C_NULL)
nothing
ret = ZMQ.lib.zmq_proxy(heartbeat, heartbeat, C_NULL)
@static if VERSION v"1.9.0-DEV.1588" # julia#46609
# julia automatically "adopts" this thread because
# we entered a Julia cfunction. We then have to enable
# a GC "safe" region to prevent us from grabbing the
# GC lock with the call to zmq_proxy, which never returns.
# (see julia#47196)
ccall(:jl_gc_safe_leave, Int8, ())
end
return ret
end

function start_heartbeat(sock)
zmq_proxy[] = Libdl.dlsym(Libdl.dlopen(ZMQ.libzmq), :zmq_proxy)
heartbeat_c = @cfunction(heartbeat_thread, Cvoid, (Ptr{Cvoid},))
ccall(:uv_thread_create, Cint, (Ptr{Int}, Ptr{Cvoid}, Ptr{Cvoid}),
threadid, heartbeat_c, sock)
function start_heartbeat(heartbeat_addr)
heartbeat_c = @cfunction(heartbeat_thread, Cint, (Cstring,))
ccall(:uv_thread_create, Cint, (Ptr{Int}, Ptr{Cvoid}, Cstring),
threadid, heartbeat_c, heartbeat_addr)
end
7 changes: 2 additions & 5 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ const publish = Ref{Socket}()
const raw_input = Ref{Socket}()
const requests = Ref{Socket}()
const control = Ref{Socket}()
const heartbeat = Ref{Socket}()
const profile = Dict{String,Any}()
const read_stdout = Ref{Base.PipeEndpoint}()
const read_stderr = Ref{Base.PipeEndpoint}()
Expand Down Expand Up @@ -87,21 +86,19 @@ function init(args)
raw_input[] = Socket(ROUTER)
requests[] = Socket(ROUTER)
control[] = Socket(ROUTER)
heartbeat[] = Socket(ROUTER)
sep = profile["transport"]=="ipc" ? "-" : ":"
bind(publish[], "$(profile["transport"])://$(profile["ip"])$(sep)$(profile["iopub_port"])")
bind(requests[], "$(profile["transport"])://$(profile["ip"])$(sep)$(profile["shell_port"])")
bind(control[], "$(profile["transport"])://$(profile["ip"])$(sep)$(profile["control_port"])")
bind(raw_input[], "$(profile["transport"])://$(profile["ip"])$(sep)$(profile["stdin_port"])")
bind(heartbeat[], "$(profile["transport"])://$(profile["ip"])$(sep)$(profile["hb_port"])")
start_heartbeat("$(profile["transport"])://$(profile["ip"])$(sep)$(profile["hb_port"])")

# associate a lock with each socket so that multi-part messages
# on a given socket don't get inter-mingled between tasks.
for s in (publish[], raw_input[], requests[], control[], heartbeat[])
for s in (publish[], raw_input[], requests[], control[])
socket_locks[s] = ReentrantLock()
end

start_heartbeat(heartbeat[])
if capture_stdout
read_stdout[], = redirect_stdout()
redirect_stdout(IJuliaStdio(stdout,"stdout"))
Expand Down

0 comments on commit 3a9270c

Please sign in to comment.