Skip to content

Commit

Permalink
Introduce an asynchronous version
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanUkhov committed Dec 25, 2024
1 parent cf5f910 commit 9965827
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 71 deletions.
13 changes: 12 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "loop"
version = "0.1.0"
version = "0.2.0"
edition = "2021"
license = "Apache-2.0/MIT"
authors = ["Ivan Ukhov <ivan.ukhov@gmail.com>"]
Expand All @@ -10,3 +10,14 @@ homepage = "https://github.com/stainless-steel/loop"
repository = "https://github.com/stainless-steel/loop"
categories = ["algorithms"]
keywords = ["parallel"]

[features]
asynchronous = ["futures", "tokio", "tokio-stream"]

[dependencies]
futures = { version = "0.3", default-features = false, optional = true }
tokio = { version = "1", features = ["rt-multi-thread", "sync"], optional = true }
tokio-stream = { version = "0.1", default-features = false, optional = true }

[dev-dependencies]
tokio = { version = "1", features = ["macros"] }
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,15 @@ The package allows for processing iterators in parallel.
# Example

```rust
let map = |item: &_, context| std::io::Result::Ok(*item * context);
let (items, results): (Vec<_>, Vec<_>) = r#loop::parallelize(0..10, map, 2, None).unzip();
let map = |item, context| item * context;
let _ = r#loop::parallelize(0..10, map, 2, None).collect::<Vec<_>>();
```

```rust
use futures::stream::StreamExt;

let map = |item, context| async move { item * context };
let _ = r#loop::parallelize(0..10, map, 2, None).collect::<Vec<_>>().await;
```

## Contribution
Expand Down
69 changes: 69 additions & 0 deletions src/asynchronous.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::sync::Arc;

use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tokio_stream::wrappers::ReceiverStream;

/// Process an iterator in parallel.
pub fn parallelize<Items, Map, Context, Item, Future, Result>(
items: Items,
map: Map,
context: Context,
workers: Option<usize>,
) -> impl futures::stream::Stream<Item = Result>
where
Items: std::iter::Iterator<Item = Item> + Send + 'static,
Map: Fn(Item, Context) -> Future + Copy + Send + 'static,
Context: Clone + Send + 'static,
Item: Copy + Send + 'static,
Future: std::future::Future<Output = Result> + Send,
Result: Send + 'static,
{
let workers = crate::support::workers(workers);
let (forward_sender, forward_receiver) = mpsc::channel::<Item>(workers);
let (backward_sender, backward_receiver) = mpsc::channel::<Result>(workers);
let forward_receiver = Arc::new(Mutex::new(forward_receiver));
let mut _handlers = Vec::with_capacity(workers + 1);
for _ in 0..workers {
let forward_receiver = forward_receiver.clone();
let backward_sender = backward_sender.clone();
let context = context.clone();
_handlers.push(tokio::task::spawn(async move {
while let Some(item) = forward_receiver.lock().await.recv().await {
if backward_sender
.send(map(item, context.clone()).await)
.await
.is_err()
{
break;
}
}
}));
}
_handlers.push(tokio::task::spawn(async move {
for item in items {
if forward_sender.send(item).await.is_err() {
break;
}
}
}));
ReceiverStream::new(backward_receiver)
}

