Skip to content

Commit

Permalink
Check max qos violations in server dispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Oct 7, 2022
1 parent 46897f1 commit 463e174
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

* v3: Allow to specify max allowed qos for server publishes

* v5: Check max qos violations in server dispatcher

## [0.8.10] - 2022-09-25

* Add .into_inner() client's helper for publish control message
Expand Down
3 changes: 2 additions & 1 deletion src/v3/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,10 @@ where
}
}

// check max allowed qos
if publish.qos > self.max_qos {
log::trace!(
"Max allowed QoS is viaolated, max {:?} provided {:?}",
"Max allowed QoS is violated, max {:?} provided {:?}",
self.max_qos,
publish.qos
);
Expand Down
4 changes: 2 additions & 2 deletions src/v3/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ where
self
}

/// Set max QoS allowed.
/// Set max allowed QoS.
///
/// If peer sends publish with higher qos then ProtocolError::MaxQoSViolated(..)
/// By default max qos is not set.
/// By default max qos is set to `ExactlyOnce`.
pub fn max_qos(mut self, qos: QoS) -> Self {
self.max_qos = qos;
self
Expand Down
22 changes: 21 additions & 1 deletion src/v5/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ use super::control::{ControlMessage, ControlResult};
use super::publish::{Publish, PublishAck};
use super::shared::{Ack, MqttShared};
use super::sink::MqttSink;
use super::{codec, codec::EncodeLtd, Session};
use super::{codec, codec::EncodeLtd, QoS, Session};

/// mqtt3 protocol dispatcher
pub(super) fn factory<St, T, C, E>(
publish: T,
control: C,
max_qos: QoS,
max_inflight_size: usize,
) -> impl ServiceFactory<
DispatchItem<Rc<MqttShared>>,
Expand Down Expand Up @@ -60,6 +61,7 @@ where
max_inflight_size,
Dispatcher::<_, _, E>::new(
cfg.sink().clone(),
max_qos,
max_receive as usize,
max_topic_alias,
publish,
Expand All @@ -85,6 +87,7 @@ pub(crate) struct Dispatcher<T, C: Service<ControlMessage<E>>, E> {
sink: MqttSink,
publish: T,
shutdown: RefCell<Option<Pin<Box<C::Future>>>>,
max_qos: QoS,
max_receive: usize,
max_topic_alias: u16,
inner: Rc<Inner<C>>,
Expand All @@ -111,13 +114,15 @@ where
{
fn new(
sink: MqttSink,
max_qos: QoS,
max_receive: usize,
max_topic_alias: u16,
publish: T,
control: C,
) -> Self {
Self {
publish,
max_qos,
max_receive,
max_topic_alias,
sink: sink.clone(),
Expand Down Expand Up @@ -205,6 +210,21 @@ where
)));
}

// check max allowed qos
if publish.qos > self.max_qos {
log::trace!(
"Max allowed QoS is violated, max {:?} provided {:?}",
self.max_qos,
publish.qos
);
return Either::Right(Either::Right(ControlResponse::new(
ControlMessage::proto_error(ProtocolError::MaxQoSViolated(
publish.qos,
)),
&self.inner,
)));
}

// check for duplicated packet id
if !inner.inflight.insert(pid) {
self.sink.send(codec::Packet::PublishAck(codec::PublishAck {
Expand Down
10 changes: 8 additions & 2 deletions src/v5/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ where

/// Set server max qos setting.
///
/// By default max qos is not set`
/// By default max qos is not set.
pub fn max_qos(mut self, qos: QoS) -> Self {
self.max_qos = Some(qos);
self
Expand Down Expand Up @@ -233,7 +233,12 @@ where
pool: self.pool,
_t: PhantomData,
},
factory(self.srv_publish, self.srv_control, self.max_inflight_size),
factory(
self.srv_publish,
self.srv_control,
self.max_qos.unwrap_or(QoS::ExactlyOnce),
self.max_inflight_size,
),
self.disconnect_timeout,
)
}
Expand All @@ -258,6 +263,7 @@ where
handler: Rc::new(factory(
self.srv_publish,
self.srv_control,
self.max_qos.unwrap_or(QoS::ExactlyOnce),
self.max_inflight_size,
)),
max_size: self.max_size,
Expand Down
53 changes: 52 additions & 1 deletion tests/test_server_v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use ntex::{server, service::fn_service, time::sleep};

use ntex_mqtt::v5::{
client, codec, error, ControlMessage, Handshake, HandshakeAck, MqttServer, Publish,
PublishAck, Session,
PublishAck, QoS, Session,
};

struct St;
Expand Down Expand Up @@ -833,3 +833,54 @@ async fn test_handle_incoming() -> std::io::Result<()> {

Ok(())
}

#[ntex::test]
async fn test_max_qos() -> std::io::Result<()> {
let violated = Arc::new(AtomicBool::new(false));
let violated2 = violated.clone();

let srv = server::test_server(move || {
let violated = violated2.clone();
MqttServer::new(handshake)
.max_qos(QoS::AtMostOnce)
.publish(|p: Publish| Ready::Ok::<_, TestError>(p.ack()))
.control(move |msg| {
let violated = violated.clone();
match msg {
ControlMessage::ProtocolError(msg) => {
match msg.get_ref() {
error::ProtocolError::MaxQoSViolated(_) => {
violated.store(true, Relaxed);
}
_ => (),
}
Ready::Ok::<_, TestError>(msg.ack())
}
_ => Ready::Ok(msg.disconnect()),
}
})
.finish()
});

let io = srv.connect().await.unwrap();
let codec = codec::Codec::default();
io.encode(
codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user"))),
&codec,
)
.unwrap();
let _ = io.recv(&codec).await.unwrap().unwrap();

io.encode(pkt_publish().into(), &codec).unwrap();
let pkt = io.recv(&codec).await.unwrap().unwrap();
assert_eq!(
pkt,
codec::Packet::Disconnect(codec::Disconnect {
reason_code: codec::DisconnectReasonCode::ImplementationSpecificError,
..Default::default()
})
);
assert!(violated.load(Relaxed));

Ok(())
}

0 comments on commit 463e174

Please sign in to comment.