Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add test_ringbuf_spsc_with_notify and wait_result tests #17

Merged
merged 2 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions tests/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use std::{str::from_utf8, time::Duration};

use shm_ringbuf::{
consumer::process::DataProcess,
error::DataProcessResult,
producer::{prealloc::PreAlloc, RingbufProducer},
};
use tokio::{sync::mpsc::Sender, time::sleep};

pub struct MsgForward {
pub sender: Sender<String>,
}

impl DataProcess for MsgForward {
type Error = Error;

async fn process(&self, data: &[u8]) -> Result<(), Self::Error> {
let msg = from_utf8(data).map_err(|_| Error::DecodeError)?;

let _ = self.sender.send(msg.to_string()).await;

Ok(())
}
}

#[derive(Debug)]
pub enum Error {
DecodeError,
ProcessError,
}

impl Error {
pub fn status_code(&self) -> u32 {
match self {
Error::DecodeError => 1001,
Error::ProcessError => 1002,
}
}

pub fn message(&self) -> String {
match self {
Error::DecodeError => "decode error".to_string(),
Error::ProcessError => "process error".to_string(),
}
}
}

impl From<Error> for DataProcessResult {
fn from(err: Error) -> DataProcessResult {
DataProcessResult {
status_code: err.status_code(),
message: err.message(),
}
}
}

pub fn msg_num() -> usize {
std::env::var("MSG_NUM")
.unwrap_or_else(|_| "100000".to_string())
.parse()
.unwrap()
}

pub async fn reserve_with_retry(
producer: &RingbufProducer,
size: usize,
retry_num: usize,
retry_interval: Duration,
) -> Result<PreAlloc, String> {
for _ in 0..retry_num {
let err = match producer.reserve(size) {
Ok(pre) => return Ok(pre),
Err(e) => e,
};

if !matches!(err, shm_ringbuf::error::Error::NotEnoughSpace { .. }) {
break;
}
sleep(retry_interval).await;
}

Err("reserve failed".to_string())
}

pub async fn wait_consumer_online(
p: &RingbufProducer,
retry_num: usize,
retry_interval: Duration,
) -> Result<(), String> {
for _ in 0..retry_num {
if p.server_online() && p.result_fetch_normal() {
return Ok(());
}

sleep(retry_interval).await;
}

Err("wait consumer online timeout".to_string())
}
233 changes: 141 additions & 92 deletions tests/ringbuf_spsc.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
use std::str::from_utf8;
mod common;

use std::sync::Arc;
use std::time::Duration;

use shm_ringbuf::consumer::process::DataProcess;
use common::{msg_num, reserve_with_retry, wait_consumer_online, MsgForward};
use shm_ringbuf::consumer::settings::ConsumerSettingsBuilder;
use shm_ringbuf::consumer::RingbufConsumer;
use shm_ringbuf::error::DataProcessResult;
use shm_ringbuf::error::{self};
use shm_ringbuf::producer::prealloc::PreAlloc;
use shm_ringbuf::producer::settings::ProducerSettingsBuilder;
use shm_ringbuf::producer::RingbufProducer;
use tokio::sync::mpsc::Sender;
use tokio::time::sleep;