#[cfg(test)]
mod tests {
use futures::stream::StreamExt;

#[tokio::test]
async fn parallelize() {
let mut values = super::parallelize(0..10, map, 2, None)
.collect::<Vec<_>>()
.await;
values.sort();
assert_eq!(values, &[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
}

async fn map(item: i32, context: i64) -> usize {
item as usize * context as usize
}
}
96 changes: 28 additions & 68 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,76 +3,36 @@
//! # Example
//!
//! ```
//! let map = |item: &_, context| std::io::Result::Ok(*item * context);
//! let (items, results): (Vec<_>, Vec<_>) = r#loop::parallelize(0..10, map, 2, None).unzip();
//! # #[cfg(not(feature = "asynchronous"))]
//! fn main() {
//! let map = |item, context| item * context;
//! let _ = r#loop::parallelize(0..10, map, 2, None).collect::<Vec<_>>();
//! }
//! # #[cfg(feature = "asynchronous")]
//! # fn main() {}
//!```
//!
//!```
//! # #[cfg(feature = "asynchronous")]
//! #[tokio::main]
//! async fn main() {
//! use futures::stream::StreamExt;
//!
//! let map = |item, context| async move { item * context };
//! let _ = r#loop::parallelize(0..10, map, 2, None).collect::<Vec<_>>().await;
//! }
//! # #[cfg(not(feature = "asynchronous"))]
//! # fn main() {}
//! ```
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::thread;

/// Process an iterator in parallel.
pub fn parallelize<Iterator, Item, Map, Context, Value, Error>(
iterator: Iterator,
map: Map,
context: Context,
workers: Option<usize>,
) -> impl DoubleEndedIterator<Item = (Item, Result<Value, Error>)>
where
Iterator: std::iter::Iterator<Item = Item>,
Map: Fn(&Item, Context) -> Result<Value, Error> + Copy + Send + 'static,
Item: Send + 'static,
Context: Clone + Send + 'static,
Value: Send + 'static,
Error: Send + 'static,
{
let (forward_sender, forward_receiver) = mpsc::channel::<Item>();
let (backward_sender, backward_receiver) = mpsc::channel::<(Item, Result<Value, Error>)>();
let forward_receiver = Arc::new(Mutex::new(forward_receiver));

let workers = workers.unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|value| value.get())
.unwrap_or(1)
});
let _ = (0..workers)
.map(|_| {
let forward_receiver = forward_receiver.clone();
let backward_sender = backward_sender.clone();
let context = context.clone();
thread::spawn(move || loop {
let entry = match forward_receiver.lock().unwrap().recv() {
Ok(entry) => entry,
Err(_) => break,
};
let result = map(&entry, context.clone());
backward_sender.send((entry, result)).unwrap();
})
})
.collect::<Vec<_>>();
let mut count = 0;
for entry in iterator {
forward_sender.send(entry).unwrap();
count += 1;
}
(0..count).map(move |_| backward_receiver.recv().unwrap())
}
#[cfg(feature = "asynchronous")]
#[path = "asynchronous.rs"]
mod implementation;

#[cfg(test)]
mod tests {
macro_rules! ok(($result:expr) => ($result.unwrap()));
#[cfg(not(feature = "asynchronous"))]
#[path = "synchronous.rs"]
mod implementation;

#[test]
fn parallelize() {
let values = super::parallelize(0..10, map, 2, None)
.map(|(_, result)| ok!(result))
.collect::<std::collections::BTreeSet<_>>()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(values, &[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
}
mod support;

fn map(item: &i32, context: i64) -> std::io::Result<usize> {
Ok(*item as usize * context as usize)
}
}
pub use implementation::parallelize;
7 changes: 7 additions & 0 deletions src/support.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pub fn workers(value: Option<usize>) -> usize {
value.unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|value| value.get())
.unwrap_or(1)
})
}
57 changes: 57 additions & 0 deletions src/synchronous.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use std::sync::mpsc;
use std::sync::{Arc, Mutex};

/// Process an iterator in parallel.
pub fn parallelize<Items, Map, Context, Item, Result>(
items: Items,
map: Map,
context: Context,
workers: Option<usize>,
) -> impl Iterator<Item = Result>
where
Items: std::iter::Iterator<Item = Item> + Send + 'static,
Map: Fn(Item, Context) -> Result + Copy + Send + 'static,
Context: Clone + Send + 'static,
Item: Send + 'static,
Result: Send + 'static,
{
let workers = crate::support::workers(workers);
let (forward_sender, forward_receiver) = mpsc::sync_channel::<Item>(workers);
let (backward_sender, backward_receiver) = mpsc::sync_channel::<Result>(workers);
let forward_receiver = Arc::new(Mutex::new(forward_receiver));
let mut _handlers = Vec::with_capacity(workers + 1);
for _ in 0..workers {
let forward_receiver = forward_receiver.clone();
let backward_sender = backward_sender.clone();
let context = context.clone();
_handlers.push(std::thread::spawn(move || {
while let Ok(Ok(item)) = forward_receiver.lock().map(|receiver| receiver.recv()) {
if backward_sender.send(map(item, context.clone())).is_err() {
break;
}
}
}));
}
_handlers.push(std::thread::spawn(move || {
for item in items {
if forward_sender.send(item).is_err() {
break;
}
}
}));
backward_receiver.into_iter()
}

#[cfg(test)]
mod tests {
#[test]
fn parallelize() {
let mut values = super::parallelize(0..10, map, 2, None).collect::<Vec<_>>();
values.sort();
assert_eq!(values, &[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
}

fn map(item: i32, context: i64) -> usize {
item as usize * context as usize
}
}

0 comments on commit 9965827

Please sign in to comment.