Skip to content

Commit

Permalink
added unsubscribe methods to pub_sub_stream
Browse files Browse the repository at this point in the history
  • Loading branch information
mcatanzariti committed Dec 2, 2023
1 parent d2d9ff0 commit 546fd29
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 21 deletions.
2 changes: 1 addition & 1 deletion redis/cluster.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ COPY cluster.conf .
RUN chown redis:redis /redis/cluster.conf
EXPOSE 6379
COPY cluster-entrypoint.sh .
ENTRYPOINT ["/redis/cluster-entrypoint.sh"]
ENTRYPOINT ["/redis/cluster-entrypoint.sh"]
2 changes: 1 addition & 1 deletion redis/docker_down.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
./set_host_ip.sh
docker-compose down
docker-compose down
2 changes: 1 addition & 1 deletion redis/sentinel.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ COPY sentinel.conf .
RUN chown redis:redis /redis/sentinel.conf
EXPOSE 26379
COPY sentinel-entrypoint.sh .
ENTRYPOINT ["/redis/sentinel-entrypoint.sh"]
ENTRYPOINT ["/redis/sentinel-entrypoint.sh"]
2 changes: 1 addition & 1 deletion redis/set_host_ip.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
IP=`ifconfig eth0 | grep -Eo 'inet (addr:)?([0-9]*\.){3}[0-9]*' | grep -Eo '([0-9]*\.){3}[0-9]*'`
echo IP=$IP
echo "HOST_IP=$IP" > .env
echo "HOST_IP=$IP" > .env
51 changes: 42 additions & 9 deletions src/client/pub_sub_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,7 @@ impl PubSubStream {
.subscribe_from_pub_sub_sender(&channels, &self.sender)
.await?;

let mut existing_channels = CommandArgs::default();
std::mem::swap(&mut existing_channels, &mut self.channels);
self.channels = existing_channels.arg(channels).build();
self.channels = self.channels.arg(channels).build();

Ok(())
}
Expand All @@ -210,9 +208,7 @@ impl PubSubStream {
.psubscribe_from_pub_sub_sender(&patterns, &self.sender)
.await?;

let mut existing_patterns = CommandArgs::default();
std::mem::swap(&mut existing_patterns, &mut self.patterns);
self.patterns = existing_patterns.arg(patterns).build();
self.patterns = self.patterns.arg(patterns).build();

Ok(())
}
Expand All @@ -229,9 +225,46 @@ impl PubSubStream {
.ssubscribe_from_pub_sub_sender(&shardchannels, &self.sender)
.await?;

let mut existing_shardchannels = CommandArgs::default();
std::mem::swap(&mut existing_shardchannels, &mut self.shardchannels);
self.shardchannels = existing_shardchannels.arg(shardchannels).build();
self.shardchannels = self.shardchannels.arg(shardchannels).build();

Ok(())
}

/// Unsubscribe from the given channels
pub async fn unsubscribe<C, CC>(&mut self, channels: CC) -> Result<()>
where
C: SingleArg + Send,
CC: SingleArgCollection<C>,
{
let channels = CommandArgs::default().arg(channels).build();
self.channels.retain(|channel| channels.iter().all(|c| c != channel));
self.client.unsubscribe(channels).await?;

Ok(())
}

/// Unsubscribe from the given patterns
pub async fn punsubscribe<C, CC>(&mut self, patterns: CC) -> Result<()>
where
C: SingleArg + Send,
CC: SingleArgCollection<C>,
{
let patterns = CommandArgs::default().arg(patterns).build();
self.patterns.retain(|pattern| patterns.iter().all(|p| p != pattern));
self.client.punsubscribe(patterns).await?;

Ok(())
}

/// Unsubscribe from the given patterns
pub async fn sunsubscribe<C, CC>(&mut self, shardchannels: CC) -> Result<()>
where
C: SingleArg + Send,
CC: SingleArgCollection<C>,
{
let shardchannels = CommandArgs::default().arg(shardchannels).build();
self.shardchannels.retain(|shardchannel| shardchannels.iter().all(|sc: &Vec<u8>| sc != shardchannel));
self.client.punsubscribe(shardchannels).await?;

Ok(())
}
Expand Down
24 changes: 19 additions & 5 deletions src/resp/command_args.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use smallvec::SmallVec;

use crate::resp::ToArgs;
use std::{fmt};
use std::fmt;

/// Collection of arguments of [`Command`](crate::resp::Command).
#[derive(Clone, Default)]
pub struct CommandArgs {
args: SmallVec<[Vec<u8>;10]>,
args: SmallVec<[Vec<u8>; 10]>,
}