#[tokio::test]
async fn test_ringbuf_spsc() {
async fn test_ringbuf_spsc_base() {
tracing_subscriber::fmt::init();

let (send_msgs, mut recv_msgs) = tokio::sync::mpsc::channel(100);
Expand All @@ -30,7 +26,7 @@ async fn test_ringbuf_spsc() {
.build();

tokio::spawn(async move {
let string_print = StringPrint { sender: send_msgs };
let string_print = MsgForward { sender: send_msgs };
RingbufConsumer::new(settings).run(string_print).await;
});

Expand All @@ -45,117 +41,170 @@ async fn test_ringbuf_spsc() {
let producer =
Arc::new(RingbufProducer::connect_lazy(settings).await.unwrap());

let msg_num = 100;
let mut joins = Vec::with_capacity(msg_num);
let msg_num = msg_num();

for i in 0..msg_num {
let mut pre_alloc =
reserve_with_retry(&producer, 20, 3, Duration::from_secs(1))
.await
.unwrap();

let write_str = format!("hello, {}", i);
tokio::spawn(async move {
for i in 0..msg_num {
let mut pre_alloc =
reserve_with_retry(&producer, 20, 3, Duration::from_secs(1))
.await
.unwrap();

wait_consumer_online(&producer, 5, Duration::from_secs(3))
.await
.unwrap();
let write_str = format!("hello, {}", i);

pre_alloc.write(write_str.as_bytes()).unwrap();
wait_consumer_online(&producer, 5, Duration::from_secs(3))
.await
.unwrap();

pre_alloc.commit();
pre_alloc.write(write_str.as_bytes()).unwrap();

joins.push(pre_alloc.wait_result());
}
for j in joins {
let _ = j.await;
}
pre_alloc.commit();
}
});

for i in 0..msg_num {
let msg = format!("hello, {}", i);
assert_eq!(recv_msgs.recv().await.unwrap(), msg);
}
}

async fn reserve_with_retry(
producer: &RingbufProducer,
size: usize,
retry_num: usize,
retry_interval: Duration,
) -> Result<PreAlloc, String> {
for _ in 0..retry_num {
let err = match producer.reserve(size) {
Ok(pre) => return Ok(pre),
Err(e) => e,
};

if !matches!(err, error::Error::NotEnoughSpace { .. }) {
break;
}
sleep(retry_interval).await;
}
#[tokio::test]
async fn test_ringbuf_spsc_with_notify() {
tracing_subscriber::fmt::init();

Err("reserve failed".to_string())
}
let (send_msgs, mut recv_msgs) = tokio::sync::mpsc::channel(100);

async fn wait_consumer_online(
p: &RingbufProducer,
retry_num: usize,
retry_interval: Duration,
) -> Result<(), String> {
for _ in 0..retry_num {
if p.server_online() && p.result_fetch_normal() {
return Ok(());
}
let dir = tempfile::tempdir().unwrap();
let grpc_sock_path = dir.path().join("control.sock");
let fdpass_sock_path = dir.path().join("sendfd.sock");

sleep(retry_interval).await;
}
let settings = ConsumerSettingsBuilder::new()
.grpc_sock_path(grpc_sock_path.clone())
.fdpass_sock_path(fdpass_sock_path.clone())
.process_interval(Duration::from_millis(10))
// Set too long interval for testing notify.
.process_interval(Duration::from_millis(1000))
.build();

Err("wait consumer online timeout".to_string())
}
tokio::spawn(async move {
let string_print = MsgForward { sender: send_msgs };
RingbufConsumer::new(settings).run(string_print).await;
});

pub struct StringPrint {
sender: Sender<String>,
}
// Wait for the consumer to start.
tokio::time::sleep(Duration::from_millis(10)).await;

impl DataProcess for StringPrint {
type Error = Error;
let settings = ProducerSettingsBuilder::new()
.grpc_sock_path(grpc_sock_path.clone())
.fdpass_sock_path(fdpass_sock_path.clone())
.build();

async fn process(&self, data: &[u8]) -> Result<(), Self::Error> {
let msg = from_utf8(data).map_err(|_| Error::DecodeError)?;
let producer =
Arc::new(RingbufProducer::connect_lazy(settings).await.unwrap());

let _ = self.sender.send(msg.to_string()).await;
let msg_num = msg_num();

Ok(())
}
}
tokio::spawn(async move {
for i in 0..msg_num {
let mut pre_alloc =
reserve_with_retry(&producer, 20, 3, Duration::from_secs(1))
.await
.unwrap();

#[derive(Debug)]
pub enum Error {
DecodeError,
ProcessError,
}
let write_str = format!("hello, {}", i);

impl Error {
pub fn status_code(&self) -> u32 {
match self {
Error::DecodeError => 1001,
Error::ProcessError => 1002,
}
}
wait_consumer_online(&producer, 5, Duration::from_secs(3))
.await
.unwrap();

pre_alloc.write(write_str.as_bytes()).unwrap();

pre_alloc.commit();

pub fn message(&self) -> String {
match self {
Error::DecodeError => "decode error".to_string(),
Error::ProcessError => "process error".to_string(),
producer.notify_consumer(Some(1000)).await;
}
});

for i in 0..msg_num {
let msg = format!("hello, {}", i);
assert_eq!(recv_msgs.recv().await.unwrap(), msg);
}
}

impl From<Error> for DataProcessResult {
fn from(err: Error) -> DataProcessResult {
DataProcessResult {
status_code: err.status_code(),
message: err.message(),
#[tokio::test]
async fn test_ringbuf_spsc_with_wait_result() {
tracing_subscriber::fmt::init();

let (send_msgs, mut recv_msgs) = tokio::sync::mpsc::channel(100);

let dir = tempfile::tempdir().unwrap();
let grpc_sock_path = dir.path().join("control.sock");
let fdpass_sock_path = dir.path().join("sendfd.sock");

let settings = ConsumerSettingsBuilder::new()
.grpc_sock_path(grpc_sock_path.clone())
.fdpass_sock_path(fdpass_sock_path.clone())
.process_interval(Duration::from_millis(10))
.build();

tokio::spawn(async move {
let string_print = MsgForward { sender: send_msgs };
RingbufConsumer::new(settings).run(string_print).await;
});

// Wait for the consumer to start.
tokio::time::sleep(Duration::from_millis(10)).await;

let settings = ProducerSettingsBuilder::new()
.grpc_sock_path(grpc_sock_path.clone())
.fdpass_sock_path(fdpass_sock_path.clone())
.build();

let producer =
Arc::new(RingbufProducer::connect_lazy(settings).await.unwrap());

let msg_num = msg_num();

tokio::spawn(async move {
let mut joins = Vec::with_capacity(100);
for i in 0..msg_num {
let mut pre_alloc =
reserve_with_retry(&producer, 20, 3, Duration::from_secs(1))
.await
.unwrap();

let write_str = format!("hello, {}", i);

wait_consumer_online(&producer, 5, Duration::from_secs(3))
.await
.unwrap();

pre_alloc.write(write_str.as_bytes()).unwrap();

pre_alloc.commit();

let join = pre_alloc.wait_result();

joins.push(join);

// Wait the result every 1000 messages.
if i % 1000 == 0 {
for join in joins.drain(..) {
let result = join.await.unwrap();
assert_eq!(result.status_code, 0);
}
}
if i == msg_num - 1 {
for join in joins.drain(..) {
let result = join.await.unwrap();
assert_eq!(result.status_code, 0);
}
}
}
});

for i in 0..msg_num {
let msg = format!("hello, {}", i);
assert_eq!(recv_msgs.recv().await.unwrap(), msg);
}
}
Loading