diff --git a/src/network/network_handler.rs b/src/network/network_handler.rs index 0c388b2..d335f07 100644 --- a/src/network/network_handler.rs +++ b/src/network/network_handler.rs @@ -150,7 +150,7 @@ impl NetworkHandler { loop { select! { msg = self.msg_receiver.next().fuse() => { - if !self.handle_message(msg).await { break; } + if !self.try_handle_message(msg).await { break; } } , result = self.connection.read().fuse() => { if !self.handle_result(result).await { break; } @@ -162,118 +162,12 @@ impl NetworkHandler { Ok(()) } - async fn handle_message(&mut self, mut msg: Option) -> bool { + async fn try_handle_message(&mut self, mut msg: Option) -> bool { let is_channel_closed: bool; loop { - if let Some(mut msg) = msg { - trace!( - "[{}][{:?}] Will handle message: {msg:?}", - self.tag, - self.status - ); - let pub_sub_senders = msg.pub_sub_senders.take(); - if let Some(pub_sub_senders) = pub_sub_senders { - let subscription_type = match &msg.commands { - Commands::Single(command, _) => match command.name { - "SUBSCRIBE" => SubscriptionType::Channel, - "PSUBSCRIBE" => SubscriptionType::Pattern, - "SSUBSCRIBE" => SubscriptionType::ShardChannel, - _ => unreachable!(), - }, - _ => unreachable!(), - }; - - let num_pending_subscriptions = pub_sub_senders.len(); - let pending_subscriptions = pub_sub_senders.into_iter().enumerate().map( - |(index, (channel_or_pattern, sender))| PendingSubscription { - channel_or_pattern, - subscription_type, - sender, - more_to_come: index < num_pending_subscriptions - 1, - }, - ); - - self.pending_subscriptions.extend(pending_subscriptions); - } - - let push_sender = msg.push_sender.take(); - if let Some(push_sender) = push_sender { - debug!("[{}] Registering push_sender", self.tag); - self.push_sender = Some(push_sender); - } - - match &self.status { - Status::Connected => { - for command in &msg.commands { - match command.name { - "SUBSCRIBE" | "PSUBSCRIBE" | "SSUBSCRIBE" => { - self.status = Status::Subscribing; - } - "MONITOR" => { - self.status = Status::EnteringMonitor; - } - _ => (), - } - } - self.messages_to_send.push_back(MessageToSend::new(msg)); - } - Status::Subscribing => { - self.messages_to_send.push_back(MessageToSend::new(msg)); - } - Status::Subscribed => { - for command in &msg.commands { - let subscription_type = match command.name { - "UNSUBSCRIBE" => Some(SubscriptionType::Channel), - "PUNSUBSCRIBE" => Some(SubscriptionType::Pattern), - "SUNSUBSCRIBE" => Some(SubscriptionType::ShardChannel), - _ => None, - }; - if let Some(subscription_type) = subscription_type { - self.pending_unsubscriptions.push_back( - command - .args - .into_iter() - .map(|a| (a.to_vec(), subscription_type)) - .collect(), - ); - } - } - self.messages_to_send.push_back(MessageToSend::new(msg)); - } - Status::Disconnected => { - if msg.retry_on_error { - debug!( - "[{}] network disconnected, queuing command: {:?}", - self.tag, msg.commands - ); - self.messages_to_send.push_back(MessageToSend::new(msg)); - } else { - debug!( - "[{}] network disconnected, ending command in error: {:?}", - self.tag, msg.commands - ); - msg.commands.send_error( - &self.tag, - Error::Client("Disconnected from server".to_string()), - ); - } - } - Status::EnteringMonitor => { - self.messages_to_send.push_back(MessageToSend::new(msg)) - } - Status::Monitor => { - for command in &msg.commands { - if command.name == "RESET" { - self.status = Status::LeavingMonitor; - } - } - self.messages_to_send.push_back(MessageToSend::new(msg)); - } - Status::LeavingMonitor => { - self.messages_to_send.push_back(MessageToSend::new(msg)); - } - } + if let Some(msg) = msg { + self.handle_message(msg).await; } else { is_channel_closed = true; break; @@ -296,6 +190,136 @@ impl NetworkHandler { !is_channel_closed } + async fn handle_message(&mut self, mut msg: Message) { + trace!( + "[{}][{:?}] Will handle message: {msg:?}", + self.tag, + self.status + ); + let pub_sub_senders = msg.pub_sub_senders.take(); + if let Some(pub_sub_senders) = pub_sub_senders { + let subscription_type = match &msg.commands { + Commands::Single(command, _) => match command.name { + "SUBSCRIBE" => SubscriptionType::Channel, + "PSUBSCRIBE" => SubscriptionType::Pattern, + "SSUBSCRIBE" => SubscriptionType::ShardChannel, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + for (channel_or_pattern, _sender) in pub_sub_senders.iter() { + if self.subscriptions.contains_key(channel_or_pattern) { + debug!( + "[{}][{:?}] There is already a subscription on channel `{}`", + self.tag, + self.status, + String::from_utf8_lossy(channel_or_pattern) + ); + msg.commands.send_error( + &self.tag, + Error::Client( + format!( + "There is already a subscription on channel `{}`", + String::from_utf8_lossy(channel_or_pattern) + ) + .to_string(), + ), + ); + return; + } + } + + let num_pending_subscriptions = pub_sub_senders.len(); + let pending_subscriptions = pub_sub_senders.into_iter().enumerate().map( + |(index, (channel_or_pattern, sender))| PendingSubscription { + channel_or_pattern, + subscription_type, + sender, + more_to_come: index < num_pending_subscriptions - 1, + }, + ); + + self.pending_subscriptions.extend(pending_subscriptions); + } + + let push_sender = msg.push_sender.take(); + if let Some(push_sender) = push_sender { + debug!("[{}] Registering push_sender", self.tag); + self.push_sender = Some(push_sender); + } + + match &self.status { + Status::Connected => { + for command in &msg.commands { + match command.name { + "SUBSCRIBE" | "PSUBSCRIBE" | "SSUBSCRIBE" => { + self.status = Status::Subscribing; + } + "MONITOR" => { + self.status = Status::EnteringMonitor; + } + _ => (), + } + } + self.messages_to_send.push_back(MessageToSend::new(msg)); + } + Status::Subscribing => { + self.messages_to_send.push_back(MessageToSend::new(msg)); + } + Status::Subscribed => { + for command in &msg.commands { + let subscription_type = match command.name { + "UNSUBSCRIBE" => Some(SubscriptionType::Channel), + "PUNSUBSCRIBE" => Some(SubscriptionType::Pattern), + "SUNSUBSCRIBE" => Some(SubscriptionType::ShardChannel), + _ => None, + }; + if let Some(subscription_type) = subscription_type { + self.pending_unsubscriptions.push_back( + command + .args + .into_iter() + .map(|a| (a.to_vec(), subscription_type)) + .collect(), + ); + } + } + self.messages_to_send.push_back(MessageToSend::new(msg)); + } + Status::Disconnected => { + if msg.retry_on_error { + debug!( + "[{}] network disconnected, queuing command: {:?}", + self.tag, msg.commands + ); + self.messages_to_send.push_back(MessageToSend::new(msg)); + } else { + debug!( + "[{}] network disconnected, ending command in error: {:?}", + self.tag, msg.commands + ); + msg.commands.send_error( + &self.tag, + Error::Client("Disconnected from server".to_string()), + ); + } + } + Status::EnteringMonitor => self.messages_to_send.push_back(MessageToSend::new(msg)), + Status::Monitor => { + for command in &msg.commands { + if command.name == "RESET" { + self.status = Status::LeavingMonitor; + } + } + self.messages_to_send.push_back(MessageToSend::new(msg)); + } + Status::LeavingMonitor => { + self.messages_to_send.push_back(MessageToSend::new(msg)); + } + } + } + async fn send_messages(&mut self) { if log_enabled!(Level::Debug) { let num_commands = self @@ -755,7 +779,7 @@ impl NetworkHandler { let delay = end.duration_since(Instant::now()); let result = timeout(delay, self.msg_receiver.next().fuse()).await; if let Ok(msg) = result { - if !self.handle_message(msg).await { + if !self.try_handle_message(msg).await { return false; } } else { diff --git a/src/tests/pub_sub_commands.rs b/src/tests/pub_sub_commands.rs index e1dad9e..339bc7a 100644 --- a/src/tests/pub_sub_commands.rs +++ b/src/tests/pub_sub_commands.rs @@ -787,6 +787,7 @@ async fn subscribe_multiple_times_to_the_same_channel() -> Result<()> { let mut pub_sub_stream = pub_sub_client.subscribe("mychannel").await?; assert!(pub_sub_stream.subscribe("mychannel").await.is_err()); assert!(pub_sub_client.subscribe("mychannel").await.is_err()); + regular_client.publish("mychannel", "mymessage").await?; pub_sub_stream.psubscribe("pattern").await?; assert!(pub_sub_stream.psubscribe("pattern").await.is_err());