Skip to content
This repository has been archived by the owner on Nov 9, 2022. It is now read-only.

Commit

Permalink
Make the API more thread-safe.
Browse files Browse the repository at this point in the history
Before this change, `ThreadTraceCollector` was exposed to the consumer
of the hwtracer API. Calling `TraceCollector::thread_tracer()` would
give back a fresh instance each time. This API would allow the user
request multiple `ThreadTraceCollector`s on the same thread and start
them all "collecting traces" concurrently. This leads to chaos as any
given thread can only be traced at most once at a time.

Until now we've swept this under the rug, warning the user not to do the
naughty thing unless they want to invoke UB, but the API still allows it
to happen.

This change hides `ThreadTraceCollector`s from the public API entirely,
leaving `TraceCollector`s to manage thread local instances of
`ThreadTraceCollector`.

The user now starts and stops trace collection of a thread with
`TraceCollector::start_thread_collector()` and
`TraceCollector::stop_thread_collector()`. `ThreadTraceCollector`s are
still there behind the scenes, but the user no longer sees them and thus
can't start two tracing the same thread any more.

One subtle side-effect of this new API is that we can no longer
automatically stop a running trace collector when the
`ThreadTraceCollector` goes out of scope (via `Drop`). Since
`ThreadTraceCollector`s are stored (and owned) by thread locals, they
never really "fall out of scope", so the user must ensure they
consistently stop trace collection before trying to re-trace the same
thread.

It may be tempting to stop the current thread's `ThreadTraceCollector`
when a `TraceCollector` goes out of scope, but then we'd have to track
which `TraceCollector` started the `ThreadTraceCollector` and only stop
the collector if they match. This starts to sound pretty over the top,
and I think the proposed API is already a lot better than what we had
before. Further `TraceCollector::start_thread_collector()` will return a
meaningful error message if the user has forgotten to stop a collector,
so it shouldn't be hard to track down instances of this blunder.

Fixes #101.
  • Loading branch information
