diff --git a/Cargo.toml b/Cargo.toml index 1a0e7af..3ef4cb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,7 @@ keywords = ["async", "stream", "timestamp"] [dependencies] futures-core = "0.3" pin-project-lite = "0.2" + +[dev-dependencies] +futures-executor = "0.3.25" +futures-util = "0.3.25" diff --git a/src/adapters.rs b/src/adapters.rs index 56ab453..922fb75 100644 --- a/src/adapters.rs +++ b/src/adapters.rs @@ -1006,3 +1006,9 @@ impl OrderedStream for Peekable { } } } + +impl FusedOrderedStream for Peekable { + fn is_terminated(&self) -> bool { + self.stream.is_none() + } +} diff --git a/src/join.rs b/src/join.rs index 63d3823..66fc821 100644 --- a/src/join.rs +++ b/src/join.rs @@ -379,3 +379,39 @@ where matches!(self.state, JoinState::Terminated) } } + +#[cfg(test)] +mod test { + use crate::join; + use crate::FromStream; + use crate::OrderedStreamExt; + + pub struct Message { + serial: u32, + } + + #[test] + fn join_two() { + futures_executor::block_on(async { + let stream1 = futures_util::stream::iter([ + Message { serial: 1 }, + Message { serial: 3 }, + Message { serial: 5 }, + ]); + + let stream2 = futures_util::stream::iter([ + Message { serial: 2 }, + Message { serial: 4 }, + Message { serial: 6 }, + ]); + let mut joined = join( + FromStream::with_ordering(stream1, |m| m.serial), + FromStream::with_ordering(stream2, |m| m.serial), + ); + for i in 0..6 { + let msg = joined.next().await.unwrap(); + assert_eq!(msg.serial, i as u32 + 1); + } + }); + } +} diff --git a/src/multi.rs b/src/multi.rs index 6e5627d..45bf5a7 100644 --- a/src/multi.rs +++ b/src/multi.rs @@ -16,7 +16,7 @@ where // The stream with the earliest item that is actually before the given point let mut best: Option> = None; let mut has_data = false; - let mut has_pending = true; + let mut has_pending = false; for mut stream in streams { let best_before = best.as_ref().and_then(|p| p.item().map(|i| &i.0)); let before = match (before, best_before) { @@ -32,24 +32,20 @@ where Poll::Ready(PollResult::NoneBefore) => { has_data = true; } - Poll::Ready(PollResult::Item { ordering, .. }) => { - match before { - // skip the compare if it doesn't matter - _ if has_pending => continue, - Some(max) if max < ordering => continue, - _ => { - best = Some(stream); - } + Poll::Ready(PollResult::Item { ordering, .. }) => match before { + Some(max) if max < ordering => continue, + _ => { + best = Some(stream); } - } + }, } } match best { - _ if has_pending => Poll::Pending, - // This is guaranteed to return PollResult::Item - Some(mut stream) => stream.as_mut().poll_next_before(cx, before), None if has_data => Poll::Ready(PollResult::NoneBefore), + None if has_pending => Poll::Pending, None => Poll::Ready(PollResult::Terminated), + // This is guaranteed to return PollResult::Item + Some(mut stream) => stream.as_mut().poll_next_before(cx, before), } } @@ -90,6 +86,17 @@ where } } +impl FusedOrderedStream for JoinMultiple +where + for<'a> &'a mut C: IntoIterator>, + for<'a> &'a C: IntoIterator>, + S: OrderedStream + Unpin, +{ + fn is_terminated(&self) -> bool { + self.0.into_iter().all(|peekable| peekable.is_terminated()) + } +} + pin_project_lite::pin_project! { /// Join a collection of pinned [`OrderedStream`]s. /// @@ -127,3 +134,55 @@ where poll_multiple(self.as_pin_mut(), cx, before) } } + +#[cfg(test)] +mod test { + extern crate alloc; + + use crate::FromStream; + use crate::JoinMultiple; + use crate::OrderedStreamExt; + use alloc::boxed::Box; + use alloc::vec::Vec; + use core::pin::Pin; + use futures_core::Stream; + + #[test] + fn join_mutiple() { + futures_executor::block_on(async { + pub struct Message { + serial: u32, + } + + pub struct RemoteLogSource { + stream: Pin>>, + } + + let mut logs = [ + RemoteLogSource { + stream: Box::pin(futures_util::stream::iter([ + Message { serial: 1 }, + Message { serial: 3 }, + Message { serial: 5 }, + ])), + }, + RemoteLogSource { + stream: Box::pin(futures_util::stream::iter([ + Message { serial: 2 }, + Message { serial: 4 }, + Message { serial: 6 }, + ])), + }, + ]; + let streams: Vec<_> = logs + .iter_mut() + .map(|s| FromStream::with_ordering(&mut s.stream, |m| m.serial).peekable()) + .collect(); + let mut joined = JoinMultiple(streams); + for i in 0..6 { + let msg = joined.next().await.unwrap(); + assert_eq!(msg.serial, i as u32 + 1); + } + }); + } +}