Skip to content

Commit faa74e6

Browse files
committed
runtime: pull thread tracking into a header magic system to unify implementations
1 parent aab57be commit faa74e6

File tree

6 files changed

+161
-129
lines changed

6 files changed

+161
-129
lines changed

runtime/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ if(NOT ALASKA_CORE_ONLY) # -----------------------------------------------------
141141
set(SOURCES
142142
rt/init.cpp
143143
rt/halloc.cpp
144-
rt/threads.cpp
145144
rt/compat.c
146145
rt/barrier.cpp
147146
)

runtime/include/alaska/rt/barrier.hpp

+4-14
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
#include <pthread.h>
1919

2020

21-
typedef struct {
21+
struct AlaskaThreadState {
22+
23+
2224
// Track the depth of escape of a given thread. If this number is zero,
2325
// the thread is in 'managed' code and will eventually poll the barrier
2426
uint64_t escaped;
@@ -29,28 +31,16 @@ typedef struct {
2931
#define ALASKA_JOIN_REASON_SAFEPOINT 1 // This thread was at a safepoint
3032
#define ALASKA_JOIN_REASON_ORCHESTRATOR 2 // This thread was the orchestrator
3133
#define ALASKA_JOIN_REASON_ABORT 3 // This thread requires the barrier abort (invalid state, for some reason)
32-
33-
// ...
34-
} alaska_thread_state_t;
35-
36-
extern __thread alaska_thread_state_t alaska_thread_state;
37-
34+
};
3835

3936
namespace alaska {
4037
namespace barrier {
4138

42-
// Thread tracking lifetime
43-
void add_self_thread(void);
44-
void remove_self_thread(void);
45-
4639
// Barrier operational lifetime. It is not recommended to use this interface, and
4740
// instead use `with_barrier` interface below:
4841
bool begin();
4942
void end();
5043

51-
// struct BarrierInfo {};
52-
// bool with_barrier(ck::func<void(BarrierInfo &)> &&cb);
53-
5444
// Initialization and deinitialization
5545
void init();
5646
void deinit();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* This file is part of the Alaska Handle-Based Memory Management System
3+
*
4+
* Copyright (c) 2024, Nick Wanninger <ncw@u.northwestern.edu>
5+
* Copyright (c) 2024, The Constellation Project
6+
* All rights reserved.
7+
*
8+
* This is free software. You are permitted to use, redistribute,
9+
* and modify it as specified in the file "LICENSE".
10+
*/
11+
12+
#include <alaska/ThreadRegistry.hpp>
13+
#include <alaska/rt.hpp>
14+
#include <alaska/rt/barrier.hpp>
15+
#include <dlfcn.h>
16+
#include <unistd.h>
17+
#include <pthread.h>
18+
#include <stdio.h>
19+
#include <stdlib.h>
20+
21+
// NOTE: this file is a little funky so far as header files go in C++ because it
22+
// provides implementations based on a macro begin defined *before* this file is
23+
// included.
24+
// To enable thread tracking in your alaska runtime, do the following:
25+
//
26+
// struct my_state {};
27+
// #define ALASKA_THREAD_TRACK_STATE_T struct my_state
28+
// #include <alaska/thread_tracking.in.hpp>
29+
//
30+
// This will *define* several functions in the current file. Namely, it will override
31+
// `pthread_create` to enable tracking. Usually, this should just be included wherever
32+
// you implement your barrier logic - and should only be included *ONE TIME*.
33+
34+
#ifndef ALASKA_THREAD_TRACK_STATE_T
35+
#warning Thread state type not defined. defaulting to int
36+
#define ALASKA_THREAD_TRACK_STATE_T int
37+
#endif
38+
39+
// the following is the main interface to this system:
40+
namespace alaska::thread_tracking {
41+
using StateT = ALASKA_THREAD_TRACK_STATE_T;
42+
static __thread StateT my_state;
43+
44+
alaska::ThreadRegistry<StateT *> &threads();
45+
46+
// Join and leave with the current thread
47+
void join();
48+
void leave();
49+
} // namespace alaska::thread_tracking
50+
51+
52+
53+
// And this is the implementation:
54+
// static auto& threads(void) {
55+
// }
56+
57+
namespace alaska::thread_tracking {
58+
alaska::ThreadRegistry<StateT *> &threads() {
59+
static alaska::ThreadRegistry<StateT *> *g_threads;
60+
if (g_threads == NULL) g_threads = new alaska::ThreadRegistry<StateT *>();
61+
return *g_threads;
62+
}
63+
64+
65+
void join(void) {
66+
#ifdef ALASKA_THREAD_TRACK_INIT
67+
ALASKA_THREAD_TRACK_INIT;
68+
#endif
69+
alaska::thread_tracking::threads().join(&alaska::thread_tracking::my_state);
70+
}
71+
72+
void leave(void) { alaska::thread_tracking::threads().leave(); }
73+
} // namespace alaska::thread_tracking
74+
75+
76+
77+
78+
79+
// now, the icky part about wrapping around pthread_create
80+
struct alaska_pthread_trampoline_arg {
81+
void* arg;
82+
void* (*start)(void*);
83+
};
84+
85+
86+
static void* alaska_pthread_trampoline(void* varg) {
87+
void* (*start)(void*);
88+
auto* arg = (struct alaska_pthread_trampoline_arg*)varg;
89+
void* thread_arg = arg->arg;
90+
start = arg->start;
91+
free(arg);
92+
93+
alaska::thread_tracking::join();
94+
void* ret = start(thread_arg);
95+
alaska::thread_tracking::leave();
96+
97+
return ret;
98+
}
99+
100+
101+
// Hook into thread creation by overriding the pthread_create function
102+
#undef pthread_create
103+
extern "C" int pthread_create(pthread_t* __restrict thread, const pthread_attr_t* __restrict attr,
104+
void* (*start)(void*), void* __restrict arg) {
105+
int rc;
106+
static int (*real_create)(pthread_t* __restrict thread, const pthread_attr_t* __restrict attr,
107+
void* (*start)(void*), void* __restrict arg) = NULL;
108+
if (!real_create) real_create = (decltype(real_create))dlsym(RTLD_NEXT, "pthread_create");
109+
110+
auto* args = (struct alaska_pthread_trampoline_arg*)calloc(
111+
1, sizeof(struct alaska_pthread_trampoline_arg));
112+
args->arg = arg;
113+
args->start = start;
114+
rc = real_create(thread, attr, alaska_pthread_trampoline, args);
115+
return rc;
116+
}
117+
118+
static void __attribute__((constructor(102))) thread_tracking_init(void) {
119+
alaska::thread_tracking::join();
120+
}
121+
122+
static void __attribute__((destructor)) alaska_deinit(void) {
123+
alaska::thread_tracking::leave();
124+
}

runtime/rt/barrier.cpp

+31-36
Original file line numberDiff line numberDiff line change
@@ -103,35 +103,33 @@ static void patchNop(void) {
103103
}
104104
}
105105

106+
static void setup_signal_handlers(void);
107+
static void clear_pending_signals(void);
106108

107109

108110

109111
enum class JoinReason { Signal, Safepoint };
110112

111-
static alaska::ThreadRegistry<alaska_thread_state_t*>* g_threads;
112-
static auto& threads(void) {
113-
if (g_threads == NULL) g_threads = new alaska::ThreadRegistry<alaska_thread_state_t*>();
114-
return *g_threads;
115-
}
113+
114+
#define ALASKA_THREAD_TRACK_STATE_T AlaskaThreadState
115+
#define ALASKA_THREAD_TRACK_INIT setup_signal_handlers();
116+
#include <alaska/thread_tracking.in.hpp>
116117

117118

118119

119-
__thread alaska_thread_state_t alaska_thread_state;
120-
static pthread_mutex_t barrier_lock = PTHREAD_MUTEX_INITIALIZER;
121120

122121
// This is *the* barrier used in alaska_barrier to make sure threads are stopped correctly.
123122
static pthread_barrier_t the_barrier;
124123
static long barrier_last_num_threads = 0;
124+
static pthread_mutex_t barrier_lock = PTHREAD_MUTEX_INITIALIZER;
125125

126-
static void setup_signal_handlers(void);
127-
static void clear_pending_signals(void);
128126

129127

130128

131129
void alaska_remove_from_local_lock_list(void* ptr) { return; }
132130
static void alaska_dump_thread_states_r(void) {
133-
struct alaska_thread_info* pos;
134-
threads().for_each_locked([&](auto thread, alaska_thread_state_t* state) {
131+
struct info* pos;
132+
alaska::thread_tracking::threads().for_each_locked([&](auto thread, AlaskaThreadState* state) {
135133
if (state->escaped == 0) {
136134
printf("\e[0m. "); // a thread will join a barrier (not out to lunch)
137135
} else {
@@ -142,7 +140,7 @@ static void alaska_dump_thread_states_r(void) {
142140
}
143141

144142
void alaska_dump_thread_states(void) {
145-
auto lk = threads().take_lock();
143+
auto lk = alaska::thread_tracking::threads().take_lock();
146144
alaska_dump_thread_states_r();
147145
}
148146

@@ -266,13 +264,11 @@ void alaska::barrier::get_pinned_handles(ck::set<void*>& out) {
266264

267265

268266
static void participant_join(bool leader, const ck::set<void*>& ps) {
269-
printf("join from %lx with state %p%s\n", pthread_self(), &alaska_thread_state,
270-
leader ? " as leader" : "");
271267
for (auto* p : ps) {
272268
record_handle(p, true);
273269
}
274270
// Wait on the barrier so everyone's state has been commited.
275-
if (threads().num_threads() > 1) {
271+
if (alaska::thread_tracking::threads().num_threads() > 1) {
276272
pthread_barrier_wait(&the_barrier);
277273
}
278274
}
@@ -282,7 +278,7 @@ static void participant_join(bool leader, const ck::set<void*>& ps) {
282278

283279
static void participant_leave(bool leader, const ck::set<void*>& ps) {
284280
// wait for the the leader (and everyone else to catch up)
285-
if (threads().num_threads() > 1) {
281+
if (alaska::thread_tracking::threads().num_threads() > 1) {
286282
pthread_barrier_wait(&the_barrier);
287283
}
288284

@@ -294,8 +290,8 @@ static void participant_leave(bool leader, const ck::set<void*>& ps) {
294290

295291

296292
void dump_thread_states(void) {
297-
struct alaska_thread_info* pos;
298-
threads().for_each_locked([&](auto thread, alaska_thread_state_t* state) {
293+
struct info* pos;
294+
alaska::thread_tracking::threads().for_each_locked([&](auto thread, AlaskaThreadState* state) {
299295
switch (state->join_status) {
300296
case ALASKA_JOIN_REASON_NOT_JOINED:
301297
printf("\e[41m! "); // a thread will need to be interrupted (out to lunch)
@@ -334,27 +330,27 @@ bool alaska::barrier::begin(void) {
334330

335331
// Take locks so nobody else tries to signal a barrier.
336332
pthread_mutex_lock(&barrier_lock);
337-
threads().lock_thread_creation();
333+
alaska::thread_tracking::threads().lock_thread_creation();
338334

339335

340336

341-
auto num_threads = threads().num_threads();
337+
auto num_threads = alaska::thread_tracking::threads().num_threads();
342338
alaska::printf("Barrier begin from %lx:\n", pthread_self());
343339
alaska::printf(" num threads: %lu\n", num_threads);
344340

345341
printf(" threads:\n");
346-
threads().for_each_locked([](auto thread, alaska_thread_state_t* state) {
342+
alaska::thread_tracking::threads().for_each_locked([](auto thread, AlaskaThreadState* state) {
347343
alaska::printf(" - %lx %p %d\n", thread, state, state->join_status);
348344
});
349345

350346

351347
// First, mark everyone as *not* in the barrier.
352-
threads().for_each_locked([](auto thread, alaska_thread_state_t* state) {
348+
alaska::thread_tracking::threads().for_each_locked([](auto thread, AlaskaThreadState* state) {
353349
state->join_status = ALASKA_JOIN_REASON_NOT_JOINED;
354350
});
355351

356352
// Mark the orch thread (us) as joined
357-
alaska_thread_state.join_status = ALASKA_JOIN_REASON_ORCHESTRATOR;
353+
alaska::thread_tracking::my_state.join_status = ALASKA_JOIN_REASON_ORCHESTRATOR;
358354

359355
// If the barrier needs resizing, do so.
360356
if (barrier_last_num_threads != num_threads) {
@@ -382,9 +378,8 @@ bool alaska::barrier::begin(void) {
382378
bool sent_signal = false;
383379
bool aborted = false;
384380

385-
threads().for_each_locked([&](auto thread, auto* state) {
381+
alaska::thread_tracking::threads().for_each_locked([&](auto thread, auto* state) {
386382
if (state->join_status == ALASKA_JOIN_REASON_NOT_JOINED) {
387-
printf("killing %lu\n", thread);
388383
pthread_kill(thread, SIGUSR2);
389384
sent_signal = true;
390385
signals_sent++;
@@ -422,7 +417,7 @@ void alaska::barrier::end(void) {
422417
participant_leave(true, locked);
423418

424419
// Unlock all the locks we took.
425-
threads().unlock_thread_creation();
420+
alaska::thread_tracking::threads().unlock_thread_creation();
426421
pthread_mutex_unlock(&barrier_lock);
427422
}
428423

@@ -467,26 +462,26 @@ static void alaska_barrier_signal_handler(int sig, siginfo_t* info, void* ptr) {
467462
printf(
468463
"ManagedUntracked: pc:0x%zx sig:%d invl:%d!\n", return_address, sig, invalid_state_abort);
469464
if (sig == SIGILL) {
470-
alaska_thread_state.join_status = ALASKA_JOIN_REASON_ABORT;
465+
alaska::thread_tracking::my_state.join_status = ALASKA_JOIN_REASON_ABORT;
471466
break;
472467
}
473468
if (not invalid_state_abort) {
474469
invalid_state_abort = true;
475470
return;
476471
}
477-
alaska_thread_state.join_status = ALASKA_JOIN_REASON_ABORT;
472+
alaska::thread_tracking::my_state.join_status = ALASKA_JOIN_REASON_ABORT;
478473
invalid_state_abort = false;
479474
break;
480475

481476
case StackState::ManagedTracked:
482477
// it's possible to be at a managed poll point *and* get interrupted
483478
// through SIGUSR2
484-
alaska_thread_state.join_status = ALASKA_JOIN_REASON_SAFEPOINT;
479+
alaska::thread_tracking::my_state.join_status = ALASKA_JOIN_REASON_SAFEPOINT;
485480
break;
486481

487482
case StackState::Unmanaged:
488483
assert(sig == SIGUSR2 && "Unmanaged code got into the barrier handler w/ the wrong signal");
489-
alaska_thread_state.join_status = ALASKA_JOIN_REASON_SIGNAL;
484+
alaska::thread_tracking::my_state.join_status = ALASKA_JOIN_REASON_SIGNAL;
490485
break;
491486
}
492487

@@ -544,13 +539,13 @@ static void clear_pending_signals(void) {
544539
}
545540

546541

547-
void alaska::barrier::add_self_thread(void) {
548-
setup_signal_handlers();
549-
alaska_thread_state.escaped = 0;
550-
threads().join(&alaska_thread_state);
551-
}
542+
// void alaska::barrier::add_self_thread(void) {
543+
// setup_signal_handlers();
544+
// alaska::thread_tracking::join();
545+
// alaska::thread_tracking::my_state.escaped = 0;
546+
// }
552547

553-
void alaska::barrier::remove_self_thread(void) { threads().leave(); }
548+
// void alaska::barrier::remove_self_thread(void) { alaska::thread_tracking::leave(); }
554549

555550
/**
556551
* This function parses a stackmap emitted from LLVM and pushes all

0 commit comments

Comments
 (0)