Skip to content

Commit

Permalink
fix concurrency problem in pub sub
Browse files Browse the repository at this point in the history
  • Loading branch information
mcatanzariti committed Nov 9, 2023
1 parent 1de0113 commit a93b687
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 59 deletions.
120 changes: 64 additions & 56 deletions src/network/network_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
};
use futures_channel::{mpsc, oneshot};
use futures_util::{select, FutureExt, SinkExt, StreamExt};
use log::{trace, debug, error, info, log_enabled, warn, Level};
use log::{debug, error, info, log_enabled, trace, warn, Level};
use smallvec::SmallVec;
use std::collections::{HashMap, VecDeque};
use tokio::sync::broadcast;
Expand Down Expand Up @@ -73,6 +73,14 @@ impl MessageToReceive {
}
}

struct PendingSubscription {
pub channel_or_pattern: Vec<u8>,
pub subscription_type: SubscriptionType,
pub sender: PubSubSender,
/// indicates if more subscriptions will arrive in the same batch
pub more_to_come: bool,
}

pub(crate) struct NetworkHandler {
status: Status,
connection: Connection,
Expand All @@ -81,7 +89,7 @@ pub(crate) struct NetworkHandler {
msg_receiver: MsgReceiver,
messages_to_send: VecDeque<MessageToSend>,
messages_to_receive: VecDeque<MessageToReceive>,
pending_subscriptions: HashMap<Vec<u8>, (SubscriptionType, PubSubSender)>,
pending_subscriptions: VecDeque<PendingSubscription>,
pending_unsubscriptions: VecDeque<HashMap<Vec<u8>, SubscriptionType>>,
subscriptions: HashMap<Vec<u8>, (SubscriptionType, PubSubSender)>,
is_reply_on: bool,
Expand Down Expand Up @@ -113,7 +121,7 @@ impl NetworkHandler {
msg_receiver,
messages_to_send: VecDeque::new(),
messages_to_receive: VecDeque::new(),
pending_subscriptions: HashMap::new(),
pending_subscriptions: VecDeque::new(),
pending_unsubscriptions: VecDeque::new(),
subscriptions: HashMap::new(),
is_reply_on: true,
Expand All @@ -128,10 +136,7 @@ impl NetworkHandler {

let join_handle = spawn(async move {
if let Err(e) = network_handler.network_loop().await {
error!(
"[{}] network loop ended in error: {e}",
network_handler.tag
);
error!("[{}] network loop ended in error: {e}", network_handler.tag);
}
});

Expand Down Expand Up @@ -172,9 +177,15 @@ impl NetworkHandler {
_ => unreachable!(),
};

let pending_subscriptions = pub_sub_senders
.into_iter()
.map(|(channel, sender)| (channel, (subscription_type, sender)));
let num_pending_subscriptions = pub_sub_senders.len();
let pending_subscriptions = pub_sub_senders.into_iter().enumerate().map(
|(index, (channel_or_pattern, sender))| PendingSubscription {
channel_or_pattern,
subscription_type,
sender,
more_to_come: index < num_pending_subscriptions - 1,
},
);

self.pending_subscriptions.extend(pending_subscriptions);
}
Expand Down Expand Up @@ -226,8 +237,7 @@ impl NetworkHandler {
Status::Disconnected => {
debug!(
"[{}] network disconnected, queuing command: {:?}",
self.tag,
msg.commands
self.tag, msg.commands
);
self.messages_to_send.push_back(MessageToSend::new(msg));
}
Expand Down Expand Up @@ -276,11 +286,7 @@ impl NetworkHandler {
.iter()
.fold(0, |sum, msg| sum + msg.message.commands.len());
if num_commands > 1 {
debug!(
"[{}] sending batch of {} commands",
self.tag,
num_commands
);
debug!("[{}] sending batch of {} commands", self.tag, num_commands);
}
}

Expand Down Expand Up @@ -377,10 +383,7 @@ impl NetworkHandler {
Ok(resp_buf) if resp_buf.is_push_message() => match &mut self.push_sender {
Some(push_sender) => {
if let Err(e) = push_sender.send(result).await {
warn!(
"[{}] Cannot send monitor result to caller: {e}",
self.tag
);
warn!("[{}] Cannot send monitor result to caller: {e}", self.tag);
}
}
None => {
Expand Down Expand Up @@ -415,10 +418,7 @@ impl NetworkHandler {
Ok(resp_buf) if resp_buf.is_monitor_message() => {
if let Some(push_sender) = &mut self.push_sender {
if let Err(e) = push_sender.send(result).await {
warn!(
"[{}] Cannot send monitor result to caller: {e}",
self.tag
);
warn!("[{}] Cannot send monitor result to caller: {e}", self.tag);
}
}
}
Expand All @@ -428,10 +428,7 @@ impl NetworkHandler {
Ok(resp_buf) if resp_buf.is_monitor_message() => {
if let Some(push_sender) = &mut self.push_sender {
if let Err(e) = push_sender.send(result).await {
warn!(
"[{}] Cannot send monitor result to caller: {e}",
self.tag
);
warn!("[{}] Cannot send monitor result to caller: {e}", self.tag);
}
}
}
Expand Down Expand Up @@ -477,7 +474,11 @@ impl NetworkHandler {
error!("[{}] Cannot retry message: {e}", self.tag);
}
} else {
trace!("[{}] Will respond to: {:?}", self.tag, message_to_receive.message);
trace!(
"[{}] Will respond to: {:?}",
self.tag,
message_to_receive.message
);
match message_to_receive.message.commands {
Commands::Single(_, Some(result_sender)) => {
if let Err(e) = result_sender.send(result) {
Expand Down Expand Up @@ -507,7 +508,8 @@ impl NetworkHandler {
}
},
Commands::None | Commands::Single(_, None) => {
debug!("[{}] forget value {result:?}", self.tag) // fire & forget
debug!("[{}] forget value {result:?}", self.tag)
// fire & forget
}
}
}
Expand Down Expand Up @@ -580,14 +582,27 @@ impl NetworkHandler {
RefPubSubMessage::Subscribe(channel_or_pattern)
| RefPubSubMessage::PSubscribe(channel_or_pattern)
| RefPubSubMessage::SSubscribe(channel_or_pattern) => {
if let Some(pub_sub_sender) =
self.pending_subscriptions.remove(channel_or_pattern)
{
self.subscriptions
.insert(channel_or_pattern.to_vec(), pub_sub_sender);
}
if !self.pending_subscriptions.is_empty() {
return None;
if let Some(pending_sub) = self.pending_subscriptions.pop_front() {
if pending_sub.channel_or_pattern == channel_or_pattern {
self.subscriptions
.insert(channel_or_pattern.to_vec(), (pending_sub.subscription_type, pending_sub.sender));

if pending_sub.more_to_come {
return None;
}
} else {
error!(
"[{}] Unexpected subscription confirmation on channel '{:?}'",
self.tag,
String::from_utf8_lossy(channel_or_pattern)
);
}
} else {
error!(
"[{}] Cannot find pending subscription for channel '{:?}'",
self.tag,
String::from_utf8_lossy(channel_or_pattern)
);
}
Some(Ok(RespBuf::ok()))
}
Expand Down Expand Up @@ -669,9 +684,7 @@ impl NetworkHandler {
message_to_receive.attempts += 1;
debug!(
"[{}]: {:?}: attempt {}",
self.tag,
message_to_receive.message.commands,
message_to_receive.attempts
self.tag, message_to_receive.message.commands, message_to_receive.attempts
);
}
}
Expand All @@ -682,8 +695,7 @@ impl NetworkHandler {
{
debug!(
"[{}] {:?}, max attempts reached",
self.tag,
message_to_receive.message.commands
self.tag, message_to_receive.message.commands
);
if let Some(message_to_receive) = self.messages_to_receive.pop_front() {
match message_to_receive.message.commands {
Expand Down Expand Up @@ -720,9 +732,7 @@ impl NetworkHandler {
message_to_send.attempts += 1;
debug!(
"[{}] {:?}: attempt {}",
self.tag,
message_to_send.message.commands,
message_to_send.attempts
self.tag, message_to_send.message.commands, message_to_send.attempts
);
}
}
Expand All @@ -733,8 +743,7 @@ impl NetworkHandler {
{
debug!(
"[{}] {:?}, max attempts reached",
self.tag,
message_to_send.message.commands
self.tag, message_to_send.message.commands
);
if let Some(message_to_send) = self.messages_to_send.pop_front() {
match message_to_send.message.commands {
Expand Down Expand Up @@ -838,29 +847,28 @@ impl NetworkHandler {
}

if !self.pending_subscriptions.is_empty() {
for (channel_or_pattern, (subscription_type, sender)) in
self.pending_subscriptions.drain()
for pending_sub in self.pending_subscriptions.drain(..)
{
match subscription_type {
match pending_sub.subscription_type {
SubscriptionType::Channel => {
self.connection
.subscribe(channel_or_pattern.clone())
.subscribe(pending_sub.channel_or_pattern.clone())
.await?;
}
SubscriptionType::Pattern => {
self.connection
.psubscribe(channel_or_pattern.clone())
.psubscribe(pending_sub.channel_or_pattern.clone())
.await?;
}
SubscriptionType::ShardChannel => {
self.connection
.ssubscribe(channel_or_pattern.clone())
.ssubscribe(pending_sub.channel_or_pattern.clone())
.await?;
}
}

self.subscriptions
.insert(channel_or_pattern, (subscription_type, sender));
.insert(pending_sub.channel_or_pattern, (pending_sub.subscription_type, pending_sub.sender));
}
}

Expand Down
43 changes: 40 additions & 3 deletions src/tests/pub_sub_commands.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use std::collections::{HashMap, HashSet};

use crate::{
client::{Client, IntoConfig},
commands::{
ClientKillOptions, ClusterCommands, ClusterShardResult, ConnectionCommands, FlushingMode,
PubSubChannelsOptions, PubSubCommands, ServerCommands, StringCommands,
ListCommands, PubSubChannelsOptions, PubSubCommands, ServerCommands, StringCommands,
},
tests::{get_cluster_test_client, get_default_addr, get_test_client, log_try_init},
Result,
};
use futures_util::{FutureExt, StreamExt, TryStreamExt};
use serial_test::serial;
use std::{
collections::{HashMap, HashSet},
future::IntoFuture,
};

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
Expand Down Expand Up @@ -610,3 +612,38 @@ async fn no_auto_resubscribe() -> Result<()> {

Ok(())
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[serial]
async fn concurrent_subscribe() -> Result<()> {
let pub_sub_client1 = get_test_client().await?;
let pub_sub_client2 = pub_sub_client1.clone();
let regular_client = get_test_client().await?;

// cleanup
regular_client.flushdb(FlushingMode::Sync).await?;

regular_client.lpush("key", ["value1", "value2"]).await?;

let results = tokio::join!(
pub_sub_client1.subscribe("mychannel1"),
pub_sub_client2.subscribe("mychannel2"),
regular_client.lpop("key", 2).into_future(),
regular_client.lpop("key", 2).into_future(),
regular_client
.publish("mychannel1", "new")
.into_future()
);

let mut pub_sub_stream1 = results.0?;
let _pub_sub_stream2 = results.1?;
let values1: Vec<String> = results.2?;
let values2: Vec<String> = results.3?;
let message1 = pub_sub_stream1.next().await.unwrap()?;

assert_eq!(vec!["value2".to_owned(), "value1".to_owned()], values1);
assert_eq!(Vec::<String>::new(), values2);
assert_eq!(b"new".to_vec(), message1.payload);

Ok(())
}

0 comments on commit a93b687

Please sign in to comment.