Skip to content

Commit

Permalink
Merge pull request danieldg#11 from zeenix/fix-join-multiple
Browse files Browse the repository at this point in the history
Fix join multiple
  • Loading branch information
zeenix authored Dec 27, 2022
2 parents 364b5ea + 6729e9a commit 86e70c2
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 13 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ keywords = ["async", "stream", "timestamp"]
[dependencies]
futures-core = "0.3"
pin-project-lite = "0.2"

[dev-dependencies]
futures-executor = "0.3.25"
futures-util = "0.3.25"
6 changes: 6 additions & 0 deletions src/adapters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1006,3 +1006,9 @@ impl<S: OrderedStream> OrderedStream for Peekable<S> {
}
}
}

impl<S: OrderedStream> FusedOrderedStream for Peekable<S> {
fn is_terminated(&self) -> bool {
self.stream.is_none()
}
}
36 changes: 36 additions & 0 deletions src/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,39 @@ where
matches!(self.state, JoinState::Terminated)
}
}

#[cfg(test)]
mod test {
use crate::join;
use crate::FromStream;
use crate::OrderedStreamExt;

pub struct Message {
serial: u32,
}

#[test]
fn join_two() {
futures_executor::block_on(async {
let stream1 = futures_util::stream::iter([
Message { serial: 1 },
Message { serial: 3 },
Message { serial: 5 },
]);

let stream2 = futures_util::stream::iter([
Message { serial: 2 },
Message { serial: 4 },
Message { serial: 6 },
]);
let mut joined = join(
FromStream::with_ordering(stream1, |m| m.serial),
FromStream::with_ordering(stream2, |m| m.serial),
);
for i in 0..6 {
let msg = joined.next().await.unwrap();
assert_eq!(msg.serial, i as u32 + 1);
}
});
}
}
85 changes: 72 additions & 13 deletions src/multi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ where
// The stream with the earliest item that is actually before the given point
let mut best: Option<Pin<P>> = None;
let mut has_data = false;
let mut has_pending = true;
let mut has_pending = false;
for mut stream in streams {
let best_before = best.as_ref().and_then(|p| p.item().map(|i| &i.0));
let before = match (before, best_before) {
Expand All @@ -32,24 +32,20 @@ where
Poll::Ready(PollResult::NoneBefore) => {
has_data = true;
}
Poll::Ready(PollResult::Item { ordering, .. }) => {
match before {
// skip the compare if it doesn't matter
_ if has_pending => continue,
Some(max) if max < ordering => continue,
_ => {
best = Some(stream);
}
Poll::Ready(PollResult::Item { ordering, .. }) => match before {
Some(max) if max < ordering => continue,
_ => {
best = Some(stream);
}
}
},
}
}
match best {
_ if has_pending => Poll::Pending,
// This is guaranteed to return PollResult::Item
Some(mut stream) => stream.as_mut().poll_next_before(cx, before),
None if has_data => Poll::Ready(PollResult::NoneBefore),
None if has_pending => Poll::Pending,
None => Poll::Ready(PollResult::Terminated),
// This is guaranteed to return PollResult::Item
Some(mut stream) => stream.as_mut().poll_next_before(cx, before),
}
}

Expand Down Expand Up @@ -90,6 +86,17 @@ where
}
}

impl<C, S> FusedOrderedStream for JoinMultiple<C>
where
for<'a> &'a mut C: IntoIterator<Item = &'a mut Peekable<S>>,
for<'a> &'a C: IntoIterator<Item = &'a Peekable<S>>,
S: OrderedStream + Unpin,
{
fn is_terminated(&self) -> bool {
self.0.into_iter().all(|peekable| peekable.is_terminated())
}
}

pin_project_lite::pin_project! {
/// Join a collection of pinned [`OrderedStream`]s.
///
Expand Down Expand Up @@ -127,3 +134,55 @@ where
poll_multiple(self.as_pin_mut(), cx, before)
}
}

#[cfg(test)]
mod test {
extern crate alloc;

use crate::FromStream;
use crate::JoinMultiple;
use crate::OrderedStreamExt;
use alloc::boxed::Box;
use alloc::vec::Vec;
use core::pin::Pin;
use futures_core::Stream;

#[test]
fn join_mutiple() {
futures_executor::block_on(async {
pub struct Message {
serial: u32,
}

pub struct RemoteLogSource {
stream: Pin<Box<dyn Stream<Item = Message>>>,
}

let mut logs = [
RemoteLogSource {
stream: Box::pin(futures_util::stream::iter([
Message { serial: 1 },
Message { serial: 3 },
Message { serial: 5 },
])),
},
RemoteLogSource {
stream: Box::pin(futures_util::stream::iter([
Message { serial: 2 },
Message { serial: 4 },
Message { serial: 6 },
])),
},
];
let streams: Vec<_> = logs
.iter_mut()
.map(|s| FromStream::with_ordering(&mut s.stream, |m| m.serial).peekable())
.collect();
let mut joined = JoinMultiple(streams);
for i in 0..6 {
let msg = joined.next().await.unwrap();
assert_eq!(msg.serial, i as u32 + 1);
}
});
}
}

0 comments on commit 86e70c2

Please sign in to comment.