diff --git a/src/network/network_handler.rs b/src/network/network_handler.rs index bc99490..9369746 100644 --- a/src/network/network_handler.rs +++ b/src/network/network_handler.rs @@ -7,7 +7,7 @@ use crate::{ }; use futures_channel::{mpsc, oneshot}; use futures_util::{select, FutureExt, SinkExt, StreamExt}; -use log::{trace, debug, error, info, log_enabled, warn, Level}; +use log::{debug, error, info, log_enabled, trace, warn, Level}; use smallvec::SmallVec; use std::collections::{HashMap, VecDeque}; use tokio::sync::broadcast; @@ -73,6 +73,14 @@ impl MessageToReceive { } } +struct PendingSubscription { + pub channel_or_pattern: Vec, + pub subscription_type: SubscriptionType, + pub sender: PubSubSender, + /// indicates if more subscriptions will arrive in the same batch + pub more_to_come: bool, +} + pub(crate) struct NetworkHandler { status: Status, connection: Connection, @@ -81,7 +89,7 @@ pub(crate) struct NetworkHandler { msg_receiver: MsgReceiver, messages_to_send: VecDeque, messages_to_receive: VecDeque, - pending_subscriptions: HashMap, (SubscriptionType, PubSubSender)>, + pending_subscriptions: VecDeque, pending_unsubscriptions: VecDeque, SubscriptionType>>, subscriptions: HashMap, (SubscriptionType, PubSubSender)>, is_reply_on: bool, @@ -113,7 +121,7 @@ impl NetworkHandler { msg_receiver, messages_to_send: VecDeque::new(), messages_to_receive: VecDeque::new(), - pending_subscriptions: HashMap::new(), + pending_subscriptions: VecDeque::new(), pending_unsubscriptions: VecDeque::new(), subscriptions: HashMap::new(), is_reply_on: true, @@ -128,10 +136,7 @@ impl NetworkHandler { let join_handle = spawn(async move { if let Err(e) = network_handler.network_loop().await { - error!( - "[{}] network loop ended in error: {e}", - network_handler.tag - ); + error!("[{}] network loop ended in error: {e}", network_handler.tag); } }); @@ -172,9 +177,15 @@ impl NetworkHandler { _ => unreachable!(), }; - let pending_subscriptions = pub_sub_senders - .into_iter() - .map(|(channel, sender)| (channel, (subscription_type, sender))); + let num_pending_subscriptions = pub_sub_senders.len(); + let pending_subscriptions = pub_sub_senders.into_iter().enumerate().map( + |(index, (channel_or_pattern, sender))| PendingSubscription { + channel_or_pattern, + subscription_type, + sender, + more_to_come: index < num_pending_subscriptions - 1, + }, + ); self.pending_subscriptions.extend(pending_subscriptions); } @@ -226,8 +237,7 @@ impl NetworkHandler { Status::Disconnected => { debug!( "[{}] network disconnected, queuing command: {:?}", - self.tag, - msg.commands + self.tag, msg.commands ); self.messages_to_send.push_back(MessageToSend::new(msg)); } @@ -276,11 +286,7 @@ impl NetworkHandler { .iter() .fold(0, |sum, msg| sum + msg.message.commands.len()); if num_commands > 1 { - debug!( - "[{}] sending batch of {} commands", - self.tag, - num_commands - ); + debug!("[{}] sending batch of {} commands", self.tag, num_commands); } } @@ -377,10 +383,7 @@ impl NetworkHandler { Ok(resp_buf) if resp_buf.is_push_message() => match &mut self.push_sender { Some(push_sender) => { if let Err(e) = push_sender.send(result).await { - warn!( - "[{}] Cannot send monitor result to caller: {e}", - self.tag - ); + warn!("[{}] Cannot send monitor result to caller: {e}", self.tag); } } None => { @@ -415,10 +418,7 @@ impl NetworkHandler { Ok(resp_buf) if resp_buf.is_monitor_message() => { if let Some(push_sender) = &mut self.push_sender { if let Err(e) = push_sender.send(result).await { - warn!( - "[{}] Cannot send monitor result to caller: {e}", - self.tag - ); + warn!("[{}] Cannot send monitor result to caller: {e}", self.tag); } } } @@ -428,10 +428,7 @@ impl NetworkHandler { Ok(resp_buf) if resp_buf.is_monitor_message() => { if let Some(push_sender) = &mut self.push_sender { if let Err(e) = push_sender.send(result).await { - warn!( - "[{}] Cannot send monitor result to caller: {e}", - self.tag - ); + warn!("[{}] Cannot send monitor result to caller: {e}", self.tag); } } } @@ -477,7 +474,11 @@ impl NetworkHandler { error!("[{}] Cannot retry message: {e}", self.tag); } } else { - trace!("[{}] Will respond to: {:?}", self.tag, message_to_receive.message); + trace!( + "[{}] Will respond to: {:?}", + self.tag, + message_to_receive.message + ); match message_to_receive.message.commands { Commands::Single(_, Some(result_sender)) => { if let Err(e) = result_sender.send(result) { @@ -507,7 +508,8 @@ impl NetworkHandler { } }, Commands::None | Commands::Single(_, None) => { - debug!("[{}] forget value {result:?}", self.tag) // fire & forget + debug!("[{}] forget value {result:?}", self.tag) + // fire & forget } } } @@ -580,14 +582,27 @@ impl NetworkHandler { RefPubSubMessage::Subscribe(channel_or_pattern) | RefPubSubMessage::PSubscribe(channel_or_pattern) | RefPubSubMessage::SSubscribe(channel_or_pattern) => { - if let Some(pub_sub_sender) = - self.pending_subscriptions.remove(channel_or_pattern) - { - self.subscriptions - .insert(channel_or_pattern.to_vec(), pub_sub_sender); - } - if !self.pending_subscriptions.is_empty() { - return None; + if let Some(pending_sub) = self.pending_subscriptions.pop_front() { + if pending_sub.channel_or_pattern == channel_or_pattern { + self.subscriptions + .insert(channel_or_pattern.to_vec(), (pending_sub.subscription_type, pending_sub.sender)); + + if pending_sub.more_to_come { + return None; + } + } else { + error!( + "[{}] Unexpected subscription confirmation on channel '{:?}'", + self.tag, + String::from_utf8_lossy(channel_or_pattern) + ); + } + } else { + error!( + "[{}] Cannot find pending subscription for channel '{:?}'", + self.tag, + String::from_utf8_lossy(channel_or_pattern) + ); } Some(Ok(RespBuf::ok())) } @@ -669,9 +684,7 @@ impl NetworkHandler { message_to_receive.attempts += 1; debug!( "[{}]: {:?}: attempt {}", - self.tag, - message_to_receive.message.commands, - message_to_receive.attempts + self.tag, message_to_receive.message.commands, message_to_receive.attempts ); } } @@ -682,8 +695,7 @@ impl NetworkHandler { { debug!( "[{}] {:?}, max attempts reached", - self.tag, - message_to_receive.message.commands + self.tag, message_to_receive.message.commands ); if let Some(message_to_receive) = self.messages_to_receive.pop_front() { match message_to_receive.message.commands { @@ -720,9 +732,7 @@ impl NetworkHandler { message_to_send.attempts += 1; debug!( "[{}] {:?}: attempt {}", - self.tag, - message_to_send.message.commands, - message_to_send.attempts + self.tag, message_to_send.message.commands, message_to_send.attempts ); } } @@ -733,8 +743,7 @@ impl NetworkHandler { { debug!( "[{}] {:?}, max attempts reached", - self.tag, - message_to_send.message.commands + self.tag, message_to_send.message.commands ); if let Some(message_to_send) = self.messages_to_send.pop_front() { match message_to_send.message.commands { @@ -838,29 +847,28 @@ impl NetworkHandler { } if !self.pending_subscriptions.is_empty() { - for (channel_or_pattern, (subscription_type, sender)) in - self.pending_subscriptions.drain() + for pending_sub in self.pending_subscriptions.drain(..) { - match subscription_type { + match pending_sub.subscription_type { SubscriptionType::Channel => { self.connection - .subscribe(channel_or_pattern.clone()) + .subscribe(pending_sub.channel_or_pattern.clone()) .await?; } SubscriptionType::Pattern => { self.connection - .psubscribe(channel_or_pattern.clone()) + .psubscribe(pending_sub.channel_or_pattern.clone()) .await?; } SubscriptionType::ShardChannel => { self.connection - .ssubscribe(channel_or_pattern.clone()) + .ssubscribe(pending_sub.channel_or_pattern.clone()) .await?; } } self.subscriptions - .insert(channel_or_pattern, (subscription_type, sender)); + .insert(pending_sub.channel_or_pattern, (pending_sub.subscription_type, pending_sub.sender)); } } diff --git a/src/tests/pub_sub_commands.rs b/src/tests/pub_sub_commands.rs index 4b16e5d..895ef0e 100644 --- a/src/tests/pub_sub_commands.rs +++ b/src/tests/pub_sub_commands.rs @@ -1,16 +1,18 @@ -use std::collections::{HashMap, HashSet}; - use crate::{ client::{Client, IntoConfig}, commands::{ ClientKillOptions, ClusterCommands, ClusterShardResult, ConnectionCommands, FlushingMode, - PubSubChannelsOptions, PubSubCommands, ServerCommands, StringCommands, + ListCommands, PubSubChannelsOptions, PubSubCommands, ServerCommands, StringCommands, }, tests::{get_cluster_test_client, get_default_addr, get_test_client, log_try_init}, Result, }; use futures_util::{FutureExt, StreamExt, TryStreamExt}; use serial_test::serial; +use std::{ + collections::{HashMap, HashSet}, + future::IntoFuture, +}; #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] @@ -610,3 +612,38 @@ async fn no_auto_resubscribe() -> Result<()> { Ok(()) } + +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[serial] +async fn concurrent_subscribe() -> Result<()> { + let pub_sub_client1 = get_test_client().await?; + let pub_sub_client2 = pub_sub_client1.clone(); + let regular_client = get_test_client().await?; + + // cleanup + regular_client.flushdb(FlushingMode::Sync).await?; + + regular_client.lpush("key", ["value1", "value2"]).await?; + + let results = tokio::join!( + pub_sub_client1.subscribe("mychannel1"), + pub_sub_client2.subscribe("mychannel2"), + regular_client.lpop("key", 2).into_future(), + regular_client.lpop("key", 2).into_future(), + regular_client + .publish("mychannel1", "new") + .into_future() + ); + + let mut pub_sub_stream1 = results.0?; + let _pub_sub_stream2 = results.1?; + let values1: Vec = results.2?; + let values2: Vec = results.3?; + let message1 = pub_sub_stream1.next().await.unwrap()?; + + assert_eq!(vec!["value2".to_owned(), "value1".to_owned()], values1); + assert_eq!(Vec::::new(), values2); + assert_eq!(b"new".to_vec(), message1.payload); + + Ok(()) +}