From 13c466e323e4379f913b3428aa734e88ef0916a4 Mon Sep 17 00:00:00 2001 From: Edd Barrett Date: Wed, 19 Oct 2022 15:46:35 +0100 Subject: [PATCH] Make the API more thread-safe. Before this change, `ThreadTraceCollector` was exposed to the consumer of the hwtracer API. Calling `TraceCollector::thread_tracer()` would give back a fresh instance each time. This API would allow the user request multiple `ThreadTraceCollector`s on the same thread and start them all "collecting traces" concurrently. This leads to chaos as any given thread can only be traced at most once at a time. Until now we've swept this under the rug, warning the user not to do the naughty thing unless they want to invoke UB, but the API still allows it to happen. This change hides `ThreadTraceCollector`s from the public API entirely, leaving `TraceCollector`s to manage thread local instances of `ThreadTraceCollector`. The user now starts and stops trace collection of a thread with `TraceCollector::start_thread_collector()` and `TraceCollector::stop_thread_collector()`. `ThreadTraceCollector`s are still there behind the scenes, but the user no longer sees them and thus can't start two tracing the same thread any more. One subtle side-effect of this new API is that we can no longer automatically stop a running trace collector when the `ThreadTraceCollector` goes out of scope (via `Drop`). Since `ThreadTraceCollector`s are stored (and owned) by thread locals, they never really "fall out of scope", so the user must ensure they consistently stop trace collection before trying to re-trace the same thread. It may be tempting to stop the current thread's `ThreadTraceCollector` when a `TraceCollector` goes out of scope, but then we'd have to track which `TraceCollector` started the `ThreadTraceCollector` and only stop the collector if they match. This starts to sound pretty over the top, and I think the proposed API is already a lot better than what we had before. Further `TraceCollector::start_thread_collector()` will return a meaningful error message if the user has forgotten to stop a collector, so it shouldn't be hard to track down instances of this blunder. Fixes #101. --- examples/simple_example.rs | 7 +- src/collect/mod.rs | 138 ++++++++++++++++++++++++------------- src/collect/perf/mod.rs | 68 ++++++++++-------- src/decode/libipt/mod.rs | 32 ++++----- src/decode/mod.rs | 13 ++-- tests/pt_chdir_rel.rs | 8 +-- 6 files changed, 151 insertions(+), 115 deletions(-) diff --git a/examples/simple_example.rs b/examples/simple_example.rs index a96db11..6c136eb 100644 --- a/examples/simple_example.rs +++ b/examples/simple_example.rs @@ -35,15 +35,14 @@ fn work() -> u32 { /// The results are printed to discourage the compiler from optimising the computation out. fn main() { let bldr = TraceCollectorBuilder::new(); - let col = bldr.build().unwrap(); - let mut thr_col = unsafe { col.thread_collector() }; + let tc = bldr.build().unwrap(); for i in 1..4 { - thr_col.start_collector().unwrap_or_else(|e| { + tc.start_thread_collector().unwrap_or_else(|e| { panic!("Failed to start collector: {}", e); }); let res = work(); - let trace = thr_col.stop_collector().unwrap(); + let trace = tc.stop_thread_collector().unwrap(); let name = format!("trace{}", i); print_trace(trace, &name, res, 10); } diff --git a/src/collect/mod.rs b/src/collect/mod.rs index ff011a8..85e1fad 100644 --- a/src/collect/mod.rs +++ b/src/collect/mod.rs @@ -3,7 +3,7 @@ use crate::{errors::HWTracerError, Trace}; use core::arch::x86_64::__cpuid_count; use libc::{size_t, sysconf, _SC_PAGESIZE}; -use std::{convert::TryFrom, sync::LazyLock}; +use std::{cell::RefCell, convert::TryFrom, sync::LazyLock}; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -22,19 +22,59 @@ static PERF_DFLT_AUX_BUFSIZE: LazyLock = LazyLock::new(|| { const PERF_DFLT_INITIAL_TRACE_BUFSIZE: size_t = 1024 * 1024; // 1MiB -/// The interface offered by all trace collectors. -pub trait TraceCollector: Send + Sync { - /// Obtain a `ThreadTraceCollector` for the current thread. - /// - /// A thread may obtain multiple `ThreadTraceCollector`s but must only collect a trace with one - /// at a time. - /// - /// FIXME: This API needs to be fixed: - /// https://github.com/ykjit/hwtracer/issues/101 +thread_local! { + /// When `Some` holds the `ThreadTraceCollector` that is collecting a trace of the current + /// thread. + static THREAD_TRACE_COLLECTOR: RefCell>> = RefCell::new(None); +} + +/// The private innards of a `TraceCollector`. +pub(crate) trait TraceCollectorImpl: Send + Sync { unsafe fn thread_collector(&self) -> Box; } -pub trait ThreadTraceCollector { +/// The public interface offered by all trace collectors. +pub struct TraceCollector { + col_impl: Box, +} + +impl TraceCollector { + pub(crate) fn new(col_impl: Box) -> Self { + Self { col_impl } + } + + /// Start collecting a trace of the current thread. + pub fn start_thread_collector(&self) -> Result<(), HWTracerError> { + THREAD_TRACE_COLLECTOR.with(|inner| { + let mut inner = inner.borrow_mut(); + if inner.is_some() { + Err(HWTracerError::AlreadyCollecting) + } else { + let mut thr_col = unsafe { self.col_impl.thread_collector() }; + thr_col.start_collector()?; + *inner = Some(thr_col); + Ok(()) + } + }) + } + + /// Stop collecting a trace of the current thread. + pub fn stop_thread_collector(&self) -> Result, HWTracerError> { + THREAD_TRACE_COLLECTOR.with(|inner| { + let mut inner = inner.borrow_mut(); + if let Some(thr_col) = &mut *inner { + let ret = thr_col.stop_collector(); + *inner = None; + ret + } else { + Err(HWTracerError::AlreadyStopped) + } + }) + } +} + +/// Represents a trace collection session for a single thread. +pub(crate) trait ThreadTraceCollector { /// Start recording a trace. /// /// Tracing continues until [stop_collector] is called. @@ -182,13 +222,15 @@ impl TraceCollectorBuilder { /// /// An error is returned if the requested collector is inappropriate for the platform or not /// compiled in to hwtracer. - pub fn build(self) -> Result, HWTracerError> { + pub fn build(self) -> Result { let kind = self.config.kind(); kind.match_platform()?; match self.config { TraceCollectorConfig::Perf(_pt_conf) => { #[cfg(collector_perf)] - return Ok(Box::new(PerfTraceCollector::new(_pt_conf)?)); + return Ok(TraceCollector::new(Box::new(PerfTraceCollector::new( + _pt_conf, + )?))); #[cfg(not(collector_perf))] unreachable!(); } @@ -198,81 +240,79 @@ impl TraceCollectorBuilder { #[cfg(test)] pub(crate) mod test_helpers { - use crate::{ - collect::{ThreadTraceCollector, TraceCollector}, - errors::HWTracerError, - test_helpers::work_loop, - Trace, - }; + use crate::{collect::TraceCollector, errors::HWTracerError, test_helpers::work_loop, Trace}; use std::thread; /// Trace a closure that returns a u64. - pub fn trace_closure(tc: &mut dyn ThreadTraceCollector, f: F) -> Box + pub fn trace_closure(tc: &TraceCollector, f: F) -> Box where F: FnOnce() -> u64, { - tc.start_collector().unwrap(); + tc.start_thread_collector().unwrap(); let res = f(); - let trace = tc.stop_collector().unwrap(); + let trace = tc.stop_thread_collector().unwrap(); println!("traced closure with result: {}", res); // To avoid over-optimisation. trace } /// Check that starting and stopping a trace collector works. - pub fn basic_collection(mut tracer: T) - where - T: ThreadTraceCollector, - { - let trace = trace_closure(&mut tracer, || work_loop(500)); + pub fn basic_collection(tc: TraceCollector) { + let trace = trace_closure(&tc, || work_loop(500)); assert_ne!(trace.len(), 0); } /// Check that repeated usage of the same trace collector works. - pub fn repeated_collection(mut tracer: T) - where - T: ThreadTraceCollector, - { + pub fn repeated_collection(tc: TraceCollector) { for _ in 0..10 { - trace_closure(&mut tracer, || work_loop(500)); + trace_closure(&tc, || work_loop(500)); + } + } + + /// Check that repeated collection using different collectors works. + pub fn repeated_collection_different_collectors(tcs: [TraceCollector; 10]) { + for i in 0..10 { + trace_closure(&tcs[i], || work_loop(500)); } } /// Check that starting a trace collector twice (without stopping maktracing inbetween) makes /// an appropriate error. - pub fn already_started(mut tc: T) - where - T: ThreadTraceCollector, - { - tc.start_collector().unwrap(); - match tc.start_collector() { + pub fn already_started(tc: TraceCollector) { + tc.start_thread_collector().unwrap(); + match tc.start_thread_collector() { Err(HWTracerError::AlreadyCollecting) => (), _ => panic!(), }; - tc.stop_collector().unwrap(); + tc.stop_thread_collector().unwrap(); + } + + /// Check that an attempt to trace the same thread using different collectors fails. + pub fn already_started_different_collectors(tc1: TraceCollector, tc2: TraceCollector) { + tc1.start_thread_collector().unwrap(); + match tc2.start_thread_collector() { + Err(HWTracerError::AlreadyCollecting) => (), + _ => panic!(), + }; + tc1.stop_thread_collector().unwrap(); } /// Check that stopping an unstarted trace collector makes an appropriate error. - pub fn not_started(mut tc: T) - where - T: ThreadTraceCollector, - { - match tc.stop_collector() { + pub fn not_started(tc: TraceCollector) { + match tc.stop_thread_collector() { Err(HWTracerError::AlreadyStopped) => (), _ => panic!(), }; } /// Check that traces can be collected concurrently. - pub fn concurrent_collection(tc: &dyn TraceCollector) { + pub fn concurrent_collection(tc: TraceCollector) { for _ in 0..10 { thread::scope(|s| { let hndl = s.spawn(|| { - let mut thr_c1 = unsafe { tc.thread_collector() }; - trace_closure(&mut *thr_c1, || work_loop(500)); + trace_closure(&tc, || work_loop(500)); }); - let mut thr_c2 = unsafe { tc.thread_collector() }; - trace_closure(&mut *thr_c2, || work_loop(500)); + trace_closure(&tc, || work_loop(500)); hndl.join().unwrap(); }); } diff --git a/src/collect/perf/mod.rs b/src/collect/perf/mod.rs index 7a8b732..52fdb7e 100644 --- a/src/collect/perf/mod.rs +++ b/src/collect/perf/mod.rs @@ -3,7 +3,7 @@ use super::PerfCollectorConfig; use crate::{ c_errors::PerfPTCError, - collect::{ThreadTraceCollector, TraceCollector}, + collect::{ThreadTraceCollector, TraceCollectorImpl}, errors::HWTracerError, Trace, }; @@ -76,7 +76,7 @@ impl PerfTraceCollector { } } -impl TraceCollector for PerfTraceCollector { +impl TraceCollectorImpl for PerfTraceCollector { unsafe fn thread_collector(&self) -> Box { Box::new(PerfThreadTraceCollector::new(self.config.clone())) } @@ -88,8 +88,6 @@ pub struct PerfThreadTraceCollector { config: PerfCollectorConfig, // Opaque C pointer representing the collector context. ctx: *mut c_void, - // The state of the collector. - is_tracing: bool, // The trace currently being collected, or `None`. trace: Option>, } @@ -99,7 +97,6 @@ impl PerfThreadTraceCollector { Self { config, ctx: ptr::null_mut(), - is_tracing: false, trace: None, } } @@ -111,21 +108,8 @@ impl Default for PerfThreadTraceCollector { } } -impl Drop for PerfThreadTraceCollector { - fn drop(&mut self) { - if self.is_tracing { - // If we haven't stopped the collector already, stop it now. - self.stop_collector().unwrap(); - } - } -} - impl ThreadTraceCollector for PerfThreadTraceCollector { fn start_collector(&mut self) -> Result<(), HWTracerError> { - if self.is_tracing { - return Err(HWTracerError::AlreadyCollecting); - } - // At the time of writing, we have to use a fresh Perf file descriptor to ensure traces // start with a `PSB+` packet sequence. This is required for correct instruction-level and // block-level decoding. Therefore we have to re-initialise for each new tracing session. @@ -147,18 +131,13 @@ impl ThreadTraceCollector for PerfThreadTraceCollector { if !unsafe { hwt_perf_start_collector(self.ctx, &mut *trace, &mut cerr) } { return Err(cerr.into()); } - self.is_tracing = true; self.trace = Some(trace); Ok(()) } fn stop_collector(&mut self) -> Result, HWTracerError> { - if !self.is_tracing { - return Err(HWTracerError::AlreadyStopped); - } let mut cerr = PerfPTCError::new(); let rc = unsafe { hwt_perf_stop_collector(self.ctx, &mut cerr) }; - self.is_tracing = false; if !rc { return Err(cerr.into()); } @@ -257,36 +236,65 @@ mod tests { use super::{PerfCollectorConfig, PerfThreadTraceCollector}; use crate::{ collect::{ - test_helpers, ThreadTraceCollector, TraceCollectorBuilder, TraceCollectorConfig, - TraceCollectorKind, + test_helpers, ThreadTraceCollector, TraceCollector, TraceCollectorBuilder, + TraceCollectorConfig, TraceCollectorKind, }, errors::HWTracerError, test_helpers::work_loop, }; + fn mk_collector() -> TraceCollector { + TraceCollectorBuilder::new() + .kind(TraceCollectorKind::Perf) + .build() + .unwrap() + } + #[test] fn basic_collection() { - test_helpers::basic_collection(PerfThreadTraceCollector::default()); + test_helpers::basic_collection(mk_collector()); } #[test] pub fn repeated_collection() { - test_helpers::repeated_collection(PerfThreadTraceCollector::default()); + test_helpers::repeated_collection(mk_collector()); + } + + #[test] + pub fn repeated_collection_different_collectors() { + let tcs = [ + mk_collector(), + mk_collector(), + mk_collector(), + mk_collector(), + mk_collector(), + mk_collector(), + mk_collector(), + mk_collector(), + mk_collector(), + mk_collector(), + ]; + test_helpers::repeated_collection_different_collectors(tcs); } #[test] pub fn already_started() { - test_helpers::already_started(PerfThreadTraceCollector::default()); + test_helpers::already_started(mk_collector()); + } + + #[test] + pub fn already_started_different_collectors() { + test_helpers::already_started_different_collectors(mk_collector(), mk_collector()); } #[test] pub fn not_started() { - test_helpers::not_started(PerfThreadTraceCollector::default()); + test_helpers::not_started(mk_collector()); } #[test] fn concurrent_collection() { - test_helpers::concurrent_collection(&*TraceCollectorBuilder::new().build().unwrap()); + test_helpers::concurrent_collection(mk_collector()); } /// Check that a long trace causes the trace buffer to reallocate. diff --git a/src/decode/libipt/mod.rs b/src/decode/libipt/mod.rs index 4f21a80..a9d70f4 100644 --- a/src/decode/libipt/mod.rs +++ b/src/decode/libipt/mod.rs @@ -157,8 +157,7 @@ mod tests { use super::{LibIPTBlockIterator, PerfPTCError}; use crate::{ collect::{ - perf::PerfTrace, test_helpers::trace_closure, ThreadTraceCollector, - TraceCollectorBuilder, + perf::PerfTrace, test_helpers::trace_closure, TraceCollector, TraceCollectorBuilder, }, decode::{test_helpers, TraceDecoderKind}, errors::HWTracerError, @@ -311,11 +310,11 @@ mod tests { } /// Trace a closure and then decode it and check the block iterator agrees with ptxed. - fn trace_and_check_blocks(tracer: &mut dyn ThreadTraceCollector, f: F) + fn trace_and_check_blocks(tc: &TraceCollector, f: F) where F: FnOnce() -> u64, { - let trace = trace_closure(tracer, f); + let trace = trace_closure(&tc, f); let expects = get_expected_blocks(&trace); test_helpers::test_expected_blocks(trace, TraceDecoderKind::LibIPT, expects.iter()); } @@ -323,15 +322,15 @@ mod tests { /// Check that the block decoder agrees with the reference implementation in ptxed. #[test] fn versus_ptxed_short_trace() { - let tracer = TraceCollectorBuilder::new().build().unwrap(); - trace_and_check_blocks(&mut *unsafe { tracer.thread_collector() }, || work_loop(10)); + let tc = TraceCollectorBuilder::new().build().unwrap(); + trace_and_check_blocks(&tc, || work_loop(10)); } /// Check that the block decoder agrees ptxed on a (likely) empty trace; #[test] fn versus_ptxed_empty_trace() { - let tracer = TraceCollectorBuilder::new().build().unwrap(); - trace_and_check_blocks(&mut *unsafe { tracer.thread_collector() }, || work_loop(0)); + let tc = TraceCollectorBuilder::new().build().unwrap(); + trace_and_check_blocks(&tc, || work_loop(0)); } /// Check that our block decoder deals with traces involving the VDSO correctly. @@ -339,8 +338,8 @@ mod tests { fn versus_ptxed_vdso() { use libc::{clock_gettime, timespec, CLOCK_MONOTONIC}; - let tracer = TraceCollectorBuilder::new().build().unwrap(); - trace_and_check_blocks(&mut *unsafe { tracer.thread_collector() }, || { + let tc = TraceCollectorBuilder::new().build().unwrap(); + trace_and_check_blocks(&tc, || { let mut res = 0; let mut tv = timespec { tv_sec: 0, @@ -359,10 +358,8 @@ mod tests { /// Check that the block decoder agrees with ptxed on long trace. #[test] fn versus_ptxed_long_trace() { - let tracer = TraceCollectorBuilder::new().build().unwrap(); - trace_and_check_blocks(&mut *unsafe { tracer.thread_collector() }, || { - work_loop(3000) - }); + let tc = TraceCollectorBuilder::new().build().unwrap(); + trace_and_check_blocks(&tc, || work_loop(3000)); } /// Check that a block iterator returns none after an error. @@ -393,10 +390,7 @@ mod tests { #[test] fn ten_times_as_many_blocks() { - let col = TraceCollectorBuilder::new().build().unwrap(); - test_helpers::ten_times_as_many_blocks( - &mut *unsafe { col.thread_collector() }, - TraceDecoderKind::LibIPT, - ); + let tc = TraceCollectorBuilder::new().build().unwrap(); + test_helpers::ten_times_as_many_blocks(tc, TraceDecoderKind::LibIPT); } } diff --git a/src/decode/mod.rs b/src/decode/mod.rs index 965db4c..c512e64 100644 --- a/src/decode/mod.rs +++ b/src/decode/mod.rs @@ -86,7 +86,7 @@ impl TraceDecoderBuilder { mod test_helpers { use super::{TraceDecoder, TraceDecoderBuilder, TraceDecoderKind}; use crate::{ - collect::{test_helpers::trace_closure, ThreadTraceCollector}, + collect::{test_helpers::trace_closure, TraceCollector}, test_helpers::work_loop, Block, Trace, }; @@ -121,12 +121,9 @@ mod test_helpers { /// Trace two loops, one 10x larger than the other, then check the proportions match the number /// of block the trace passes through. - pub fn ten_times_as_many_blocks( - thr_col: &mut dyn ThreadTraceCollector, - decoder_kind: TraceDecoderKind, - ) { - let trace1 = trace_closure(thr_col, || work_loop(10)); - let trace2 = trace_closure(thr_col, || work_loop(100)); + pub fn ten_times_as_many_blocks(mut tc: TraceCollector, decoder_kind: TraceDecoderKind) { + let trace1 = trace_closure(&mut tc, || work_loop(10)); + let trace2 = trace_closure(&mut tc, || work_loop(100)); let dec: Box = TraceDecoderBuilder::new() .kind(decoder_kind) @@ -138,6 +135,6 @@ mod test_helpers { // Should be roughly 10x more blocks in trace2. It won't be exactly 10x, due to the stuff // we trace either side of the loop itself. On a smallish trace, that will be significant. - assert!(ct2 > ct1 * 9); + assert!(ct2 > ct1 * 8); } } diff --git a/tests/pt_chdir_rel.rs b/tests/pt_chdir_rel.rs index bc24243..9fff7bb 100644 --- a/tests/pt_chdir_rel.rs +++ b/tests/pt_chdir_rel.rs @@ -42,12 +42,10 @@ fn pt_chdir_rel() { // When we get here, we have a process that was invoked with a relative path. - let tcol = TraceCollectorBuilder::new().build().unwrap(); - let mut thr_col = unsafe { tcol.thread_collector() }; - - thr_col.start_collector().unwrap(); + let tc = TraceCollectorBuilder::new().build().unwrap(); + tc.start_thread_collector().unwrap(); println!("{}", work_loop(env::args().len() as u64)); - let trace = thr_col.stop_collector().unwrap(); + let trace = tc.stop_thread_collector().unwrap(); // Now check that the trace decoder can still find its objects after we change dir. let tdec = TraceDecoderBuilder::new().build().unwrap();