diff --git a/Cargo.lock b/Cargo.lock index 2d73762f..3c6eb93d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -532,6 +532,7 @@ dependencies = [ "cordyceps", "futures-util", "loom", + "mycelium-bitfield", "mycelium-util", "mycotest", "pin-project", diff --git a/async/Cargo.toml b/async/Cargo.toml index 5203394b..e29edcbe 100644 --- a/async/Cargo.toml +++ b/async/Cargo.toml @@ -10,6 +10,7 @@ license = "MIT" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +mycelium-bitfield = { path = "../bitfield" } mycelium-util = { path = "../util" } mycotest = { path = "../mycotest", default-features = false } cordyceps = { path = "../cordyceps" } @@ -20,8 +21,10 @@ package = "tracing" default_features = false git = "https://github.com/tokio-rs/tracing" +[dev-dependencies] +futures-util = "0.3" + [target.'cfg(loom)'.dev-dependencies] loom = { version = "0.5.5", features = ["futures"] } tracing_01 = { package = "tracing", version = "0.1", default_features = false } -tracing_subscriber_03 = { package = "tracing-subscriber", version = "0.3.11", features = ["fmt"] } -futures-util = "0.3" \ No newline at end of file +tracing_subscriber_03 = { package = "tracing-subscriber", version = "0.3.11", features = ["fmt"] } \ No newline at end of file diff --git a/async/src/scheduler.rs b/async/src/scheduler.rs index 9028d22f..ef7d8f15 100644 --- a/async/src/scheduler.rs +++ b/async/src/scheduler.rs @@ -184,7 +184,7 @@ impl Future for Stub { #[cfg(all(test, not(loom)))] mod tests { - use super::test_util::Yield; + use super::test_util::{Chan, Yield}; use super::*; use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use mycelium_util::sync::Lazy; @@ -229,6 +229,57 @@ mod tests { assert!(!tick.has_remaining); } + #[test] + fn notify_future() { + static SCHEDULER: Lazy = Lazy::new(StaticScheduler::new); + static COMPLETED: AtomicUsize = AtomicUsize::new(0); + + let chan = Chan::new(1); + + SCHEDULER.spawn({ + let chan = chan.clone(); + async move { + chan.wait().await; + COMPLETED.fetch_add(1, Ordering::SeqCst); + } + }); + + SCHEDULER.spawn(async move { + Yield::once().await; + chan.notify(); + }); + + dbg!(SCHEDULER.tick()); + + assert_eq!(COMPLETED.load(Ordering::SeqCst), 1); + } + + #[test] + fn notify_external() { + static SCHEDULER: Lazy = Lazy::new(StaticScheduler::new); + static COMPLETED: AtomicUsize = AtomicUsize::new(0); + + let chan = Chan::new(1); + + SCHEDULER.spawn({ + let chan = chan.clone(); + async move { + chan.wait().await; + COMPLETED.fetch_add(1, Ordering::SeqCst); + } + }); + + dbg!(SCHEDULER.tick()); + + std::thread::spawn(move || { + chan.notify(); + }); + + dbg!(SCHEDULER.tick()); + + assert_eq!(COMPLETED.load(Ordering::SeqCst), 1); + } + #[test] fn many_yields() { static SCHEDULER: Lazy = Lazy::new(StaticScheduler::new); @@ -253,7 +304,7 @@ mod tests { #[cfg(all(test, loom))] mod loom { - use super::test_util::Yield; + use super::test_util::{Chan, Yield}; use super::*; use crate::loom::{ self, @@ -261,6 +312,7 @@ mod loom { atomic::{AtomicBool, AtomicUsize, Ordering}, Arc, }, + thread, }; use core::{ future::Future, @@ -316,6 +368,61 @@ mod loom { }) } + #[test] + fn notify_external() { + loom::model(|| { + let scheduler = Scheduler::new(); + let chan = Chan::new(1); + let it_worked = Arc::new(AtomicBool::new(false)); + + scheduler.spawn({ + let it_worked = it_worked.clone(); + let chan = chan.clone(); + track_future(async move { + chan.wait().await; + it_worked.store(true, Ordering::Release); + }) + }); + + thread::spawn(move || { + chan.notify(); + }); + + while scheduler.tick().completed < 1 { + thread::yield_now(); + } + + assert!(it_worked.load(Ordering::Acquire)); + }) + } + + #[test] + fn notify_future() { + loom::model(|| { + let scheduler = Scheduler::new(); + let chan = Chan::new(1); + let it_worked = Arc::new(AtomicBool::new(false)); + + scheduler.spawn({ + let it_worked = it_worked.clone(); + let chan = chan.clone(); + track_future(async move { + chan.wait().await; + it_worked.store(true, Ordering::Release); + }) + }); + + scheduler.spawn(async move { + Yield::once().await; + chan.notify(); + }); + + test_dbg!(scheduler.tick()); + + assert!(it_worked.load(Ordering::Acquire)); + }) + } + #[test] fn schedule_many() { const TASKS: usize = 10; @@ -391,6 +498,8 @@ mod test_util { task::{Context, Poll}, }; + pub(crate) use crate::wait::cell::test_util::Chan; + pub(crate) struct Yield { yields: usize, } diff --git a/async/src/task.rs b/async/src/task.rs index 6331bf95..5b44d4b4 100644 --- a/async/src/task.rs +++ b/async/src/task.rs @@ -9,16 +9,17 @@ use cordyceps::{mpsc_queue, Linked}; pub use core::task::{Context, Poll, Waker}; use core::{ any::type_name, - fmt, future::Future, + mem, pin::Pin, ptr::NonNull, task::{RawWaker, RawWakerVTable}, }; +use mycelium_util::fmt; mod state; -use self::state::StateVar; +use self::state::{OrDrop, ScheduleAction, StateCell}; #[derive(Debug)] pub(crate) struct TaskRef(NonNull
); @@ -27,7 +28,7 @@ pub(crate) struct TaskRef(NonNull
); #[derive(Debug)] pub(crate) struct Header { run_queue: mpsc_queue::Links
, - state: StateVar, + state: StateCell, // task_list: list::Links, vtable: &'static Vtable, } @@ -62,8 +63,8 @@ macro_rules! trace_task { ($ptr:expr, $f:ty, $method:literal) => { tracing::trace!( ptr = ?$ptr, - concat!("Task::::", $method), - type_name::<<$f>::Output>() + output = %type_name::<<$f>::Output>(), + concat!("Task::", $method), ); }; } @@ -86,7 +87,7 @@ impl Task { header: Header { run_queue: mpsc_queue::Links::new(), vtable: &Self::TASK_VTABLE, - state: StateVar::new(), + state: StateCell::new(), }, scheduler, inner: UnsafeCell::new(Cell::Future(future)), @@ -94,14 +95,19 @@ impl Task { } fn raw_waker(this: *const Self) -> RawWaker { - unsafe { (*this).header.state.clone_ref() }; RawWaker::new(this as *const (), &Self::WAKER_VTABLE) } + #[inline] + fn state(&self) -> &StateCell { + &self.header.state + } + unsafe fn clone_waker(ptr: *const ()) -> RawWaker { trace_task!(ptr, F, "clone_waker"); - - Self::raw_waker(ptr as *const Self) + let this = ptr as *const Self; + (*this).state().clone_ref(); + Self::raw_waker(this) } unsafe fn drop_waker(ptr: *const ()) { @@ -115,13 +121,31 @@ impl Task { trace_task!(ptr, F, "wake_by_val"); let this = non_null(ptr as *mut ()).cast::(); - Self::schedule(this); + match test_dbg!(this.as_ref().state().wake_by_val()) { + OrDrop::Drop => drop(Box::from_raw(this.as_ptr())), + OrDrop::Action(ScheduleAction::Enqueue) => { + // the task should be enqueued. + // + // in the case that the task is enqueued, the state + // transition does *not* decrement the reference count. this is + // in order to avoid dropping the task while it is being + // scheduled. one reference is consumed by enqueuing the task... + Self::schedule(this); + // now that the task has been enqueued, decrement the reference + // count to drop the waker that performed the `wake_by_val`. + Self::drop_ref(this); + } + OrDrop::Action(ScheduleAction::None) => {} + } } unsafe fn wake_by_ref(ptr: *const ()) { trace_task!(ptr, F, "wake_by_ref"); - Self::schedule(non_null(ptr as *mut ()).cast::()) + let this = non_null(ptr as *mut ()).cast::(); + if this.as_ref().state().wake_by_ref() == ScheduleAction::Enqueue { + Self::schedule(this); + } } #[inline(always)] @@ -134,7 +158,7 @@ impl Task { #[inline] unsafe fn drop_ref(this: NonNull) { trace_task!(this, F, "drop_ref"); - if !this.as_ref().header.state.drop_ref() { + if !this.as_ref().state().drop_ref() { return; } @@ -143,13 +167,36 @@ impl Task { unsafe fn poll(ptr: NonNull
) -> Poll<()> { trace_task!(ptr, F, "poll"); - let ptr = ptr.cast::(); - let waker = Waker::from_raw(Self::raw_waker(ptr.as_ptr())); + let mut this = ptr.cast::(); + test_trace!(task = ?fmt::alt(this.as_ref())); + // try to transition the task to the polling state + let state = &this.as_ref().state(); + match test_dbg!(state.start_poll()) { + // transitioned successfully! + Ok(_) => {} + Err(_state) => { + // TODO(eliza): could run the dealloc glue here instead of going + // through a ref cycle? + return Poll::Ready(()); + } + } + + // wrap the waker in `ManuallyDrop` because we're converting it from an + // existing task ref, rather than incrementing the task ref count. if + // this waker is consumed during the poll, we don't want to decrement + // its ref count when the poll ends. + let waker = mem::ManuallyDrop::new(Waker::from_raw(Self::raw_waker(this.as_ptr()))); let cx = Context::from_waker(&waker); - let pin = Pin::new_unchecked(ptr.cast::().as_mut()); + + // actually poll the task + let pin = Pin::new_unchecked(this.as_mut()); let poll = pin.poll_inner(cx); - if poll.is_ready() { - Self::drop_ref(ptr); + + // post-poll state transition + match test_dbg!(state.end_poll(poll.is_ready())) { + OrDrop::Drop => drop(Box::from_raw(this.as_ptr())), + OrDrop::Action(ScheduleAction::Enqueue) => Self::schedule(this), + OrDrop::Action(ScheduleAction::None) => {} } poll @@ -185,9 +232,9 @@ unsafe impl Sync for Task {} impl fmt::Debug for Task { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Task") - .field("future_type", &type_name::()) - .field("output_type", &type_name::()) - .field("scheduler_type", &type_name::()) + // .field("future_type", &fmt::display(type_name::())) + .field("output_type", &fmt::display(type_name::())) + .field("scheduler_type", &fmt::display(type_name::())) .field("header", &self.header) .field("inner", &self.inner) .finish() diff --git a/async/src/task/state.rs b/async/src/task/state.rs index 4a7ef8e2..36b2fc72 100644 --- a/async/src/task/state.rs +++ b/async/src/task/state.rs @@ -3,69 +3,180 @@ use crate::loom::sync::atomic::{ Ordering::{self, *}, }; use core::fmt; -use mycelium_util::bits::PackUsize; -#[derive(Clone, Copy)] -pub(crate) struct State(usize); +mycelium_bitfield::bitfield! { + /// A snapshot of a task's current state. + #[derive(PartialEq, Eq)] + pub(crate) struct State { + /// If set, this task is currently being polled. + pub(crate) const POLLING: bool; + /// If set, this task's [`Waker`] has been woken. + /// + /// [`Waker`]: core::task::Waker + pub(crate) const WOKEN: bool; + + /// If set, this task's [`Future`] has completed (i.e., it has returned + /// [`Poll::Ready`]). + /// + /// [`Future`]: core::future::Future + /// [`Poll::Ready`]: core::task::Poll::Ready + pub(crate) const COMPLETED: bool; + + /// The number of currently live references to this task. + /// + /// When this is 0, the task may be deallocated. + const REFS = ..; + } + +} + +/// An atomic cell that stores a task's current [`State`]. #[repr(transparent)] -pub(super) struct StateVar(AtomicUsize); +pub(super) struct StateCell(AtomicUsize); -impl State { - const RUNNING: PackUsize = PackUsize::least_significant(1); - const NOTIFIED: PackUsize = Self::RUNNING.next(1); - const COMPLETED: PackUsize = Self::NOTIFIED.next(1); - const REFS: PackUsize = Self::COMPLETED.remaining(); +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub(super) enum ScheduleAction { + /// The task should be enqueued. + Enqueue, - const REF_ONE: usize = Self::REFS.first_bit(); - const REF_MAX: usize = Self::REFS.raw_mask(); - // const STATE_MASK: usize = - // Self::RUNNING.raw_mask() | Self::NOTIFIED.raw_mask() | Self::COMPLETED.raw_mask(); + /// The task does not need to be enqueued. + None, +} - #[inline] - pub(crate) fn is_running(self) -> bool { - Self::RUNNING.contained_in_all(self.0) - } +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub(super) enum OrDrop { + /// Another action should be performed. + Action(T), - #[inline] - pub(crate) fn is_notified(self) -> bool { - Self::NOTIFIED.contained_in_all(self.0) - } + /// The task should be deallocated. + Drop, +} - #[inline] - pub(crate) fn is_completed(self) -> bool { - Self::NOTIFIED.contained_in_all(self.0) - } +pub(super) type WakeAction = OrDrop; +impl State { #[inline] pub(crate) fn ref_count(self) -> usize { - Self::REFS.unpack(self.0) + self.get(Self::REFS) } -} -impl fmt::Debug for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("State") - .field("running", &self.is_running()) - .field("notified", &self.is_notified()) - .field("completed", &self.is_completed()) - .field("ref_count", &self.ref_count()) - .field("bits", &format_args!("{:#b}", self.0)) - .finish() + fn drop_ref(self) -> Self { + Self(self.0 - REF_ONE) } -} -impl fmt::Binary for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "State({:#b})", self.0) + fn clone_ref(self) -> Self { + Self(self.0 + REF_ONE) } } -// === impl StateVar === +const REF_ONE: usize = State::REFS.first_bit(); +const REF_MAX: usize = State::REFS.raw_mask(); + +// === impl StateCell === -impl StateVar { +impl StateCell { pub fn new() -> Self { - Self(AtomicUsize::new(State::REF_ONE)) + Self(AtomicUsize::new(REF_ONE)) + } + + pub(super) fn start_poll(&self) -> Result { + self.transition(|state| { + // Cannot start polling a task which is being polled on another + // thread. + if test_dbg!(state.get(State::POLLING)) { + return Err(*state); + } + + // Cannot start polling a completed task. + if test_dbg!(state.get(State::COMPLETED)) { + return Err(*state); + } + + let new_state = state + // The task is now being polled. + .with(State::POLLING, true) + // If the task was woken, consume the wakeup. + .with(State::WOKEN, false); + *state = new_state; + Ok(new_state) + }) + } + + pub(super) fn end_poll(&self, completed: bool) -> WakeAction { + self.transition(|state| { + // Cannot end a poll if a task is not being polled! + debug_assert!(state.get(State::POLLING)); + debug_assert!(!state.get(State::COMPLETED)); + let next_state = state + .with(State::POLLING, false) + .with(State::COMPLETED, completed); + + // Was the task woken during the poll? + if !test_dbg!(completed) && test_dbg!(state.get(State::WOKEN)) { + *state = test_dbg!(next_state); + return OrDrop::Action(ScheduleAction::Enqueue); + } + + let next_state = test_dbg!(next_state.drop_ref()); + *state = next_state; + + if next_state.ref_count() == 0 { + OrDrop::Drop + } else { + OrDrop::Action(ScheduleAction::None) + } + }) + } + + /// Transition to the woken state by value, returning `true` if the task + /// should be enqueued. + pub(super) fn wake_by_val(&self) -> WakeAction { + self.transition(|state| { + // If the task was woken *during* a poll, it will be re-queued by the + // scheduler at the end of the poll if needed, so don't enqueue it now. + if test_dbg!(state.get(State::POLLING)) { + *state = state.with(State::WOKEN, true).drop_ref(); + assert!(state.ref_count() > 0); + + return OrDrop::Action(ScheduleAction::None); + } + + // If the task is already completed or woken, we don't need to + // requeue it, but decrement the ref count for the waker that was + // used for this wakeup. + if test_dbg!(state.get(State::COMPLETED)) || test_dbg!(state.get(State::WOKEN)) { + let new_state = state.drop_ref(); + *state = new_state; + return if new_state.ref_count() == 0 { + OrDrop::Drop + } else { + OrDrop::Action(ScheduleAction::None) + }; + } + + // Otherwise, transition to the notified state and enqueue the task. + *state = state.with(State::WOKEN, true).clone_ref(); + OrDrop::Action(ScheduleAction::Enqueue) + }) + } + + /// Transition to the woken state by ref, returning `true` if the task + /// should be enqueued. + pub(super) fn wake_by_ref(&self) -> ScheduleAction { + self.transition(|state| { + if test_dbg!(state.get(State::COMPLETED)) || test_dbg!(state.get(State::WOKEN)) { + return ScheduleAction::None; + } + + if test_dbg!(state.get(State::POLLING)) { + state.set(State::WOKEN, true); + return ScheduleAction::None; + } + + *state = state.with(State::WOKEN, true).clone_ref(); + ScheduleAction::Enqueue + }) } #[inline] @@ -81,7 +192,8 @@ impl StateVar { // another must already provide any required synchronization. // // [1]: (www.boost.org/doc/libs/1_55_0/doc/html/atomic/usage_examples.html) - let old_refs = test_dbg!(self.0.fetch_add(State::REF_ONE, Relaxed)); + let old_refs = self.0.fetch_add(REF_ONE, Relaxed); + test_dbg!(State::REFS.unpack(old_refs)); // However we need to guard against massive refcounts in case someone // is `mem::forget`ing tasks. If we don't do this the count can overflow @@ -92,19 +204,27 @@ impl StateVar { // // We abort because such a program is incredibly degenerate, and we // don't care to support it. - if test_dbg!(old_refs > State::REF_MAX) { + if test_dbg!(old_refs > REF_MAX) { panic!("task reference count overflow"); } } #[inline] pub(super) fn drop_ref(&self) -> bool { - // Because `cores` is already atomic, we do not need to synchronize - // with other threads unless we are going to delete the task. - let old_refs = test_dbg!(self.0.fetch_sub(State::REF_ONE, Release)); + // We do not need to synchronize with other cores unless we are going to + // delete the task. + let old_refs = self.0.fetch_sub(REF_ONE, Release); + + // Manually shift over the refcount to clear the state bits. We don't + // use the packing spec here, because it would also mask out any high + // bits, and we can avoid doing the bitwise-and (since there are no + // higher bits that are not part of the ref count). This is probably a + // premature optimization lol. + let old_refs = old_refs >> State::REFS.least_significant_index(); + test_dbg!(State::REFS.unpack(old_refs)); // Did we drop the last ref? - if test_dbg!(old_refs != State::REF_ONE) { + if test_dbg!(old_refs) > 1 { return false; } @@ -115,25 +235,49 @@ impl StateVar { pub(super) fn load(&self, order: Ordering) -> State { State(self.0.load(order)) } + + /// Advance this task's state by running the provided + /// `transition` function on the current [`State`]. + fn transition(&self, mut transition: impl FnMut(&mut State) -> T) -> T { + let mut current = self.load(Acquire); + loop { + let mut next = test_dbg!(current); + // Run the transition function. + let res = transition(&mut next); + + if current.0 == next.0 { + return res; + } + + match self + .0 + .compare_exchange_weak(current.0, next.0, AcqRel, Acquire) + { + Ok(_) => return res, + Err(actual) => current = State(actual), + } + } + } } -impl fmt::Debug for StateVar { +impl fmt::Debug for StateCell { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.load(Relaxed).fmt(f) } } -#[cfg(test)] +#[cfg(all(test, not(loom)))] mod tests { use super::*; #[test] fn packing_specs_valid() { - PackUsize::assert_all_valid(&[ - ("RUNNING", State::RUNNING), - ("NOTIFIED", State::NOTIFIED), - ("COMPLETED", State::COMPLETED), - ("REFS", State::REFS), - ]) + State::assert_valid() + } + + #[test] + fn debug_alt() { + let state = StateCell::new(); + println!("{:#?}", state); } } diff --git a/async/src/util.rs b/async/src/util.rs index 0b55d33a..d1156ecd 100644 --- a/async/src/util.rs +++ b/async/src/util.rs @@ -40,7 +40,7 @@ macro_rules! test_trace { macro_rules! fmt_bits { ($self: expr, $f: expr, $has_states: ident, $($name: ident),+) => { $( - if $self.contains(Self::$name) { + if $self.is(Self::$name) { if $has_states { $f.write_str(" | ")?; } diff --git a/async/src/wait.rs b/async/src/wait.rs index cc4d6841..b7309e52 100644 --- a/async/src/wait.rs +++ b/async/src/wait.rs @@ -3,14 +3,14 @@ //! This module implements two types of structure for waiting: a [`WaitCell`], //! which stores a *single* waiting task, and a wait *queue*, which //! stores a queue of waiting tasks. -mod cell; +pub(crate) mod cell; pub use cell::WaitCell; use core::task::Poll; /// An error indicating that a [`WaitCell`] or queue was closed while attempting /// register a waiter. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct Closed(()); pub type WaitResult = Result<(), Closed>; diff --git a/async/src/wait/cell.rs b/async/src/wait/cell.rs index 55c6bac1..25f6403e 100644 --- a/async/src/wait/cell.rs +++ b/async/src/wait/cell.rs @@ -56,16 +56,16 @@ impl WaitCell { } impl WaitCell { - pub fn wait(&self, waker: &Waker) -> Poll { + pub fn poll_wait(&self, waker: &Waker) -> Poll { tracing::trace!(wait_cell = ?fmt::ptr(self), ?waker, "registering waker"); // this is based on tokio's AtomicWaker synchronization strategy match test_dbg!(self.compare_exchange(State::WAITING, State::PARKING, Acquire)) { // someone else is notifying, so don't wait! - Err(actual) if test_dbg!(actual.contains(State::CLOSED)) => { + Err(actual) if test_dbg!(actual.is(State::CLOSED)) => { return wait::closed(); } - Err(actual) if test_dbg!(actual.contains(State::NOTIFYING)) => { + Err(actual) if test_dbg!(actual.is(State::NOTIFYING)) => { waker.wake_by_ref(); crate::loom::hint::spin_loop(); return wait::notified(); @@ -111,7 +111,7 @@ impl WaitCell { waker.wake(); } - if test_dbg!(state.contains(State::CLOSED)) { + if test_dbg!(state.is(State::CLOSED)) { wait::closed() } else { wait::notified() @@ -197,7 +197,7 @@ impl State { const NOTIFYING: Self = Self(0b10); const CLOSED: Self = Self(0b100); - fn contains(self, Self(state): Self) -> bool { + fn is(self, Self(state): Self) -> bool { self.0 & state == state } } @@ -238,87 +238,114 @@ impl fmt::Debug for State { } } -#[cfg(all(loom, test))] -mod loom { +#[cfg(test)] +#[allow(dead_code)] +pub(crate) mod test_util { use super::*; - use crate::loom::{ - future, - sync::atomic::{AtomicUsize, Ordering::Relaxed}, - thread, - }; - use core::task::Poll; + + use crate::loom::sync::atomic::{AtomicUsize, Ordering::Relaxed}; use std::sync::Arc; - struct Chan { + #[derive(Debug)] + pub(crate) struct Chan { num: AtomicUsize, task: WaitCell, + num_notify: usize, } - const NUM_NOTIFY: usize = 2; + impl Chan { + pub(crate) fn new(num_notify: usize) -> Arc { + Arc::new(Self { + num: AtomicUsize::new(0), + task: WaitCell::new(), + num_notify, + }) + } - async fn wait_on(chan: Arc) { - futures_util::future::poll_fn(move |cx| { - let res = test_dbg!(chan.task.wait(cx.waker())); + pub(crate) async fn wait(self: Arc) { + let this = Arc::downgrade(&self); + drop(self); + futures_util::future::poll_fn(move |cx| { + let this = match this.upgrade() { + Some(this) => this, + None => return Poll::Ready(()), + }; - if NUM_NOTIFY == chan.num.load(Relaxed) { - return Poll::Ready(()); - } + let res = test_dbg!(this.task.poll_wait(cx.waker())); - if res.is_ready() { - return Poll::Ready(()); - } + if this.num_notify == this.num.load(Relaxed) { + return Poll::Ready(()); + } + + if res.is_ready() { + return Poll::Ready(()); + } + + Poll::Pending + }) + .await + } + + pub(crate) fn notify(&self) { + self.num.fetch_add(1, Relaxed); + self.task.notify(); + } - Poll::Pending - }) - .await + pub(crate) fn close(&self) { + self.num.fetch_add(1, Relaxed); + self.task.close(); + } } + impl Drop for Chan { + fn drop(&mut self) { + tracing::debug!(chan = ?fmt::alt(self), "drop") + } + } +} + +#[cfg(all(loom, test))] +mod loom { + use super::*; + use crate::loom::{future, thread}; + + const NUM_NOTIFY: usize = 2; + #[test] - fn basic_notification() { + fn basic_latch() { crate::loom::model(|| { - let chan = Arc::new(Chan { - num: AtomicUsize::new(0), - task: WaitCell::new(), - }); + let chan = test_util::Chan::new(NUM_NOTIFY); for _ in 0..NUM_NOTIFY { let chan = chan.clone(); - thread::spawn(move || { - chan.num.fetch_add(1, Relaxed); - chan.task.notify(); - }); + thread::spawn(move || chan.notify()); } - future::block_on(wait_on(chan)); + future::block_on(chan.wait()); }); } #[test] fn close() { crate::loom::model(|| { - let chan = Arc::new(Chan { - num: AtomicUsize::new(0), - task: WaitCell::new(), - }); + let chan = test_util::Chan::new(NUM_NOTIFY); thread::spawn({ let chan = chan.clone(); move || { - chan.num.fetch_add(1, Relaxed); - chan.task.notify(); + chan.notify(); } }); thread::spawn({ let chan = chan.clone(); move || { - chan.num.fetch_add(1, Relaxed); - chan.task.close(); + chan.close(); } }); - future::block_on(wait_on(chan)); + future::block_on(chan.wait()); }); } } diff --git a/bitfield/src/bitfield.rs b/bitfield/src/bitfield.rs index db3af151..f2d2ad54 100644 --- a/bitfield/src/bitfield.rs +++ b/bitfield/src/bitfield.rs @@ -169,14 +169,15 @@ /// /// let my_bitfield = TypedBitfield::from_bits(0b0011_0101_1001_1110); /// let formatted = format!("{my_bitfield}"); +/// println!("{formatted}"); /// let expected = r#" -/// 00000000000000000011010110011110 -/// └┬─────┘││└┬───┘└┤ -/// │ ││ │ └ ENUM_VALUE: Baz (10) -/// │ ││ └────── SOME_BITS: 39 (100111) -/// │ │└─────────── FLAG_1: true (1) -/// │ └──────────── FLAG_2: false (0) -/// └─────────────────── A_BYTE: 13 (00001101) +/// 000011010110011110 +/// └┬─────┘││└┬───┘└┤ +/// │ ││ │ └ ENUM_VALUE: Baz (10) +/// │ ││ └────── SOME_BITS: 39 (100111) +/// │ │└─────────── FLAG_1: true (1) +/// │ └──────────── FLAG_2: false (0) +/// └─────────────────── A_BYTE: 13 (00001101) /// "#.trim_start(); /// assert_eq!(formatted, expected); /// ``` @@ -332,10 +333,25 @@ macro_rules! bitfield { #[automatically_derived] impl core::fmt::Display for $Name { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let mut truncated = self.0.leading_zeros(); + let mut width = $T::BITS - truncated; + let most_sig_field = Self::FIELDS[Self::FIELDS.len() - 1].1; + if width < most_sig_field.least_significant_index() + 1 { + width = most_sig_field.least_significant_index() + 3; + truncated = $T::BITS - (most_sig_field.least_significant_index() + 3); + } + let diff = most_sig_field.most_significant_index() as i32 - (width as i32); + if diff > 5 { + width += 1; + truncated -= 1; + } else if diff > 0 { + width += diff as u32; + truncated -= diff as u32; + } f.pad("")?; - writeln!(f, "{:0width$b}", self.0, width = $T::BITS as usize)?; + writeln!(f, "{:0width$b}", self.0, width = width as usize)?; f.pad("")?; - let mut cur_pos = $T::BITS; + let mut cur_pos = width; let mut max_len = 0; let mut rem = 0; let mut fields = Self::FIELDS.iter().rev().peekable(); @@ -345,6 +361,7 @@ macro_rules! bitfield { cur_pos -= 1; } let bits = field.bits(); + let mut sub_bits = bits; match (name, bits) { (name, bits) if name.starts_with("_") => { for _ in 0..bits { @@ -356,8 +373,15 @@ macro_rules! bitfield { (_, 1) => f.write_str("│")?, (_, 2) => f.write_str("└┤")?, (_, bits) => { - f.write_str("└┬")?; - for _ in 0..(bits - 3) { + let n_underlines = if field.most_significant_index() > cur_pos { + f.write_str("⋯ ┬")?; + sub_bits -= (field.most_significant_index() - width); + 4 + } else { + f.write_str("└┬")?; + 3 + }; + for _ in 0..(sub_bits.saturating_sub(n_underlines)) { f.write_str("─")?; } f.write_str("┘")?; @@ -365,11 +389,11 @@ macro_rules! bitfield { } if fields.peek().is_none() { - rem = cur_pos - (bits - 1); + rem = cur_pos - (sub_bits - 1); } max_len = core::cmp::max(max_len, name.len()); - cur_pos -= field.bits() + cur_pos -= sub_bits; } f.write_str("\n")?; @@ -379,7 +403,7 @@ macro_rules! bitfield { let name = stringify!($Field); if !name.starts_with("_") { f.pad("")?; - cur_pos = $T::BITS; + cur_pos = width; for (cur_name, cur_field) in Self::FIELDS.iter().rev() { while cur_pos > cur_field.most_significant_index() { f.write_str(" ")?; @@ -390,7 +414,13 @@ macro_rules! bitfield { break; } - let bits = cur_field.bits(); + let mut bits = cur_field.bits(); + let whitespace = if cur_field.most_significant_index() > cur_pos { + bits -= (cur_field.most_significant_index() - width); + true + } else { + false + }; match (cur_name, bits) { (name, bits) if name.starts_with("_") => { for _ in 0..bits { @@ -398,6 +428,12 @@ macro_rules! bitfield { } } (_, 1) => f.write_str("│")?, + (_, bits) if whitespace => { + f.write_str(" │")?; + for _ in 0..bits.saturating_sub(3) { + f.write_str(" ")?; + } + } (_, bits) => { f.write_str(" │")?; for _ in 0..(bits - 2) { @@ -409,10 +445,14 @@ macro_rules! bitfield { cur_pos -= bits; } - let field_bits = field.bits(); + let mut field_bits = field.bits(); if field_bits == 1 { f.write_str("└")?; cur_pos -= 1; + } else if field.most_significant_index() > width { + f.write_str(" └")?; + cur_pos -= 3; + field_bits -= truncated; } else { f.write_str(" └")?; cur_pos -= 2; @@ -633,7 +673,7 @@ mod tests { .with(TestBitfield::OF, 0) .with(TestBitfield::FUN, 9); println!("{}", test_bitfield); - + println!("empty:\n{}", TestBitfield::new()); let test_debug = TestDebug { value: 42, bits: test_bitfield, @@ -644,6 +684,24 @@ mod tests { println!("test_debug: {:?}", test_debug) } + #[test] + fn many_leading_zeros() { + bitfield! { + #[allow(dead_code)] + struct ManyLeadingZeros { + const A = 4; + const B: bool; + const C: bool; + const D = ..; + } + } + + let bitfield = ManyLeadingZeros::from_bits(0b1100_1011_0110); + println!("{bitfield}"); + let empty = ManyLeadingZeros::new(); + println!("{empty}"); + } + #[test] fn macro_bitfield_valid() { TestBitfield::assert_valid();