From 890977133a4844f347d89a1978927bea6196db2b Mon Sep 17 00:00:00 2001 From: fys Date: Tue, 14 Jan 2025 22:23:17 +0800 Subject: [PATCH] feat: expose result_sender to DataProcess trait --- examples/consumer.rs | 16 +++++++++++----- src/consumer.rs | 33 +++++++++++++++------------------ src/consumer/process.rs | 26 ++++++++++++++++++++++---- tests/common.rs | 14 +++++++++++--- 4 files changed, 59 insertions(+), 30 deletions(-) diff --git a/examples/consumer.rs b/examples/consumer.rs index e49399d..2563cbb 100644 --- a/examples/consumer.rs +++ b/examples/consumer.rs @@ -1,7 +1,7 @@ use std::str::from_utf8; use std::time::Duration; -use shm_ringbuf::consumer::process::DataProcess; +use shm_ringbuf::consumer::process::{DataProcess, ResultSender}; use shm_ringbuf::consumer::settings::ConsumerSettingsBuilder; use shm_ringbuf::consumer::RingbufConsumer; use shm_ringbuf::error::DataProcessResult; @@ -23,13 +23,19 @@ async fn main() { pub struct StringPrint; impl DataProcess for StringPrint { - type Error = Error; + async fn process(&self, data: &[u8], result_sender: ResultSender) { + if let Err(e) = self.do_process(data).await { + result_sender.push_result(e).await; + } else { + result_sender.push_ok().await; + } + } +} - async fn process(&self, data: &[u8]) -> Result<(), Self::Error> { +impl StringPrint { + async fn do_process(&self, data: &[u8]) -> Result<(), Error> { let msg = from_utf8(data).map_err(|_| Error::DecodeError)?; - info!("receive: {}", msg); - Ok(()) } } diff --git a/src/consumer.rs b/src/consumer.rs index 7e57002..54a40d0 100644 --- a/src/consumer.rs +++ b/src/consumer.rs @@ -3,13 +3,13 @@ pub mod settings; pub(crate) mod session_manager; -use std::fmt::Debug; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::Duration; use process::DataProcess; +use process::ResultSender; use session_manager::SessionManager; use session_manager::SessionManagerRef; use session_manager::SessionRef; @@ -86,10 +86,9 @@ impl RingbufConsumer { } /// Run the consumer, which will block the current thread. - pub async fn run(&self, processor: P) + pub async fn run

(&self, processor: P) where - P: DataProcess, - E: Into + Debug + Send, + P: DataProcess, { if self .started @@ -153,14 +152,13 @@ impl RingbufConsumer { } /// The main loop to process the ringbufs. - async fn process_loop( + async fn process_loop

( &self, processor: &P, interval: Duration, cancel: Option, ) where - P: DataProcess, - E: Into + Debug + Send, + P: DataProcess, { loop { process_all_sessions(&self.session_manager, processor).await; @@ -183,22 +181,20 @@ impl RingbufConsumer { } } -async fn process_all_sessions( +async fn process_all_sessions

( session_manager: &SessionManagerRef, processor: &P, ) where - P: DataProcess, - E: Into, + P: DataProcess, { for (_, session) in session_manager.iter() { process_session(&session, processor).await; } } -async fn process_session(session: &SessionRef, processor: &P) +async fn process_session

(session: &SessionRef, processor: &P) where - P: DataProcess, - E: Into, + P: DataProcess, { let ringbuf = session.ringbuf(); let enable_checksum = session.enable_checksum(); @@ -227,11 +223,12 @@ where continue; } - if let Err(e) = processor.process(data_slice).await { - session.push_result(req_id, e).await; - } else { - session.push_ok(req_id).await; - } + let result_sender = ResultSender { + request_id: req_id, + session: session.clone(), + }; + + processor.process(data_slice, result_sender).await; unsafe { ringbuf.advance_consume_offset(data_block.total_len()) } } diff --git a/src/consumer/process.rs b/src/consumer/process.rs index a42c9da..378c38b 100644 --- a/src/consumer/process.rs +++ b/src/consumer/process.rs @@ -1,14 +1,32 @@ use std::fmt::Debug; use std::future::Future; -use std::result::Result as StdResult; use crate::error::DataProcessResult; -pub trait DataProcess: Send + Sync { - type Error: Into + Debug + Send + 'static; +use super::session_manager::SessionRef; +pub trait DataProcess: Send + Sync { fn process( &self, data: &[u8], - ) -> impl Future>; + result_sender: ResultSender, + ) -> impl Future; +} + +pub struct ResultSender { + pub(crate) request_id: u32, + pub(crate) session: SessionRef, +} + +impl ResultSender { + pub async fn push_ok(&self) { + self.session.push_ok(self.request_id).await + } + + pub async fn push_result( + &self, + result: impl Into + Debug + Send + 'static, + ) { + self.session.push_result(self.request_id, result).await + } } diff --git a/tests/common.rs b/tests/common.rs index 2b2001f..c210468 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -1,7 +1,7 @@ use std::{str::from_utf8, sync::Arc, time::Duration}; use shm_ringbuf::{ - consumer::process::DataProcess, + consumer::process::{DataProcess, ResultSender}, error::DataProcessResult, producer::{prealloc::PreAlloc, RingbufProducer}, }; @@ -13,9 +13,17 @@ pub struct MsgForward { } impl DataProcess for MsgForward { - type Error = Error; + async fn process(&self, data: &[u8], result_sender: ResultSender) { + if let Err(e) = self.do_process(data).await { + result_sender.push_result(e).await; + } else { + result_sender.push_ok().await; + } + } +} - async fn process(&self, data: &[u8]) -> Result<(), Self::Error> { +impl MsgForward { + async fn do_process(&self, data: &[u8]) -> Result<(), Error> { let msg = from_utf8(data).map_err(|_| Error::DecodeError)?; let _ = self.sender.send(msg.to_string()).await;