vext01 committed Oct 21, 2022
1 parent 63cfd16 commit 13c466e
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 115 deletions.
7 changes: 3 additions & 4 deletions examples/simple_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,14 @@ fn work() -> u32 {
/// The results are printed to discourage the compiler from optimising the computation out.
fn main() {
let bldr = TraceCollectorBuilder::new();
let col = bldr.build().unwrap();
let mut thr_col = unsafe { col.thread_collector() };
let tc = bldr.build().unwrap();

for i in 1..4 {
thr_col.start_collector().unwrap_or_else(|e| {
tc.start_thread_collector().unwrap_or_else(|e| {
panic!("Failed to start collector: {}", e);
});
let res = work();
let trace = thr_col.stop_collector().unwrap();
let trace = tc.stop_thread_collector().unwrap();
let name = format!("trace{}", i);
print_trace(trace, &name, res, 10);
}
Expand Down
138 changes: 89 additions & 49 deletions src/collect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::{errors::HWTracerError, Trace};
use core::arch::x86_64::__cpuid_count;
use libc::{size_t, sysconf, _SC_PAGESIZE};
use std::{convert::TryFrom, sync::LazyLock};
use std::{cell::RefCell, convert::TryFrom, sync::LazyLock};
use strum::IntoEnumIterator;
use strum_macros::EnumIter;

Expand All @@ -22,19 +22,59 @@ static PERF_DFLT_AUX_BUFSIZE: LazyLock<size_t> = LazyLock::new(|| {

const PERF_DFLT_INITIAL_TRACE_BUFSIZE: size_t = 1024 * 1024; // 1MiB

/// The interface offered by all trace collectors.
pub trait TraceCollector: Send + Sync {
/// Obtain a `ThreadTraceCollector` for the current thread.
///
/// A thread may obtain multiple `ThreadTraceCollector`s but must only collect a trace with one
/// at a time.
///
/// FIXME: This API needs to be fixed:
/// https://github.com/ykjit/hwtracer/issues/101
thread_local! {
/// When `Some` holds the `ThreadTraceCollector` that is collecting a trace of the current
/// thread.
static THREAD_TRACE_COLLECTOR: RefCell<Option<Box<dyn ThreadTraceCollector>>> = RefCell::new(None);
}

/// The private innards of a `TraceCollector`.
pub(crate) trait TraceCollectorImpl: Send + Sync {
unsafe fn thread_collector(&self) -> Box<dyn ThreadTraceCollector>;
}

pub trait ThreadTraceCollector {
/// The public interface offered by all trace collectors.
pub struct TraceCollector {
col_impl: Box<dyn TraceCollectorImpl>,
}

impl TraceCollector {
pub(crate) fn new(col_impl: Box<dyn TraceCollectorImpl>) -> Self {
Self { col_impl }
}

/// Start collecting a trace of the current thread.
pub fn start_thread_collector(&self) -> Result<(), HWTracerError> {
THREAD_TRACE_COLLECTOR.with(|inner| {
let mut inner = inner.borrow_mut();
if inner.is_some() {
Err(HWTracerError::AlreadyCollecting)
} else {
let mut thr_col = unsafe { self.col_impl.thread_collector() };
thr_col.start_collector()?;
*inner = Some(thr_col);
Ok(())
}
})
}

/// Stop collecting a trace of the current thread.
pub fn stop_thread_collector(&self) -> Result<Box<dyn Trace>, HWTracerError> {
THREAD_TRACE_COLLECTOR.with(|inner| {
let mut inner = inner.borrow_mut();
if let Some(thr_col) = &mut *inner {
let ret = thr_col.stop_collector();
*inner = None;
ret
} else {
Err(HWTracerError::AlreadyStopped)
}
})
}
}

/// Represents a trace collection session for a single thread.
pub(crate) trait ThreadTraceCollector {
/// Start recording a trace.
///
/// Tracing continues until [stop_collector] is called.
Expand Down Expand Up @@ -182,13 +222,15 @@ impl TraceCollectorBuilder {
///
/// An error is returned if the requested collector is inappropriate for the platform or not
/// compiled in to hwtracer.
pub fn build(self) -> Result<Box<dyn TraceCollector>, HWTracerError> {
pub fn build(self) -> Result<TraceCollector, HWTracerError> {
let kind = self.config.kind();
kind.match_platform()?;
match self.config {
TraceCollectorConfig::Perf(_pt_conf) => {
#[cfg(collector_perf)]
return Ok(Box::new(PerfTraceCollector::new(_pt_conf)?));
return Ok(TraceCollector::new(Box::new(PerfTraceCollector::new(
_pt_conf,
)?)));
#[cfg(not(collector_perf))]
unreachable!();
}
Expand All @@ -198,81 +240,79 @@ impl TraceCollectorBuilder {

#[cfg(test)]
pub(crate) mod test_helpers {
use crate::{
collect::{ThreadTraceCollector, TraceCollector},
errors::HWTracerError,
test_helpers::work_loop,
Trace,
};
use crate::{collect::TraceCollector, errors::HWTracerError, test_helpers::work_loop, Trace};
use std::thread;

/// Trace a closure that returns a u64.
pub fn trace_closure<F>(tc: &mut dyn ThreadTraceCollector, f: F) -> Box<dyn Trace>
pub fn trace_closure<F>(tc: &TraceCollector, f: F) -> Box<dyn Trace>
where
F: FnOnce() -> u64,
{
tc.start_collector().unwrap();
tc.start_thread_collector().unwrap();
let res = f();
let trace = tc.stop_collector().unwrap();
let trace = tc.stop_thread_collector().unwrap();
println!("traced closure with result: {}", res); // To avoid over-optimisation.
trace
}

/// Check that starting and stopping a trace collector works.
pub fn basic_collection<T>(mut tracer: T)
where
T: ThreadTraceCollector,
{
let trace = trace_closure(&mut tracer, || work_loop(500));
pub fn basic_collection(tc: TraceCollector) {
let trace = trace_closure(&tc, || work_loop(500));
assert_ne!(trace.len(), 0);
}

/// Check that repeated usage of the same trace collector works.
pub fn repeated_collection<T>(mut tracer: T)
where
T: ThreadTraceCollector,
{
pub fn repeated_collection(tc: TraceCollector) {
for _ in 0..10 {
trace_closure(&mut tracer, || work_loop(500));
trace_closure(&tc, || work_loop(500));
}
}

/// Check that repeated collection using different collectors works.
pub fn repeated_collection_different_collectors(tcs: [TraceCollector; 10]) {
for i in 0..10 {
trace_closure(&tcs[i], || work_loop(500));
}
}

/// Check that starting a trace collector twice (without stopping maktracing inbetween) makes
/// an appropriate error.
pub fn already_started<T>(mut tc: T)
where
T: ThreadTraceCollector,
{
tc.start_collector().unwrap();
match tc.start_collector() {
pub fn already_started(tc: TraceCollector) {
tc.start_thread_collector().unwrap();
match tc.start_thread_collector() {
Err(HWTracerError::AlreadyCollecting) => (),
_ => panic!(),
};
tc.stop_collector().unwrap();
tc.stop_thread_collector().unwrap();
}

/// Check that an attempt to trace the same thread using different collectors fails.
pub fn already_started_different_collectors(tc1: TraceCollector, tc2: TraceCollector) {
tc1.start_thread_collector().unwrap();
match tc2.start_thread_collector() {
Err(HWTracerError::AlreadyCollecting) => (),
_ => panic!(),
};
tc1.stop_thread_collector().unwrap();
}

/// Check that stopping an unstarted trace collector makes an appropriate error.
pub fn not_started<T>(mut tc: T)
where
T: ThreadTraceCollector,
{
match tc.stop_collector() {
pub fn not_started(tc: TraceCollector) {
match tc.stop_thread_collector() {
Err(HWTracerError::AlreadyStopped) => (),
_ => panic!(),
};
}

/// Check that traces can be collected concurrently.
pub fn concurrent_collection(tc: &dyn TraceCollector) {
pub fn concurrent_collection(tc: TraceCollector) {
for _ in 0..10 {
thread::scope(|s| {
let hndl = s.spawn(|| {
let mut thr_c1 = unsafe { tc.thread_collector() };
trace_closure(&mut *thr_c1, || work_loop(500));
trace_closure(&tc, || work_loop(500));
});

let mut thr_c2 = unsafe { tc.thread_collector() };
trace_closure(&mut *thr_c2, || work_loop(500));
trace_closure(&tc, || work_loop(500));
hndl.join().unwrap();
});
}
Expand Down
68 changes: 38 additions & 30 deletions src/collect/perf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use super::PerfCollectorConfig;
use crate::{
c_errors::PerfPTCError,
collect::{ThreadTraceCollector, TraceCollector},
collect::{ThreadTraceCollector, TraceCollectorImpl},
errors::HWTracerError,
Trace,
};
Expand Down Expand Up @@ -76,7 +76,7 @@ impl PerfTraceCollector {
}
}

impl TraceCollector for PerfTraceCollector {
impl TraceCollectorImpl for PerfTraceCollector {
unsafe fn thread_collector(&self) -> Box<dyn ThreadTraceCollector> {
Box::new(PerfThreadTraceCollector::new(self.config.clone()))
}
Expand All @@ -88,8 +88,6 @@ pub struct PerfThreadTraceCollector {
config: PerfCollectorConfig,
// Opaque C pointer representing the collector context.
ctx: *mut c_void,
// The state of the collector.
is_tracing: bool,
// The trace currently being collected, or `None`.
trace: Option<Box<PerfTrace>>,
}
Expand All @@ -99,7 +97,6 @@ impl PerfThreadTraceCollector {
Self {
config,
ctx: ptr::null_mut(),
is_tracing: false,
trace: None,
}
}
Expand All @@ -111,21 +108,8 @@ impl Default for PerfThreadTraceCollector {
}
}

impl Drop for PerfThreadTraceCollector {
fn drop(&mut self) {
if self.is_tracing {
// If we haven't stopped the collector already, stop it now.
self.stop_collector().unwrap();
}
}
}

impl ThreadTraceCollector for PerfThreadTraceCollector {
fn start_collector(&mut self) -> Result<(), HWTracerError> {
if self.is_tracing {
return Err(HWTracerError::AlreadyCollecting);
}

// At the time of writing, we have to use a fresh Perf file descriptor to ensure traces
// start with a `PSB+` packet sequence. This is required for correct instruction-level and
// block-level decoding. Therefore we have to re-initialise for each new tracing session.
Expand All @@ -147,18 +131,13 @@ impl ThreadTraceCollector for PerfThreadTraceCollector {
if !unsafe { hwt_perf_start_collector(self.ctx, &mut *trace, &mut cerr) } {
return Err(cerr.into());
}
self.is_tracing = true;
self.trace = Some(trace);
Ok(())
}

fn stop_collector(&mut self) -> Result<Box<dyn Trace>, HWTracerError> {
if !self.is_tracing {
return Err(HWTracerError::AlreadyStopped);
}
let mut cerr = PerfPTCError::new();
let rc = unsafe { hwt_perf_stop_collector(self.ctx, &mut cerr) };
self.is_tracing = false;
if !rc {
return Err(cerr.into());
}
Expand Down Expand Up @@ -257,36 +236,65 @@ mod tests {
use super::{PerfCollectorConfig, PerfThreadTraceCollector};
use crate::{
collect::{
test_helpers, ThreadTraceCollector, TraceCollectorBuilder, TraceCollectorConfig,
TraceCollectorKind,
test_helpers, ThreadTraceCollector, TraceCollector, TraceCollectorBuilder,
TraceCollectorConfig, TraceCollectorKind,
},
errors::HWTracerError,
test_helpers::work_loop,
};

fn mk_collector() -> TraceCollector {
TraceCollectorBuilder::new()
.kind(TraceCollectorKind::Perf)
.build()
.unwrap()
}

#[test]
fn basic_collection() {
test_helpers::basic_collection(PerfThreadTraceCollector::default());
test_helpers::basic_collection(mk_collector());
}

#[test]
pub fn repeated_collection() {
test_helpers::repeated_collection(PerfThreadTraceCollector::default());
test_helpers::repeated_collection(mk_collector());
}

#[test]
pub fn repeated_collection_different_collectors() {
let tcs = [
mk_collector(),
mk_collector(),
mk_collector(),
mk_collector(),
mk_collector(),
mk_collector(),
mk_collector(),
mk_collector(),
mk_collector(),
mk_collector(),
];
test_helpers::repeated_collection_different_collectors(tcs);
}

#[test]
pub fn already_started() {
test_helpers::already_started(PerfThreadTraceCollector::default());
test_helpers::already_started(mk_collector());
}

#[test]
pub fn already_started_different_collectors() {
test_helpers::already_started_different_collectors(mk_collector(), mk_collector());
}

#[test]
pub fn not_started() {
test_helpers::not_started(PerfThreadTraceCollector::default());
test_helpers::not_started(mk_collector());
}

#[test]
fn concurrent_collection() {
test_helpers::concurrent_collection(&*TraceCollectorBuilder::new().build().unwrap());
test_helpers::concurrent_collection(mk_collector());
}

/// Check that a long trace causes the trace buffer to reallocate.
Expand Down
Loading

0 comments on commit 13c466e

Please sign in to comment.