impl CommandArgs {
Expand Down Expand Up @@ -70,6 +70,13 @@ impl CommandArgs {
pub(crate) fn write_arg(&mut self, buf: &[u8]) {
self.args.push(buf.to_vec());
}

pub(crate) fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&[u8]) -> bool,
{
self.args.retain(|arg| f(arg))
}
}

impl<'a> IntoIterator for &'a CommandArgs {
Expand All @@ -79,14 +86,14 @@ impl<'a> IntoIterator for &'a CommandArgs {
#[inline]
fn into_iter(self) -> Self::IntoIter {
CommandArgsIterator {
iter: self.args.iter()
iter: self.args.iter(),
}
}
}

/// [`CommandArgs`] iterator
pub struct CommandArgsIterator<'a> {
iter: std::slice::Iter<'a, Vec<u8>>
iter: std::slice::Iter<'a, Vec<u8>>,
}

impl<'a> Iterator for CommandArgsIterator<'a> {
Expand All @@ -110,7 +117,14 @@ impl std::ops::Deref for CommandArgs {
impl fmt::Debug for CommandArgs {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CommandArgs")
.field("args", &self.args.iter().map(|a| String::from_utf8_lossy(a.as_slice())).collect::<Vec<_>>())
.field(
"args",
&self
.args
.iter()
.map(|a| String::from_utf8_lossy(a.as_slice()))
.collect::<Vec<_>>(),
)
.finish()
}
}
86 changes: 83 additions & 3 deletions src/tests/pub_sub_commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,9 +630,7 @@ async fn concurrent_subscribe() -> Result<()> {
pub_sub_client2.subscribe("mychannel2"),
regular_client.lpop("key", 2).into_future(),
regular_client.lpop("key", 2).into_future(),
regular_client
.publish("mychannel1", "new")
.into_future()
regular_client.publish("mychannel1", "new").into_future()
);

let mut pub_sub_stream1 = results.0?;
Expand All @@ -647,3 +645,85 @@ async fn concurrent_subscribe() -> Result<()> {

Ok(())
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
#[serial]
async fn unsubscribe() -> Result<()> {
let pub_sub_client = get_test_client().await?;
let regular_client = get_test_client().await?;

// cleanup
regular_client.flushdb(FlushingMode::Sync).await?;

let mut pub_sub_stream = pub_sub_client
.subscribe(["mychannel1", "mychannel2"])
.await?;
regular_client.publish("mychannel1", "mymessage1").await?;
regular_client.publish("mychannel2", "mymessage2").await?;

let message = pub_sub_stream.next().await.unwrap()?;
let channel: String = String::from_utf8(message.channel).unwrap();
let payload: String = String::from_utf8(message.payload).unwrap();

assert_eq!("mychannel1", channel);
assert_eq!("mymessage1", payload);

let message = pub_sub_stream.next().await.unwrap()?;
let channel: String = String::from_utf8(message.channel).unwrap();
let payload: String = String::from_utf8(message.payload).unwrap();

assert_eq!("mychannel2", channel);
assert_eq!("mymessage2", payload);

regular_client.publish("mychannel1", "mymessage11").await?;
pub_sub_stream.unsubscribe("mychannel2").await?;
regular_client.publish("mychannel1", "mymessage12").await?;

let message = pub_sub_stream.next().await.unwrap()?;
let channel: String = String::from_utf8(message.channel).unwrap();
let payload: String = String::from_utf8(message.payload).unwrap();

assert_eq!("mychannel1", channel);
assert_eq!("mymessage11", payload);

let message = pub_sub_stream.next().await.unwrap()?;
let channel: String = String::from_utf8(message.channel).unwrap();
let payload: String = String::from_utf8(message.payload).unwrap();

assert_eq!("mychannel1", channel);
assert_eq!("mymessage12", payload);

pub_sub_stream.close().await?;
regular_client.close().await?;

Ok(())
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
#[serial]
async fn punsubscribe() -> Result<()> {
let pub_sub_client = get_test_client().await?;
let regular_client = get_test_client().await?;

// cleanup
regular_client.flushdb(FlushingMode::Sync).await?;

let mut pub_sub_stream = pub_sub_client
.psubscribe(["mychannel1*", "mychannel2*"])
.await?;

let num_patterns = regular_client.pub_sub_numpat().await?;
assert_eq!(2, num_patterns);

pub_sub_stream.punsubscribe("mychannel1*").await?;

let num_patterns = regular_client.pub_sub_numpat().await?;
assert_eq!(1, num_patterns);

pub_sub_stream.close().await?;
regular_client.close().await?;

Ok(())
}

0 comments on commit 546fd29

Please sign in to comment.