From 7b8e3c585ea920a6cc5d1eb265b17ec596f467a9 Mon Sep 17 00:00:00 2001 From: Max Bruckner Date: Wed, 19 Jun 2024 15:51:34 +0200 Subject: [PATCH 01/20] Fix subscribe_many for MQTTv5 --- rumqttc/src/v5/client.rs | 69 +++++++++++------------- rumqttc/src/v5/mqttbytes/v5/subscribe.rs | 12 +++++ 2 files changed, 44 insertions(+), 37 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index f8629b8c5..d726197eb 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -6,7 +6,7 @@ use super::mqttbytes::v5::{ Filter, PubAck, PubRec, Publish, PublishProperties, Subscribe, SubscribeProperties, Unsubscribe, UnsubscribeProperties, }; -use super::mqttbytes::{valid_filter, QoS}; +use super::mqttbytes::QoS; use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; use crate::valid_topic; @@ -256,13 +256,14 @@ impl AsyncClient { properties: Option, ) -> Result<(), ClientError> { let filter = Filter::new(topic, qos); - let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); - let request: Request = Request::Subscribe(subscribe); - if !is_filter_valid { - return Err(ClientError::Request(request)); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(Request::Subscribe(subscribe))); } - self.request_tx.send_async(request).await?; + + self.request_tx + .send_async(Request::Subscribe(subscribe)) + .await?; Ok(()) } @@ -287,13 +288,12 @@ impl AsyncClient { properties: Option, ) -> Result<(), ClientError> { let filter = Filter::new(topic, qos); - let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); - let request = Request::Subscribe(subscribe); - if !is_filter_valid { - return Err(ClientError::TryRequest(request)); + if !subscribe.has_valid_filters() { + return Err(ClientError::TryRequest(Request::Subscribe(subscribe))); } - self.request_tx.try_send(request)?; + + self.request_tx.try_send(Request::Subscribe(subscribe))?; Ok(()) } @@ -319,15 +319,15 @@ impl AsyncClient { where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); - let request = Request::Subscribe(subscribe); - if !is_valid_filters { - return Err(ClientError::Request(request)); + let subscribe = Subscribe::new_many(topics, properties); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(Request::Subscribe(subscribe))); } - self.request_tx.send_async(request).await?; + self.request_tx + .send_async(Request::Subscribe(subscribe)) + .await?; + Ok(()) } @@ -358,14 +358,12 @@ impl AsyncClient { where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); - let request = Request::Subscribe(subscribe); - if !is_valid_filters { - return Err(ClientError::TryRequest(request)); + let subscribe = Subscribe::new_many(topics, properties); + if !subscribe.has_valid_filters() { + return Err(ClientError::TryRequest(Request::Subscribe(subscribe))); } - self.request_tx.try_send(request)?; + + self.request_tx.try_send(Request::Subscribe(subscribe))?; Ok(()) } @@ -608,13 +606,12 @@ impl Client { properties: Option, ) -> Result<(), ClientError> { let filter = Filter::new(topic, qos); - let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); - let request = Request::Subscribe(subscribe); - if !is_filter_valid { - return Err(ClientError::Request(request)); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(Request::Subscribe(subscribe))); } - self.client.request_tx.send(request)?; + + self.client.request_tx.send(Request::Subscribe(subscribe))?; Ok(()) } @@ -655,14 +652,12 @@ impl Client { where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); - let request = Request::Subscribe(subscribe); - if !is_valid_filters { - return Err(ClientError::Request(request)); + let subscribe = Subscribe::new_many(topics, properties); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(Request::Subscribe(subscribe))); } - self.client.request_tx.send(request)?; + + self.client.request_tx.send(Request::Subscribe(subscribe))?; Ok(()) } diff --git a/rumqttc/src/v5/mqttbytes/v5/subscribe.rs b/rumqttc/src/v5/mqttbytes/v5/subscribe.rs index 4167cd671..f55dbbe86 100644 --- a/rumqttc/src/v5/mqttbytes/v5/subscribe.rs +++ b/rumqttc/src/v5/mqttbytes/v5/subscribe.rs @@ -95,6 +95,14 @@ impl Subscribe { Ok(1 + remaining_len_bytes + remaining_len) } + + pub(crate) fn has_valid_filters(&self) -> bool { + if self.filters.is_empty() { + return false; + } + + self.filters.iter().all(Filter::is_valid) + } } /// Subscription filter @@ -177,6 +185,10 @@ impl Filter { write_mqtt_string(buffer, self.path.as_str()); buffer.put_u8(options); } + + fn is_valid(&self) -> bool { + valid_filter(&self.path) + } } #[derive(Debug, Clone, PartialEq, Eq)] From c5bdd32e1fc0869ad825249b8deeef3056c68cbe Mon Sep 17 00:00:00 2001 From: Max Bruckner Date: Wed, 19 Jun 2024 16:21:29 +0200 Subject: [PATCH 02/20] Fix subscribe_many for MQTTv3.2 --- rumqttc/src/client.rs | 70 ++++++++++++--------------- rumqttc/src/mqttbytes/v4/subscribe.rs | 12 +++++ 2 files changed, 42 insertions(+), 40 deletions(-) diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 15cd5f5ad..11366694d 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -3,7 +3,7 @@ use std::time::Duration; use crate::mqttbytes::{v4::*, QoS}; -use crate::{valid_filter, valid_topic, ConnectionError, Event, EventLoop, MqttOptions, Request}; +use crate::{valid_topic, ConnectionError, Event, EventLoop, MqttOptions, Request}; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; @@ -150,25 +150,23 @@ impl AsyncClient { /// Sends a MQTT Subscribe to the `EventLoop` pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - let topic = topic.into(); - let subscribe = Subscribe::new(&topic, qos); - let request = Request::Subscribe(subscribe); - if !valid_filter(&topic) { - return Err(ClientError::Request(request)); + let subscribe = Subscribe::new(topic, qos); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(subscribe.into())); } - self.request_tx.send_async(request).await?; + + self.request_tx.send_async(subscribe.into()).await?; Ok(()) } /// Attempts to send a MQTT Subscribe to the `EventLoop` pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - let topic = topic.into(); - let subscribe = Subscribe::new(&topic, qos); - let request = Request::Subscribe(subscribe); - if !valid_filter(&topic) { - return Err(ClientError::TryRequest(request)); + let subscribe = Subscribe::new(topic, qos); + if !subscribe.has_valid_filters() { + return Err(ClientError::TryRequest(subscribe.into())); } - self.request_tx.try_send(request)?; + + self.request_tx.try_send(subscribe.into())?; Ok(()) } @@ -177,14 +175,12 @@ impl AsyncClient { where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); - let request = Request::Subscribe(subscribe); - if !is_valid_filters { - return Err(ClientError::Request(request)); + let subscribe = Subscribe::new_many(topics); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(subscribe.into())); } - self.request_tx.send_async(request).await?; + + self.request_tx.send_async(subscribe.into()).await?; Ok(()) } @@ -193,14 +189,11 @@ impl AsyncClient { where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); - let request = Request::Subscribe(subscribe); - if !is_valid_filters { - return Err(ClientError::TryRequest(request)); + let subscribe = Subscribe::new_many(topics); + if !subscribe.has_valid_filters() { + return Err(ClientError::TryRequest(subscribe.into())); } - self.request_tx.try_send(request)?; + self.request_tx.try_send(subscribe.into())?; Ok(()) } @@ -341,13 +334,12 @@ impl Client { /// Sends a MQTT Subscribe to the `EventLoop` pub fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - let topic = topic.into(); - let subscribe = Subscribe::new(&topic, qos); - let request = Request::Subscribe(subscribe); - if !valid_filter(&topic) { - return Err(ClientError::Request(request)); + let subscribe = Subscribe::new(topic, qos); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(subscribe.into())); } - self.client.request_tx.send(request)?; + + self.client.request_tx.send(subscribe.into())?; Ok(()) } @@ -362,14 +354,12 @@ impl Client { where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); - let request = Request::Subscribe(subscribe); - if !is_valid_filters { - return Err(ClientError::Request(request)); + let subscribe = Subscribe::new_many(topics); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(subscribe.into())); } - self.client.request_tx.send(request)?; + + self.client.request_tx.send(subscribe.into())?; Ok(()) } diff --git a/rumqttc/src/mqttbytes/v4/subscribe.rs b/rumqttc/src/mqttbytes/v4/subscribe.rs index 42ddb57b1..6b31d3738 100644 --- a/rumqttc/src/mqttbytes/v4/subscribe.rs +++ b/rumqttc/src/mqttbytes/v4/subscribe.rs @@ -93,6 +93,14 @@ impl Subscribe { Ok(1 + remaining_len_bytes + remaining_len) } + + pub(crate) fn has_valid_filters(&self) -> bool { + if self.filters.is_empty() { + return false; + } + + self.filters.iter().all(SubscribeFilter::is_valid) + } } /// Subscription filter @@ -119,6 +127,10 @@ impl SubscribeFilter { write_mqtt_string(buffer, self.path.as_str()); buffer.put_u8(options); } + + fn is_valid(&self) -> bool { + valid_filter(&self.path) + } } #[derive(Debug, Clone, PartialEq, Eq)] From 3b9ff0c2cce308065fd4bb33f98801ba1839fe4e Mon Sep 17 00:00:00 2001 From: Max Bruckner Date: Wed, 19 Jun 2024 16:30:19 +0200 Subject: [PATCH 03/20] v5: Implement From for Request --- rumqttc/src/v5/client.rs | 12 ++++++------ rumqttc/src/v5/mod.rs | 6 ++++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index d726197eb..34aa20ce4 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -258,7 +258,7 @@ impl AsyncClient { let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); if !subscribe.has_valid_filters() { - return Err(ClientError::Request(Request::Subscribe(subscribe))); + return Err(ClientError::Request(subscribe.into())); } self.request_tx @@ -290,7 +290,7 @@ impl AsyncClient { let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); if !subscribe.has_valid_filters() { - return Err(ClientError::TryRequest(Request::Subscribe(subscribe))); + return Err(ClientError::TryRequest(subscribe.into())); } self.request_tx.try_send(Request::Subscribe(subscribe))?; @@ -321,7 +321,7 @@ impl AsyncClient { { let subscribe = Subscribe::new_many(topics, properties); if !subscribe.has_valid_filters() { - return Err(ClientError::Request(Request::Subscribe(subscribe))); + return Err(ClientError::Request(subscribe.into())); } self.request_tx @@ -360,7 +360,7 @@ impl AsyncClient { { let subscribe = Subscribe::new_many(topics, properties); if !subscribe.has_valid_filters() { - return Err(ClientError::TryRequest(Request::Subscribe(subscribe))); + return Err(ClientError::TryRequest(subscribe.into())); } self.request_tx.try_send(Request::Subscribe(subscribe))?; @@ -608,7 +608,7 @@ impl Client { let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); if !subscribe.has_valid_filters() { - return Err(ClientError::Request(Request::Subscribe(subscribe))); + return Err(ClientError::Request(subscribe.into())); } self.client.request_tx.send(Request::Subscribe(subscribe))?; @@ -654,7 +654,7 @@ impl Client { { let subscribe = Subscribe::new_many(topics, properties); if !subscribe.has_valid_filters() { - return Err(ClientError::Request(Request::Subscribe(subscribe))); + return Err(ClientError::Request(subscribe.into())); } self.client.request_tx.send(Request::Subscribe(subscribe))?; diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 2518a93f1..6e0e43931 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -49,6 +49,12 @@ pub enum Request { Disconnect, } +impl From for Request { + fn from(subscribe: Subscribe) -> Self { + Self::Subscribe(subscribe) + } +} + #[cfg(feature = "websocket")] type RequestModifierFn = Arc< dyn Fn(http::Request<()>) -> Pin> + Send>> From 4e51ad4fcc5420386921a4ce88bc8210473737f7 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Thu, 22 Feb 2024 15:59:58 +0000 Subject: [PATCH 04/20] feat: oneshot channel returns `Pkid` --- rumqttc/src/client.rs | 159 +++++++++++++++++------- rumqttc/src/lib.rs | 3 + rumqttc/src/mqttbytes/v4/publish.rs | 46 ++++++- rumqttc/src/mqttbytes/v4/subscribe.rs | 45 ++++++- rumqttc/src/mqttbytes/v4/unsubscribe.rs | 38 +++++- rumqttc/src/state.rs | 4 + 6 files changed, 240 insertions(+), 55 deletions(-) diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 15cd5f5ad..3d68cece7 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -3,7 +3,9 @@ use std::time::Duration; use crate::mqttbytes::{v4::*, QoS}; -use crate::{valid_filter, valid_topic, ConnectionError, Event, EventLoop, MqttOptions, Request}; +use crate::{ + valid_filter, valid_topic, ConnectionError, Event, EventLoop, MqttOptions, PkidPromise, Request, +}; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; @@ -72,7 +74,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, @@ -80,12 +82,16 @@ impl AsyncClient { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + publish.place_pkid_tx(pkid_tx); + let publish = Request::Publish(publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } self.request_tx.send_async(publish).await?; - Ok(()) + Ok(pkid_rx) } /// Attempts to send a MQTT Publish to the `EventLoop`. @@ -95,7 +101,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, @@ -103,12 +109,16 @@ impl AsyncClient { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + publish.place_pkid_tx(pkid_tx); + let publish = Request::Publish(publish); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } self.request_tx.try_send(publish)?; - Ok(()) + Ok(pkid_rx) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. @@ -137,87 +147,123 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result<(), ClientError> + ) -> Result where S: Into, { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + publish.place_pkid_tx(pkid_tx); + let publish = Request::Publish(publish); self.request_tx.send_async(publish).await?; - Ok(()) + Ok(pkid_rx) } /// Sends a MQTT Subscribe to the `EventLoop` - pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub async fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { let topic = topic.into(); - let subscribe = Subscribe::new(&topic, qos); + let mut subscribe = Subscribe::new(&topic, qos); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + subscribe.place_pkid_tx(pkid_tx); + let request = Request::Subscribe(subscribe); if !valid_filter(&topic) { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; - Ok(()) + Ok(pkid_rx) } /// Attempts to send a MQTT Subscribe to the `EventLoop` - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { let topic = topic.into(); - let subscribe = Subscribe::new(&topic, qos); + let mut subscribe = Subscribe::new(&topic, qos); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + subscribe.place_pkid_tx(pkid_tx); + let request = Request::Subscribe(subscribe); if !valid_filter(&topic) { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; - Ok(()) + Ok(pkid_rx) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub async fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); + let mut subscribe = Subscribe::new_many(topics_iter); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + subscribe.place_pkid_tx(pkid_tx); + let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; - Ok(()) + Ok(pkid_rx) } /// Attempts to send a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); + let mut subscribe = Subscribe::new_many(topics_iter); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + subscribe.place_pkid_tx(pkid_tx); + let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; - Ok(()) + Ok(pkid_rx) } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); + pub async fn unsubscribe>(&self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + unsubscribe.place_pkid_tx(pkid_tx); + let request = Request::Unsubscribe(unsubscribe); self.request_tx.send_async(request).await?; - Ok(()) + Ok(pkid_rx) } /// Attempts to send a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); + pub fn try_unsubscribe>(&self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + unsubscribe.place_pkid_tx(pkid_tx); + let request = Request::Unsubscribe(unsubscribe); self.request_tx.try_send(request)?; - Ok(()) + Ok(pkid_rx) } /// Sends a MQTT disconnect to the `EventLoop` @@ -292,7 +338,7 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, @@ -300,12 +346,16 @@ impl Client { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + publish.place_pkid_tx(pkid_tx); + let publish = Request::Publish(publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } self.client.request_tx.send(publish)?; - Ok(()) + Ok(pkid_rx) } pub fn try_publish( @@ -314,13 +364,12 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, { - self.client.try_publish(topic, qos, retain, payload)?; - Ok(()) + self.client.try_publish(topic, qos, retain, payload) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. @@ -340,40 +389,55 @@ impl Client { } /// Sends a MQTT Subscribe to the `EventLoop` - pub fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { let topic = topic.into(); - let subscribe = Subscribe::new(&topic, qos); + let mut subscribe = Subscribe::new(&topic, qos); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + subscribe.place_pkid_tx(pkid_tx); + let request = Request::Subscribe(subscribe); if !valid_filter(&topic) { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; - Ok(()) + Ok(pkid_rx) } /// Sends a MQTT Subscribe to the `EventLoop` - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - self.client.try_subscribe(topic, qos)?; - Ok(()) + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { + self.client.try_subscribe(topic, qos) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); + let mut subscribe = Subscribe::new_many(topics_iter); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + subscribe.place_pkid_tx(pkid_tx); + let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; - Ok(()) + Ok(pkid_rx) } - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -381,17 +445,20 @@ impl Client { } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); + pub fn unsubscribe>(&self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + unsubscribe.place_pkid_tx(pkid_tx); + let request = Request::Unsubscribe(unsubscribe); self.client.request_tx.send(request)?; - Ok(()) + Ok(pkid_rx) } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - self.client.try_unsubscribe(topic)?; - Ok(()) + pub fn try_unsubscribe>(&self, topic: S) -> Result { + self.client.try_unsubscribe(topic) } /// Sends a MQTT disconnect to the `EventLoop` diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 0c2d67fef..afefa933a 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -159,6 +159,9 @@ pub use proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; +pub type Pkid = u16; +pub type PkidPromise = tokio::sync::oneshot::Receiver; + /// Current outgoing activity on the eventloop #[derive(Debug, Clone, PartialEq, Eq)] pub enum Outgoing { diff --git a/rumqttc/src/mqttbytes/v4/publish.rs b/rumqttc/src/mqttbytes/v4/publish.rs index 37924fa9d..286c8fe3f 100644 --- a/rumqttc/src/mqttbytes/v4/publish.rs +++ b/rumqttc/src/mqttbytes/v4/publish.rs @@ -1,8 +1,10 @@ -use super::*; use bytes::{Buf, Bytes}; +use tokio::sync::oneshot::Sender; + +use super::*; +use crate::Pkid; /// Publish packet -#[derive(Clone, PartialEq, Eq)] pub struct Publish { pub dup: bool, pub qos: QoS, @@ -10,8 +12,37 @@ pub struct Publish { pub topic: String, pub pkid: u16, pub payload: Bytes, + pub pkid_tx: Option>, +} + +// TODO: figure out if this is even required +impl Clone for Publish { + fn clone(&self) -> Self { + Self { + dup: self.dup, + qos: self.qos, + retain: self.retain, + topic: self.topic.clone(), + payload: self.payload.clone(), + pkid: self.pkid, + pkid_tx: None, + } + } } +impl PartialEq for Publish { + fn eq(&self, other: &Self) -> bool { + self.dup == other.dup + && self.qos == other.qos + && self.retain == other.retain + && self.topic == other.topic + && self.payload == other.payload + && self.pkid == other.pkid + } +} + +impl Eq for Publish {} + impl Publish { pub fn new, P: Into>>(topic: S, qos: QoS, payload: P) -> Publish { Publish { @@ -21,6 +52,7 @@ impl Publish { pkid: 0, topic: topic.into(), payload: Bytes::from(payload.into()), + pkid_tx: None, } } @@ -32,6 +64,7 @@ impl Publish { pkid: 0, topic: topic.into(), payload, + pkid_tx: None, } } @@ -77,6 +110,7 @@ impl Publish { pkid, topic, payload: bytes, + pkid_tx: None, }; Ok(publish) @@ -106,6 +140,10 @@ impl Publish { Ok(1 + count + len) } + + pub fn place_pkid_tx(&mut self, pkid_tx: Sender) { + self.pkid_tx = Some(pkid_tx) + } } impl fmt::Debug for Publish { @@ -165,6 +203,7 @@ mod test { topic: "a/b".to_owned(), pkid: 10, payload: Bytes::from(&payload[..]), + pkid_tx: None, } ); } @@ -201,6 +240,7 @@ mod test { topic: "a/b".to_owned(), pkid: 0, payload: Bytes::from(&[0x01, 0x02][..]), + pkid_tx: None, } ); } @@ -214,6 +254,7 @@ mod test { topic: "a/b".to_owned(), pkid: 10, payload: Bytes::from(vec![0xF1, 0xF2, 0xF3, 0xF4]), + pkid_tx: None, }; let mut buf = BytesMut::new(); @@ -248,6 +289,7 @@ mod test { topic: "a/b".to_owned(), pkid: 0, payload: Bytes::from(vec![0xE1, 0xE2, 0xE3, 0xE4]), + pkid_tx: None, }; let mut buf = BytesMut::new(); diff --git a/rumqttc/src/mqttbytes/v4/subscribe.rs b/rumqttc/src/mqttbytes/v4/subscribe.rs index 42ddb57b1..407100d3e 100644 --- a/rumqttc/src/mqttbytes/v4/subscribe.rs +++ b/rumqttc/src/mqttbytes/v4/subscribe.rs @@ -1,13 +1,35 @@ -use super::*; use bytes::{Buf, Bytes}; +use tokio::sync::oneshot::Sender; + +use super::*; +use crate::Pkid; /// Subscription packet -#[derive(Clone, PartialEq, Eq)] pub struct Subscribe { pub pkid: u16, pub filters: Vec, + pub pkid_tx: Option>, } +// TODO: figure out if this is even required +impl Clone for Subscribe { + fn clone(&self) -> Self { + Self { + pkid: self.pkid, + filters: self.filters.clone(), + pkid_tx: None, + } + } +} + +impl PartialEq for Subscribe { + fn eq(&self, other: &Self) -> bool { + self.pkid == other.pkid && self.filters == other.filters + } +} + +impl Eq for Subscribe {} + impl Subscribe { pub fn new>(path: S, qos: QoS) -> Subscribe { let filter = SubscribeFilter { @@ -18,6 +40,7 @@ impl Subscribe { Subscribe { pkid: 0, filters: vec![filter], + pkid_tx: None, } } @@ -27,7 +50,11 @@ impl Subscribe { { let filters: Vec = topics.into_iter().collect(); - Subscribe { pkid: 0, filters } + Subscribe { + pkid: 0, + filters, + pkid_tx: None, + } } pub fn add(&mut self, path: String, qos: QoS) -> &mut Self { @@ -71,7 +98,11 @@ impl Subscribe { match filters.len() { 0 => Err(Error::EmptySubscription), - _ => Ok(Subscribe { pkid, filters }), + _ => Ok(Subscribe { + pkid, + filters, + pkid_tx: None, + }), } } @@ -93,6 +124,10 @@ impl Subscribe { Ok(1 + remaining_len_bytes + remaining_len) } + + pub fn place_pkid_tx(&mut self, pkid_tx: Sender) { + self.pkid_tx = Some(pkid_tx) + } } /// Subscription filter @@ -194,6 +229,7 @@ mod test { SubscribeFilter::new("#".to_owned(), QoS::AtLeastOnce), SubscribeFilter::new("a/b/c".to_owned(), QoS::ExactlyOnce) ], + pkid_tx: None, } ); } @@ -207,6 +243,7 @@ mod test { SubscribeFilter::new("#".to_owned(), QoS::AtLeastOnce), SubscribeFilter::new("a/b/c".to_owned(), QoS::ExactlyOnce), ], + pkid_tx: None, }; let mut buf = BytesMut::new(); diff --git a/rumqttc/src/mqttbytes/v4/unsubscribe.rs b/rumqttc/src/mqttbytes/v4/unsubscribe.rs index da34fbcc6..e4b1ed18a 100644 --- a/rumqttc/src/mqttbytes/v4/unsubscribe.rs +++ b/rumqttc/src/mqttbytes/v4/unsubscribe.rs @@ -1,18 +1,42 @@ -use super::*; use bytes::{Buf, Bytes}; +use tokio::sync::oneshot::Sender; + +use super::*; +use crate::Pkid; /// Unsubscribe packet -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug)] pub struct Unsubscribe { pub pkid: u16, pub topics: Vec, + pub pkid_tx: Option>, +} + +// TODO: figure out if this is even required +impl Clone for Unsubscribe { + fn clone(&self) -> Self { + Self { + pkid: self.pkid, + topics: self.topics.clone(), + pkid_tx: None, + } + } } +impl PartialEq for Unsubscribe { + fn eq(&self, other: &Self) -> bool { + self.pkid == other.pkid && self.topics == other.topics + } +} + +impl Eq for Unsubscribe {} + impl Unsubscribe { pub fn new>(topic: S) -> Unsubscribe { Unsubscribe { pkid: 0, topics: vec![topic.into()], + pkid_tx: None, } } @@ -42,7 +66,11 @@ impl Unsubscribe { topics.push(topic_filter); } - let unsubscribe = Unsubscribe { pkid, topics }; + let unsubscribe = Unsubscribe { + pkid, + topics, + pkid_tx: None, + }; Ok(unsubscribe) } @@ -58,4 +86,8 @@ impl Unsubscribe { } Ok(1 + remaining_len_bytes + remaining_len) } + + pub fn place_pkid_tx(&mut self, pkid_tx: Sender) { + self.pkid_tx = Some(pkid_tx) + } } diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index f7cb34841..99c760913 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -814,6 +814,7 @@ mod test { topic: "test".to_string(), pkid: 1, payload: "".into(), + pkid_tx: None, }), Some(Publish { dup: false, @@ -822,6 +823,7 @@ mod test { topic: "test".to_string(), pkid: 2, payload: "".into(), + pkid_tx: None, }), Some(Publish { dup: false, @@ -830,6 +832,7 @@ mod test { topic: "test".to_string(), pkid: 3, payload: "".into(), + pkid_tx: None, }), None, None, @@ -840,6 +843,7 @@ mod test { topic: "test".to_string(), pkid: 6, payload: "".into(), + pkid_tx: None, }), ] } From 64b2d6758f162e174b38277a573900b0de56f27b Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Thu, 22 Feb 2024 16:10:44 +0000 Subject: [PATCH 05/20] fulfill promise on handling packets --- rumqttc/src/state.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 99c760913..09ebfc17f 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -309,12 +309,19 @@ impl MqttState { /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { + // NOTE: pkid promise need not be fulfilled for QoS 0, + // user should know this but still handled in Client. if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); } let pkid = publish.pkid; + // Fulfill the pkid promise + if let Some(pkid_tx) = publish.pkid_tx.take() { + _ = pkid_tx.send(pkid); + } + if self .outgoing_pub .get(publish.pkid as usize) @@ -416,6 +423,10 @@ impl MqttState { let pkid = self.next_pkid(); subscription.pkid = pkid; + // Fulfill the pkid promise + if let Some(pkid_tx) = subscription.pkid_tx.take() { + _ = pkid_tx.send(pkid); + } debug!( "Subscribe. Topics = {:?}, Pkid = {:?}", @@ -435,6 +446,11 @@ impl MqttState { let pkid = self.next_pkid(); unsub.pkid = pkid; + // Fulfill the pkid promise + if let Some(pkid_tx) = unsub.pkid_tx.take() { + _ = pkid_tx.send(pkid); + } + debug!( "Unsubscribe. Topics = {:?}, Pkid = {:?}", unsub.topics, unsub.pkid From ade80152b663ee9a5a916faae64ef0910062dd46 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Thu, 22 Feb 2024 16:24:03 +0000 Subject: [PATCH 06/20] handle QoS 0 --- rumqttc/src/client.rs | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 3d68cece7..23cacc8e7 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -84,7 +84,12 @@ impl AsyncClient { publish.retain = retain; let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - publish.place_pkid_tx(pkid_tx); + // Fulfill instantly for QoS 0 + if qos == QoS::AtMostOnce { + _ = pkid_tx.send(0); + } else { + publish.place_pkid_tx(pkid_tx); + } let publish = Request::Publish(publish); if !valid_topic(&topic) { @@ -111,7 +116,11 @@ impl AsyncClient { publish.retain = retain; let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - publish.place_pkid_tx(pkid_tx); + if qos == QoS::AtMostOnce { + _ = pkid_tx.send(0); + } else { + publish.place_pkid_tx(pkid_tx); + } let publish = Request::Publish(publish); if !valid_topic(&topic) { @@ -155,7 +164,11 @@ impl AsyncClient { publish.retain = retain; let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - publish.place_pkid_tx(pkid_tx); + if qos == QoS::AtMostOnce { + _ = pkid_tx.send(0); + } else { + publish.place_pkid_tx(pkid_tx); + } let publish = Request::Publish(publish); self.request_tx.send_async(publish).await?; @@ -348,7 +361,11 @@ impl Client { publish.retain = retain; let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - publish.place_pkid_tx(pkid_tx); + if qos == QoS::AtMostOnce { + _ = pkid_tx.send(0); + } else { + publish.place_pkid_tx(pkid_tx); + } let publish = Request::Publish(publish); if !valid_topic(&topic) { From ff824f7078255b111bc3c078b9d4d337a5678a9a Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Thu, 22 Feb 2024 16:46:18 +0000 Subject: [PATCH 07/20] doc: example usecase --- rumqttc/examples/pkid_promise.rs | 61 ++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 rumqttc/examples/pkid_promise.rs diff --git a/rumqttc/examples/pkid_promise.rs b/rumqttc/examples/pkid_promise.rs new file mode 100644 index 000000000..de46e9b43 --- /dev/null +++ b/rumqttc/examples/pkid_promise.rs @@ -0,0 +1,61 @@ +use tokio::{ + task::{self, JoinSet}, + time, +}; + +use rumqttc::{AsyncClient, MqttOptions, QoS}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + requests(client).await; + }); + + loop { + let event = eventloop.poll().await; + match &event { + Ok(v) => { + println!("Event = {v:?}"); + } + Err(e) => { + println!("Error = {e:?}"); + return Ok(()); + } + } + } +} + +async fn requests(client: AsyncClient) { + let mut joins = JoinSet::new(); + joins.spawn( + client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap(), + ); + + for i in 1..=10 { + joins.spawn( + client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; i]) + .await + .unwrap(), + ); + + time::sleep(Duration::from_secs(1)).await; + } + + // TODO: maybe rewrite to showcase in-between resolutions? + while let Some(Ok(Ok(pkid))) = joins.join_next().await { + println!("Pkid: {:?}", pkid); + } +} From 59fafcdda102bb41a204089f75f00a936cbc432e Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Fri, 23 Feb 2024 01:32:19 +0000 Subject: [PATCH 08/20] ci: fix missing dep feature --- rumqttc/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 2e2bae14e..943036d63 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -24,7 +24,7 @@ proxy = ["dep:async-http-proxy"] [dependencies] futures-util = { version = "0.3", default-features = false, features = ["std", "sink"] } -tokio = { version = "1.36", features = ["rt", "macros", "io-util", "net", "time"] } +tokio = { version = "1.36", features = ["rt", "macros", "io-util", "net", "time", "sync"] } tokio-util = { version = "0.7", features = ["codec"] } bytes = "1.5" log = "0.4" From 50df705f4a390b6cb3be8362e610056b6636e210 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Fri, 23 Feb 2024 01:47:32 +0000 Subject: [PATCH 09/20] make example fun --- Cargo.lock | 1 + rumqttc/Cargo.toml | 1 + rumqttc/examples/pkid_promise.rs | 34 ++++++++++++++++++++------------ 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 59f8430d1..1b7f50657 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2559,6 +2559,7 @@ dependencies = [ "futures-core", "futures-sink", "pin-project-lite", + "slab", "tokio", "tracing", ] diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 943036d63..f0585aae8 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -58,6 +58,7 @@ matches = "0.1" pretty_assertions = "1" pretty_env_logger = "0.5" serde = { version = "1", features = ["derive"] } +tokio-util = { version = "0.7", features = ["time"] } [[example]] name = "tls" diff --git a/rumqttc/examples/pkid_promise.rs b/rumqttc/examples/pkid_promise.rs index de46e9b43..3042c03a1 100644 --- a/rumqttc/examples/pkid_promise.rs +++ b/rumqttc/examples/pkid_promise.rs @@ -1,7 +1,9 @@ +use futures_util::stream::StreamExt; use tokio::{ + select, task::{self, JoinSet}, - time, }; +use tokio_util::time::DelayQueue; use rumqttc::{AsyncClient, MqttOptions, QoS}; use std::error::Error; @@ -12,7 +14,7 @@ async fn main() -> Result<(), Box> { pretty_env_logger::init(); // color_backtrace::install(); - let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + let mut mqttoptions = MqttOptions::new("test-1", "broker.emqx.io", 1883); mqttoptions.set_keep_alive(Duration::from_secs(5)); let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); @@ -43,19 +45,25 @@ async fn requests(client: AsyncClient) { .unwrap(), ); + let mut queue = DelayQueue::new(); for i in 1..=10 { - joins.spawn( - client - .publish("hello/world", QoS::ExactlyOnce, false, vec![1; i]) - .await - .unwrap(), - ); - - time::sleep(Duration::from_secs(1)).await; + queue.insert(i as usize, Duration::from_secs(i)); } - // TODO: maybe rewrite to showcase in-between resolutions? - while let Some(Ok(Ok(pkid))) = joins.join_next().await { - println!("Pkid: {:?}", pkid); + loop { + select! { + Some(i) = queue.next() => { + joins.spawn( + client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; i.into_inner()]) + .await + .unwrap(), + ); + } + Some(Ok(Ok(pkid))) = joins.join_next() => { + println!("Pkid: {:?}", pkid); + } + else => break, + } } } From 9d7a1a98009f21e84f2ab29ecd84c74c3dfb9a88 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Fri, 23 Feb 2024 02:21:07 +0000 Subject: [PATCH 10/20] implement feature for v5 --- rumqttc/src/v5/client.rs | 229 ++++++++++++++------- rumqttc/src/v5/mqttbytes/v5/publish.rs | 43 +++- rumqttc/src/v5/mqttbytes/v5/subscribe.rs | 35 +++- rumqttc/src/v5/mqttbytes/v5/unsubscribe.rs | 33 ++- rumqttc/src/v5/state.rs | 15 ++ 5 files changed, 279 insertions(+), 76 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index f8629b8c5..79bf4af2c 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -8,7 +8,7 @@ use super::mqttbytes::v5::{ }; use super::mqttbytes::{valid_filter, QoS}; use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; -use crate::valid_topic; +use crate::{valid_topic, PkidPromise}; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; @@ -78,7 +78,7 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -86,12 +86,21 @@ impl AsyncClient { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + // Fulfill instantly for QoS 0 + if qos == QoS::AtMostOnce { + _ = pkid_tx.send(0); + } else { + publish.place_pkid_tx(pkid_tx); + } + let publish = Request::Publish(publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } self.request_tx.send_async(publish).await?; - Ok(()) + Ok(pkid_rx) } pub async fn publish_with_properties( @@ -101,7 +110,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -116,7 +125,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -132,7 +141,7 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -140,12 +149,21 @@ impl AsyncClient { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + // Fulfill instantly for QoS 0 + if qos == QoS::AtMostOnce { + _ = pkid_tx.send(0); + } else { + publish.place_pkid_tx(pkid_tx); + } + let publish = Request::Publish(publish); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } self.request_tx.try_send(publish)?; - Ok(()) + Ok(pkid_rx) } pub fn try_publish_with_properties( @@ -155,7 +173,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -169,7 +187,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -204,19 +222,28 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where S: Into, { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + // Fulfill instantly for QoS 0 + if qos == QoS::AtMostOnce { + _ = pkid_tx.send(0); + } else { + publish.place_pkid_tx(pkid_tx); + } + let publish = Request::Publish(publish); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } self.request_tx.send_async(publish).await?; - Ok(()) + Ok(pkid_rx) } pub async fn publish_bytes_with_properties( @@ -226,7 +253,7 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, { @@ -240,7 +267,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result<(), ClientError> + ) -> Result where S: Into, { @@ -254,16 +281,21 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { let filter = Filter::new(topic, qos); let is_filter_valid = valid_filter(&filter.path); - let subscribe = Subscribe::new(filter, properties); + let mut subscribe = Subscribe::new(filter, properties); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + + subscribe.place_pkid_tx(pkid_tx); + let request: Request = Request::Subscribe(subscribe); if !is_filter_valid { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; - Ok(()) + Ok(pkid_rx) } pub async fn subscribe_with_properties>( @@ -271,11 +303,15 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_subscribe(topic, qos, Some(properties)).await } - pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub async fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { self.handle_subscribe(topic, qos, None).await } @@ -285,16 +321,20 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { let filter = Filter::new(topic, qos); let is_filter_valid = valid_filter(&filter.path); - let subscribe = Subscribe::new(filter, properties); + let mut subscribe = Subscribe::new(filter, properties); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + subscribe.place_pkid_tx(pkid_tx); + let request = Request::Subscribe(subscribe); if !is_filter_valid { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; - Ok(()) + Ok(pkid_rx) } pub fn try_subscribe_with_properties>( @@ -302,11 +342,15 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_try_subscribe(topic, qos, Some(properties)) } - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { self.handle_try_subscribe(topic, qos, None) } @@ -315,34 +359,38 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); + let mut subscribe = Subscribe::new_many(topics_iter, properties); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + subscribe.place_pkid_tx(pkid_tx); + let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; - Ok(()) + Ok(pkid_rx) } pub async fn subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)).await } - pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub async fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -354,33 +402,37 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); + let mut subscribe = Subscribe::new_many(topics_iter, properties); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + subscribe.place_pkid_tx(pkid_tx); + let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; - Ok(()) + Ok(pkid_rx) } pub fn try_subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { self.handle_try_subscribe_many(topics, Some(properties)) } - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -392,22 +444,26 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic, properties); + ) -> Result { + let mut unsubscribe = Unsubscribe::new(topic, properties); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + unsubscribe.place_pkid_tx(pkid_tx); + let request = Request::Unsubscribe(unsubscribe); self.request_tx.send_async(request).await?; - Ok(()) + Ok(pkid_rx) } pub async fn unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_unsubscribe(topic, Some(properties)).await } - pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub async fn unsubscribe>(&self, topic: S) -> Result { self.handle_unsubscribe(topic, None).await } @@ -416,22 +472,26 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic, properties); + ) -> Result { + let mut unsubscribe = Unsubscribe::new(topic, properties); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + unsubscribe.place_pkid_tx(pkid_tx); + let request = Request::Unsubscribe(unsubscribe); self.request_tx.try_send(request)?; - Ok(()) + Ok(pkid_rx) } pub fn try_unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_try_unsubscribe(topic, Some(properties)) } - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn try_unsubscribe>(&self, topic: S) -> Result { self.handle_try_unsubscribe(topic, None) } @@ -509,7 +569,7 @@ impl Client { retain: bool, payload: P, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -517,12 +577,21 @@ impl Client { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + // Fulfill instantly for QoS 0 + if qos == QoS::AtMostOnce { + _ = pkid_tx.send(0); + } else { + publish.place_pkid_tx(pkid_tx); + } + let publish = Request::Publish(publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } self.client.request_tx.send(publish)?; - Ok(()) + Ok(pkid_rx) } pub fn publish_with_properties( @@ -532,7 +601,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -546,7 +615,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -561,7 +630,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -576,7 +645,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -606,16 +675,20 @@ impl Client { topic: S, qos: QoS, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { let filter = Filter::new(topic, qos); let is_filter_valid = valid_filter(&filter.path); - let subscribe = Subscribe::new(filter, properties); + let mut subscribe = Subscribe::new(filter, properties); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + subscribe.place_pkid_tx(pkid_tx); + let request = Request::Subscribe(subscribe); if !is_filter_valid { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; - Ok(()) + Ok(pkid_rx) } pub fn subscribe_with_properties>( @@ -623,11 +696,15 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_subscribe(topic, qos, Some(properties)) } - pub fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { self.handle_subscribe(topic, qos, None) } @@ -637,12 +714,16 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.client .try_subscribe_with_properties(topic, qos, properties) } - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { self.client.try_subscribe(topic, qos) } @@ -651,33 +732,37 @@ impl Client { &self, topics: T, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); + let mut subscribe = Subscribe::new_many(topics_iter, properties); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + subscribe.place_pkid_tx(pkid_tx); + let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; - Ok(()) + Ok(pkid_rx) } pub fn subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)) } - pub fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -688,7 +773,7 @@ impl Client { &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { @@ -696,7 +781,7 @@ impl Client { .try_subscribe_many_with_properties(topics, properties) } - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -708,22 +793,26 @@ impl Client { &self, topic: S, properties: Option, - ) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic, properties); + ) -> Result { + let mut unsubscribe = Unsubscribe::new(topic, properties); + + let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + unsubscribe.place_pkid_tx(pkid_tx); + let request = Request::Unsubscribe(unsubscribe); self.client.request_tx.send(request)?; - Ok(()) + Ok(pkid_rx) } pub fn unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_unsubscribe(topic, Some(properties)) } - pub fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn unsubscribe>(&self, topic: S) -> Result { self.handle_unsubscribe(topic, None) } @@ -732,12 +821,12 @@ impl Client { &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.client .try_unsubscribe_with_properties(topic, properties) } - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn try_unsubscribe>(&self, topic: S) -> Result { self.client.try_unsubscribe(topic) } diff --git a/rumqttc/src/v5/mqttbytes/v5/publish.rs b/rumqttc/src/v5/mqttbytes/v5/publish.rs index 74fbee225..c787e6004 100644 --- a/rumqttc/src/v5/mqttbytes/v5/publish.rs +++ b/rumqttc/src/v5/mqttbytes/v5/publish.rs @@ -1,8 +1,11 @@ -use super::*; use bytes::{Buf, Bytes}; +use tokio::sync::oneshot::Sender; + +use super::*; +use crate::Pkid; /// Publish packet -#[derive(Clone, Debug, PartialEq, Eq, Default)] +#[derive(Debug, Default)] pub struct Publish { pub dup: bool, pub qos: QoS, @@ -11,8 +14,39 @@ pub struct Publish { pub pkid: u16, pub payload: Bytes, pub properties: Option, + pub pkid_tx: Option>, +} + +// TODO: figure out if this is even required +impl Clone for Publish { + fn clone(&self) -> Self { + Self { + dup: self.dup, + qos: self.qos, + retain: self.retain, + topic: self.topic.clone(), + payload: self.payload.clone(), + pkid: self.pkid, + properties: self.properties.clone(), + pkid_tx: None, + } + } +} + +impl PartialEq for Publish { + fn eq(&self, other: &Self) -> bool { + self.dup == other.dup + && self.qos == other.qos + && self.retain == other.retain + && self.topic == other.topic + && self.payload == other.payload + && self.pkid == other.pkid + && self.properties == other.properties + } } +impl Eq for Publish {} + impl Publish { pub fn new, P: Into>( topic: T, @@ -85,6 +119,7 @@ impl Publish { topic, payload: bytes, properties, + pkid_tx: None, }; Ok(publish) @@ -120,6 +155,10 @@ impl Publish { Ok(1 + count + len) } + + pub fn place_pkid_tx(&mut self, pkid_tx: Sender) { + self.pkid_tx = Some(pkid_tx) + } } #[derive(Debug, Clone, PartialEq, Eq, Default)] diff --git a/rumqttc/src/v5/mqttbytes/v5/subscribe.rs b/rumqttc/src/v5/mqttbytes/v5/subscribe.rs index 4167cd671..28796fe30 100644 --- a/rumqttc/src/v5/mqttbytes/v5/subscribe.rs +++ b/rumqttc/src/v5/mqttbytes/v5/subscribe.rs @@ -1,14 +1,40 @@ -use super::*; use bytes::{Buf, Bytes}; +use tokio::sync::oneshot::Sender; + +use super::*; +use crate::Pkid; /// Subscription packet -#[derive(Clone, Debug, PartialEq, Eq, Default)] +#[derive(Debug, Default)] pub struct Subscribe { pub pkid: u16, pub filters: Vec, pub properties: Option, + pub pkid_tx: Option>, } +// TODO: figure out if this is even required +impl Clone for Subscribe { + fn clone(&self) -> Self { + Self { + pkid: self.pkid, + filters: self.filters.clone(), + properties: self.properties.clone(), + pkid_tx: None, + } + } +} + +impl PartialEq for Subscribe { + fn eq(&self, other: &Self) -> bool { + self.pkid == other.pkid + && self.filters == other.filters + && self.properties == other.properties + } +} + +impl Eq for Subscribe {} + impl Subscribe { pub fn new(filter: Filter, properties: Option) -> Self { Self { @@ -67,6 +93,7 @@ impl Subscribe { pkid, filters, properties, + pkid_tx: None, }), } } @@ -95,6 +122,10 @@ impl Subscribe { Ok(1 + remaining_len_bytes + remaining_len) } + + pub fn place_pkid_tx(&mut self, pkid_tx: Sender) { + self.pkid_tx = Some(pkid_tx) + } } /// Subscription filter diff --git a/rumqttc/src/v5/mqttbytes/v5/unsubscribe.rs b/rumqttc/src/v5/mqttbytes/v5/unsubscribe.rs index 2b671ce39..92d1d5610 100644 --- a/rumqttc/src/v5/mqttbytes/v5/unsubscribe.rs +++ b/rumqttc/src/v5/mqttbytes/v5/unsubscribe.rs @@ -1,14 +1,38 @@ -use super::*; use bytes::{Buf, Bytes}; +use tokio::sync::oneshot::Sender; + +use super::*; +use crate::Pkid; /// Unsubscribe packet -#[derive(Debug, Clone, PartialEq, Eq, Default)] +#[derive(Debug, Default)] pub struct Unsubscribe { pub pkid: u16, pub filters: Vec, pub properties: Option, + pub pkid_tx: Option>, +} + +// TODO: figure out if this is even required +impl Clone for Unsubscribe { + fn clone(&self) -> Self { + Self { + pkid: self.pkid, + filters: self.filters.clone(), + properties: self.properties.clone(), + pkid_tx: None, + } + } +} + +impl PartialEq for Unsubscribe { + fn eq(&self, other: &Self) -> bool { + self.pkid == other.pkid && self.filters == other.filters + } } +impl Eq for Unsubscribe {} + impl Unsubscribe { pub fn new>(filter: S, properties: Option) -> Self { Self { @@ -59,6 +83,7 @@ impl Unsubscribe { pkid, filters, properties, + pkid_tx: None, }; Ok(unsubscribe) } @@ -86,6 +111,10 @@ impl Unsubscribe { Ok(1 + remaining_len_bytes + remaining_len) } + + pub fn place_pkid_tx(&mut self, pkid_tx: Sender) { + self.pkid_tx = Some(pkid_tx) + } } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 6f37a4719..f94100567 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -474,12 +474,19 @@ impl MqttState { /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { + // NOTE: pkid promise need not be fulfilled for QoS 0, + // user should know this but still handled in Client. if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); } let pkid = publish.pkid; + // Fulfill the pkid promise + if let Some(pkid_tx) = publish.pkid_tx.take() { + _ = pkid_tx.send(pkid); + } + if self .outgoing_pub .get(publish.pkid as usize) @@ -596,6 +603,10 @@ impl MqttState { let pkid = self.next_pkid(); subscription.pkid = pkid; + // Fulfill the pkid promise + if let Some(pkid_tx) = subscription.pkid_tx.take() { + _ = pkid_tx.send(pkid); + } debug!( "Subscribe. Topics = {:?}, Pkid = {:?}", @@ -615,6 +626,10 @@ impl MqttState { ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; + // Fulfill the pkid promise + if let Some(pkid_tx) = unsub.pkid_tx.take() { + _ = pkid_tx.send(pkid); + } debug!( "Unsubscribe. Topics = {:?}, Pkid = {:?}", From 48af2345391a4059415b2dfb0d62dd53b456bfda Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Fri, 23 Feb 2024 02:24:44 +0000 Subject: [PATCH 11/20] doc: v5 example --- rumqttc/examples/pkid_promise_v5.rs | 69 +++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 rumqttc/examples/pkid_promise_v5.rs diff --git a/rumqttc/examples/pkid_promise_v5.rs b/rumqttc/examples/pkid_promise_v5.rs new file mode 100644 index 000000000..95e64a7d1 --- /dev/null +++ b/rumqttc/examples/pkid_promise_v5.rs @@ -0,0 +1,69 @@ +use futures_util::stream::StreamExt; +use tokio::{ + select, + task::{self, JoinSet}, +}; +use tokio_util::time::DelayQueue; + +use rumqttc::v5::{mqttbytes::QoS, AsyncClient, MqttOptions}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "broker.emqx.io", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + requests(client).await; + }); + + loop { + let event = eventloop.poll().await; + match &event { + Ok(v) => { + println!("Event = {v:?}"); + } + Err(e) => { + println!("Error = {e:?}"); + return Ok(()); + } + } + } +} + +async fn requests(client: AsyncClient) { + let mut joins = JoinSet::new(); + joins.spawn( + client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap(), + ); + + let mut queue = DelayQueue::new(); + for i in 1..=10 { + queue.insert(i as usize, Duration::from_secs(i)); + } + + loop { + select! { + Some(i) = queue.next() => { + joins.spawn( + client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; i.into_inner()]) + .await + .unwrap(), + ); + } + Some(Ok(Ok(pkid))) = joins.join_next() => { + println!("Pkid: {:?}", pkid); + } + else => break, + } + } +} From f6c2f550c50ce35dcd91a03cfbb3dda584731547 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Fri, 23 Feb 2024 02:28:19 +0000 Subject: [PATCH 12/20] add entry into changelog --- rumqttc/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index 34f5438cd..c36fc45bd 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -65,6 +65,7 @@ To update your code simply remove `Key::ECC()` or `Key::RSA()` from the initiali `rusttls-pemfile` to `2.0.0`, `async-tungstenite` to `0.24.0`, `ws_stream_tungstenite` to `0.12.0` and `http` to `1.0.0`. This is a breaking change as types from some of these crates are part of the public API. +- `publish` / `subscribe` / `unsubscribe` methods on `AsyncClient` and `Client` now return a `PkidPromise` which resolves into the identifier value chosen by the `EventLoop` when handling the packet. ### Deprecated From 3facf7f918037092425c63a41c04bb23fb51b4fc Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Fri, 23 Feb 2024 22:59:40 +0530 Subject: [PATCH 13/20] Abstract true nature of `PkidPromise` --- rumqttc/examples/pkid_promise.rs | 7 ++++--- rumqttc/examples/pkid_promise_v5.rs | 7 ++++--- rumqttc/src/client.rs | 26 +++++++++++++------------- rumqttc/src/lib.rs | 24 +++++++++++++++++++++++- rumqttc/src/v5/client.rs | 26 +++++++++++++------------- 5 files changed, 57 insertions(+), 33 deletions(-) diff --git a/rumqttc/examples/pkid_promise.rs b/rumqttc/examples/pkid_promise.rs index 3042c03a1..caac8d122 100644 --- a/rumqttc/examples/pkid_promise.rs +++ b/rumqttc/examples/pkid_promise.rs @@ -42,7 +42,8 @@ async fn requests(client: AsyncClient) { client .subscribe("hello/world", QoS::AtMostOnce) .await - .unwrap(), + .unwrap() + .wait_async(), ); let mut queue = DelayQueue::new(); @@ -57,10 +58,10 @@ async fn requests(client: AsyncClient) { client .publish("hello/world", QoS::ExactlyOnce, false, vec![1; i.into_inner()]) .await - .unwrap(), + .unwrap().wait_async(), ); } - Some(Ok(Ok(pkid))) = joins.join_next() => { + Some(Ok(Some(pkid))) = joins.join_next() => { println!("Pkid: {:?}", pkid); } else => break, diff --git a/rumqttc/examples/pkid_promise_v5.rs b/rumqttc/examples/pkid_promise_v5.rs index 95e64a7d1..c8cbd6422 100644 --- a/rumqttc/examples/pkid_promise_v5.rs +++ b/rumqttc/examples/pkid_promise_v5.rs @@ -42,7 +42,8 @@ async fn requests(client: AsyncClient) { client .subscribe("hello/world", QoS::AtMostOnce) .await - .unwrap(), + .unwrap() + .wait_async(), ); let mut queue = DelayQueue::new(); @@ -57,10 +58,10 @@ async fn requests(client: AsyncClient) { client .publish("hello/world", QoS::ExactlyOnce, false, vec![1; i.into_inner()]) .await - .unwrap(), + .unwrap().wait_async(), ); } - Some(Ok(Ok(pkid))) = joins.join_next() => { + Some(Ok(Some(pkid))) = joins.join_next() => { println!("Pkid: {:?}", pkid); } else => break, diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 23cacc8e7..2bb1cd6bc 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -96,7 +96,7 @@ impl AsyncClient { return Err(ClientError::Request(publish)); } self.request_tx.send_async(publish).await?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } /// Attempts to send a MQTT Publish to the `EventLoop`. @@ -127,7 +127,7 @@ impl AsyncClient { return Err(ClientError::TryRequest(publish)); } self.request_tx.try_send(publish)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. @@ -172,7 +172,7 @@ impl AsyncClient { let publish = Request::Publish(publish); self.request_tx.send_async(publish).await?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } /// Sends a MQTT Subscribe to the `EventLoop` @@ -192,7 +192,7 @@ impl AsyncClient { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } /// Attempts to send a MQTT Subscribe to the `EventLoop` @@ -212,7 +212,7 @@ impl AsyncClient { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` @@ -232,7 +232,7 @@ impl AsyncClient { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } /// Attempts to send a MQTT Subscribe for multiple topics to the `EventLoop` @@ -252,7 +252,7 @@ impl AsyncClient { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } /// Sends a MQTT Unsubscribe to the `EventLoop` @@ -264,7 +264,7 @@ impl AsyncClient { let request = Request::Unsubscribe(unsubscribe); self.request_tx.send_async(request).await?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } /// Attempts to send a MQTT Unsubscribe to the `EventLoop` @@ -276,7 +276,7 @@ impl AsyncClient { let request = Request::Unsubscribe(unsubscribe); self.request_tx.try_send(request)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } /// Sends a MQTT disconnect to the `EventLoop` @@ -372,7 +372,7 @@ impl Client { return Err(ClientError::Request(publish)); } self.client.request_tx.send(publish)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub fn try_publish( @@ -422,7 +422,7 @@ impl Client { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } /// Sends a MQTT Subscribe to the `EventLoop` @@ -451,7 +451,7 @@ impl Client { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub fn try_subscribe_many(&self, topics: T) -> Result @@ -470,7 +470,7 @@ impl Client { let request = Request::Unsubscribe(unsubscribe); self.client.request_tx.send(request)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } /// Sends a MQTT Unsubscribe to the `EventLoop` diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index afefa933a..efb628abe 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -145,6 +145,7 @@ use rustls_native_certs::load_native_certs; pub use state::{MqttState, StateError}; #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] pub use tls::Error as TlsError; +use tokio::sync::oneshot; #[cfg(feature = "use-native-tls")] pub use tokio_native_tls; #[cfg(feature = "use-native-tls")] @@ -160,7 +161,28 @@ pub use proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; pub type Pkid = u16; -pub type PkidPromise = tokio::sync::oneshot::Receiver; + +/// A token through which the user can access the assigned packet id of an associated request, once it +/// is processed by the [`EventLoop`]. +pub struct PkidPromise { + inner: oneshot::Receiver, +} + +impl PkidPromise { + pub fn new(inner: oneshot::Receiver) -> Self { + Self { inner } + } + + /// Wait for the pkid to resolve by blocking the current thread + pub fn wait(self) -> Option { + self.inner.blocking_recv().ok() + } + + /// Await pkid resolution without blocking the current thread + pub async fn wait_async(self) -> Option { + self.inner.await.ok() + } +} /// Current outgoing activity on the eventloop #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 79bf4af2c..67c095e6e 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -100,7 +100,7 @@ impl AsyncClient { return Err(ClientError::Request(publish)); } self.request_tx.send_async(publish).await?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub async fn publish_with_properties( @@ -163,7 +163,7 @@ impl AsyncClient { return Err(ClientError::TryRequest(publish)); } self.request_tx.try_send(publish)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub fn try_publish_with_properties( @@ -243,7 +243,7 @@ impl AsyncClient { return Err(ClientError::TryRequest(publish)); } self.request_tx.send_async(publish).await?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub async fn publish_bytes_with_properties( @@ -295,7 +295,7 @@ impl AsyncClient { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub async fn subscribe_with_properties>( @@ -334,7 +334,7 @@ impl AsyncClient { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub fn try_subscribe_with_properties>( @@ -376,7 +376,7 @@ impl AsyncClient { } self.request_tx.send_async(request).await?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub async fn subscribe_many_with_properties( @@ -418,7 +418,7 @@ impl AsyncClient { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub fn try_subscribe_many_with_properties( @@ -452,7 +452,7 @@ impl AsyncClient { let request = Request::Unsubscribe(unsubscribe); self.request_tx.send_async(request).await?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub async fn unsubscribe_with_properties>( @@ -480,7 +480,7 @@ impl AsyncClient { let request = Request::Unsubscribe(unsubscribe); self.request_tx.try_send(request)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub fn try_unsubscribe_with_properties>( @@ -591,7 +591,7 @@ impl Client { return Err(ClientError::Request(publish)); } self.client.request_tx.send(publish)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub fn publish_with_properties( @@ -688,7 +688,7 @@ impl Client { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub fn subscribe_with_properties>( @@ -748,7 +748,7 @@ impl Client { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub fn subscribe_many_with_properties( @@ -801,7 +801,7 @@ impl Client { let request = Request::Unsubscribe(unsubscribe); self.client.request_tx.send(request)?; - Ok(pkid_rx) + Ok(PkidPromise::new(pkid_rx)) } pub fn unsubscribe_with_properties>( From 02a5a886f8dbf39b8fe4b6f212ad309cc7bf15a5 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sat, 24 Feb 2024 17:44:06 +0530 Subject: [PATCH 14/20] move tx out of mqttbytes --- rumqttc/src/client.rs | 84 +++++++++++----------- rumqttc/src/lib.rs | 74 +++++++++++++++++-- rumqttc/src/mqttbytes/v4/publish.rs | 46 +----------- rumqttc/src/mqttbytes/v4/subscribe.rs | 45 ++---------- rumqttc/src/mqttbytes/v4/unsubscribe.rs | 38 +--------- rumqttc/src/state.rs | 59 +++++++-------- rumqttc/src/v5/client.rs | 82 ++++++++++----------- rumqttc/src/v5/mod.rs | 52 ++++++++++++-- rumqttc/src/v5/mqttbytes/v5/publish.rs | 43 +---------- rumqttc/src/v5/mqttbytes/v5/subscribe.rs | 35 +-------- rumqttc/src/v5/mqttbytes/v5/unsubscribe.rs | 30 +------- rumqttc/src/v5/state.rs | 57 ++++++++------- 12 files changed, 271 insertions(+), 374 deletions(-) diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 2bb1cd6bc..be90e3870 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -85,13 +85,14 @@ impl AsyncClient { let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); // Fulfill instantly for QoS 0 - if qos == QoS::AtMostOnce { + let pkid_tx = if qos == QoS::AtMostOnce { _ = pkid_tx.send(0); + None } else { - publish.place_pkid_tx(pkid_tx); - } + Some(pkid_tx) + }; - let publish = Request::Publish(publish); + let publish = Request::Publish(pkid_tx, publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } @@ -116,13 +117,15 @@ impl AsyncClient { publish.retain = retain; let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - if qos == QoS::AtMostOnce { + // Fulfill instantly for QoS 0 + let pkid_tx = if qos == QoS::AtMostOnce { _ = pkid_tx.send(0); + None } else { - publish.place_pkid_tx(pkid_tx); - } + Some(pkid_tx) + }; - let publish = Request::Publish(publish); + let publish = Request::Publish(pkid_tx, publish); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } @@ -164,13 +167,15 @@ impl AsyncClient { publish.retain = retain; let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - if qos == QoS::AtMostOnce { + // Fulfill instantly for QoS 0 + let pkid_tx = if qos == QoS::AtMostOnce { _ = pkid_tx.send(0); + None } else { - publish.place_pkid_tx(pkid_tx); - } + Some(pkid_tx) + }; - let publish = Request::Publish(publish); + let publish = Request::Publish(pkid_tx, publish); self.request_tx.send_async(publish).await?; Ok(PkidPromise::new(pkid_rx)) } @@ -182,12 +187,11 @@ impl AsyncClient { qos: QoS, ) -> Result { let topic = topic.into(); - let mut subscribe = Subscribe::new(&topic, qos); + let subscribe = Subscribe::new(&topic, qos); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - subscribe.place_pkid_tx(pkid_tx); - let request = Request::Subscribe(subscribe); + let request = Request::Subscribe(Some(pkid_tx), subscribe); if !valid_filter(&topic) { return Err(ClientError::Request(request)); } @@ -202,12 +206,11 @@ impl AsyncClient { qos: QoS, ) -> Result { let topic = topic.into(); - let mut subscribe = Subscribe::new(&topic, qos); + let subscribe = Subscribe::new(&topic, qos); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - subscribe.place_pkid_tx(pkid_tx); - let request = Request::Subscribe(subscribe); + let request = Request::Subscribe(Some(pkid_tx), subscribe); if !valid_filter(&topic) { return Err(ClientError::TryRequest(request)); } @@ -222,12 +225,11 @@ impl AsyncClient { { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let mut subscribe = Subscribe::new_many(topics_iter); + let subscribe = Subscribe::new_many(topics_iter); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - subscribe.place_pkid_tx(pkid_tx); - let request = Request::Subscribe(subscribe); + let request = Request::Subscribe(Some(pkid_tx), subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } @@ -242,12 +244,11 @@ impl AsyncClient { { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let mut subscribe = Subscribe::new_many(topics_iter); + let subscribe = Subscribe::new_many(topics_iter); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - subscribe.place_pkid_tx(pkid_tx); - let request = Request::Subscribe(subscribe); + let request = Request::Subscribe(Some(pkid_tx), subscribe); if !is_valid_filters { return Err(ClientError::TryRequest(request)); } @@ -257,24 +258,22 @@ impl AsyncClient { /// Sends a MQTT Unsubscribe to the `EventLoop` pub async fn unsubscribe>(&self, topic: S) -> Result { - let mut unsubscribe = Unsubscribe::new(topic.into()); + let unsubscribe = Unsubscribe::new(topic.into()); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - unsubscribe.place_pkid_tx(pkid_tx); - let request = Request::Unsubscribe(unsubscribe); + let request = Request::Unsubscribe(Some(pkid_tx), unsubscribe); self.request_tx.send_async(request).await?; Ok(PkidPromise::new(pkid_rx)) } /// Attempts to send a MQTT Unsubscribe to the `EventLoop` pub fn try_unsubscribe>(&self, topic: S) -> Result { - let mut unsubscribe = Unsubscribe::new(topic.into()); + let unsubscribe = Unsubscribe::new(topic.into()); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - unsubscribe.place_pkid_tx(pkid_tx); - let request = Request::Unsubscribe(unsubscribe); + let request = Request::Unsubscribe(Some(pkid_tx), unsubscribe); self.request_tx.try_send(request)?; Ok(PkidPromise::new(pkid_rx)) } @@ -361,13 +360,15 @@ impl Client { publish.retain = retain; let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - if qos == QoS::AtMostOnce { + // Fulfill instantly for QoS 0 + let pkid_tx = if qos == QoS::AtMostOnce { _ = pkid_tx.send(0); + None } else { - publish.place_pkid_tx(pkid_tx); - } + Some(pkid_tx) + }; - let publish = Request::Publish(publish); + let publish = Request::Publish(pkid_tx, publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } @@ -412,12 +413,11 @@ impl Client { qos: QoS, ) -> Result { let topic = topic.into(); - let mut subscribe = Subscribe::new(&topic, qos); + let subscribe = Subscribe::new(&topic, qos); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - subscribe.place_pkid_tx(pkid_tx); - let request = Request::Subscribe(subscribe); + let request = Request::Subscribe(Some(pkid_tx), subscribe); if !valid_filter(&topic) { return Err(ClientError::Request(request)); } @@ -441,12 +441,11 @@ impl Client { { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let mut subscribe = Subscribe::new_many(topics_iter); + let subscribe = Subscribe::new_many(topics_iter); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - subscribe.place_pkid_tx(pkid_tx); - let request = Request::Subscribe(subscribe); + let request = Request::Subscribe(Some(pkid_tx), subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } @@ -463,12 +462,11 @@ impl Client { /// Sends a MQTT Unsubscribe to the `EventLoop` pub fn unsubscribe>(&self, topic: S) -> Result { - let mut unsubscribe = Unsubscribe::new(topic.into()); + let unsubscribe = Unsubscribe::new(topic.into()); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - unsubscribe.place_pkid_tx(pkid_tx); - let request = Request::Unsubscribe(unsubscribe); + let request = Request::Unsubscribe(Some(pkid_tx), unsubscribe); self.client.request_tx.send(request)?; Ok(PkidPromise::new(pkid_rx)) } diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index efb628abe..e50d2e7cb 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -213,37 +213,97 @@ pub enum Outgoing { /// Requests by the client to mqtt event loop. Request are /// handled one by one. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Debug)] pub enum Request { - Publish(Publish), + Publish(Option>, Publish), PubAck(PubAck), PubRec(PubRec), PubComp(PubComp), PubRel(PubRel), PingReq(PingReq), PingResp(PingResp), - Subscribe(Subscribe), + Subscribe(Option>, Subscribe), SubAck(SubAck), - Unsubscribe(Unsubscribe), + Unsubscribe(Option>, Unsubscribe), UnsubAck(UnsubAck), Disconnect(Disconnect), } +impl Clone for Request { + fn clone(&self) -> Self { + match self { + Self::Publish(_, p) => Self::Publish(None, p.clone()), + Self::PubAck(p) => Self::PubAck(p.clone()), + Self::PubRec(p) => Self::PubRec(p.clone()), + Self::PubRel(p) => Self::PubRel(p.clone()), + Self::PubComp(p) => Self::PubComp(p.clone()), + Self::Subscribe(_, p) => Self::Subscribe(None, p.clone()), + Self::SubAck(p) => Self::SubAck(p.clone()), + Self::PingReq(p) => Self::PingReq(p.clone()), + Self::PingResp(p) => Self::PingResp(p.clone()), + Self::Disconnect(p) => Self::Disconnect(p.clone()), + Self::Unsubscribe(_, p) => Self::Unsubscribe(None, p.clone()), + Self::UnsubAck(p) => Self::UnsubAck(p.clone()), + } + } +} + +impl PartialEq for Request { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Publish(_, p1), Self::Publish(_, p2)) => p1 == p2, + (Self::PubAck(p1), Self::PubAck(p2)) => p1 == p2, + (Self::PubRec(p1), Self::PubRec(p2)) => p1 == p2, + (Self::PubRel(p1), Self::PubRel(p2)) => p1 == p2, + (Self::PubComp(p1), Self::PubComp(p2)) => p1 == p2, + (Self::Subscribe(_, p1), Self::Subscribe(_, p2)) => p1 == p2, + (Self::SubAck(p1), Self::SubAck(p2)) => p1 == p2, + (Self::PingReq(p1), Self::PingReq(p2)) => p1 == p2, + (Self::PingResp(p1), Self::PingResp(p2)) => p1 == p2, + (Self::Unsubscribe(_, p1), Self::Unsubscribe(_, p2)) => p1 == p2, + (Self::UnsubAck(p1), Self::UnsubAck(p2)) => p1 == p2, + (Self::Disconnect(p1), Self::Disconnect(p2)) => p1 == p2, + _ => false, + } + } +} + +impl Eq for Request {} + +impl Request { + fn size(&self) -> usize { + match &self { + Request::Publish(_, publish) => publish.size(), + Request::PubAck(puback) => puback.size(), + Request::PubRec(pubrec) => pubrec.size(), + Request::PubComp(pubcomp) => pubcomp.size(), + Request::PubRel(pubrel) => pubrel.size(), + Request::PingReq(pingreq) => pingreq.size(), + Request::PingResp(pingresp) => pingresp.size(), + Request::Subscribe(_, subscribe) => subscribe.size(), + Request::SubAck(suback) => suback.size(), + Request::Unsubscribe(_, unsubscribe) => unsubscribe.size(), + Request::UnsubAck(unsuback) => unsuback.size(), + Request::Disconnect(disconn) => disconn.size(), + } + } +} + impl From for Request { fn from(publish: Publish) -> Request { - Request::Publish(publish) + Request::Publish(None, publish) } } impl From for Request { fn from(subscribe: Subscribe) -> Request { - Request::Subscribe(subscribe) + Request::Subscribe(None, subscribe) } } impl From for Request { fn from(unsubscribe: Unsubscribe) -> Request { - Request::Unsubscribe(unsubscribe) + Request::Unsubscribe(None, unsubscribe) } } diff --git a/rumqttc/src/mqttbytes/v4/publish.rs b/rumqttc/src/mqttbytes/v4/publish.rs index 286c8fe3f..36530b957 100644 --- a/rumqttc/src/mqttbytes/v4/publish.rs +++ b/rumqttc/src/mqttbytes/v4/publish.rs @@ -1,10 +1,8 @@ -use bytes::{Buf, Bytes}; -use tokio::sync::oneshot::Sender; - use super::*; -use crate::Pkid; +use bytes::{Buf, Bytes}; /// Publish packet +#[derive(Clone, Eq, PartialEq)] pub struct Publish { pub dup: bool, pub qos: QoS, @@ -12,37 +10,8 @@ pub struct Publish { pub topic: String, pub pkid: u16, pub payload: Bytes, - pub pkid_tx: Option>, -} - -// TODO: figure out if this is even required -impl Clone for Publish { - fn clone(&self) -> Self { - Self { - dup: self.dup, - qos: self.qos, - retain: self.retain, - topic: self.topic.clone(), - payload: self.payload.clone(), - pkid: self.pkid, - pkid_tx: None, - } - } } -impl PartialEq for Publish { - fn eq(&self, other: &Self) -> bool { - self.dup == other.dup - && self.qos == other.qos - && self.retain == other.retain - && self.topic == other.topic - && self.payload == other.payload - && self.pkid == other.pkid - } -} - -impl Eq for Publish {} - impl Publish { pub fn new, P: Into>>(topic: S, qos: QoS, payload: P) -> Publish { Publish { @@ -52,7 +21,6 @@ impl Publish { pkid: 0, topic: topic.into(), payload: Bytes::from(payload.into()), - pkid_tx: None, } } @@ -64,7 +32,6 @@ impl Publish { pkid: 0, topic: topic.into(), payload, - pkid_tx: None, } } @@ -110,7 +77,6 @@ impl Publish { pkid, topic, payload: bytes, - pkid_tx: None, }; Ok(publish) @@ -140,10 +106,6 @@ impl Publish { Ok(1 + count + len) } - - pub fn place_pkid_tx(&mut self, pkid_tx: Sender) { - self.pkid_tx = Some(pkid_tx) - } } impl fmt::Debug for Publish { @@ -203,7 +165,6 @@ mod test { topic: "a/b".to_owned(), pkid: 10, payload: Bytes::from(&payload[..]), - pkid_tx: None, } ); } @@ -240,7 +201,6 @@ mod test { topic: "a/b".to_owned(), pkid: 0, payload: Bytes::from(&[0x01, 0x02][..]), - pkid_tx: None, } ); } @@ -254,7 +214,6 @@ mod test { topic: "a/b".to_owned(), pkid: 10, payload: Bytes::from(vec![0xF1, 0xF2, 0xF3, 0xF4]), - pkid_tx: None, }; let mut buf = BytesMut::new(); @@ -289,7 +248,6 @@ mod test { topic: "a/b".to_owned(), pkid: 0, payload: Bytes::from(vec![0xE1, 0xE2, 0xE3, 0xE4]), - pkid_tx: None, }; let mut buf = BytesMut::new(); diff --git a/rumqttc/src/mqttbytes/v4/subscribe.rs b/rumqttc/src/mqttbytes/v4/subscribe.rs index 407100d3e..8bc6497a1 100644 --- a/rumqttc/src/mqttbytes/v4/subscribe.rs +++ b/rumqttc/src/mqttbytes/v4/subscribe.rs @@ -1,35 +1,13 @@ -use bytes::{Buf, Bytes}; -use tokio::sync::oneshot::Sender; - use super::*; -use crate::Pkid; +use bytes::{Buf, Bytes}; /// Subscription packet +#[derive(Clone, Eq, PartialEq)] pub struct Subscribe { pub pkid: u16, pub filters: Vec, - pub pkid_tx: Option>, -} - -// TODO: figure out if this is even required -impl Clone for Subscribe { - fn clone(&self) -> Self { - Self { - pkid: self.pkid, - filters: self.filters.clone(), - pkid_tx: None, - } - } -} - -impl PartialEq for Subscribe { - fn eq(&self, other: &Self) -> bool { - self.pkid == other.pkid && self.filters == other.filters - } } -impl Eq for Subscribe {} - impl Subscribe { pub fn new>(path: S, qos: QoS) -> Subscribe { let filter = SubscribeFilter { @@ -40,7 +18,6 @@ impl Subscribe { Subscribe { pkid: 0, filters: vec![filter], - pkid_tx: None, } } @@ -50,11 +27,7 @@ impl Subscribe { { let filters: Vec = topics.into_iter().collect(); - Subscribe { - pkid: 0, - filters, - pkid_tx: None, - } + Subscribe { pkid: 0, filters } } pub fn add(&mut self, path: String, qos: QoS) -> &mut Self { @@ -98,11 +71,7 @@ impl Subscribe { match filters.len() { 0 => Err(Error::EmptySubscription), - _ => Ok(Subscribe { - pkid, - filters, - pkid_tx: None, - }), + _ => Ok(Subscribe { pkid, filters }), } } @@ -124,10 +93,6 @@ impl Subscribe { Ok(1 + remaining_len_bytes + remaining_len) } - - pub fn place_pkid_tx(&mut self, pkid_tx: Sender) { - self.pkid_tx = Some(pkid_tx) - } } /// Subscription filter @@ -229,7 +194,6 @@ mod test { SubscribeFilter::new("#".to_owned(), QoS::AtLeastOnce), SubscribeFilter::new("a/b/c".to_owned(), QoS::ExactlyOnce) ], - pkid_tx: None, } ); } @@ -243,7 +207,6 @@ mod test { SubscribeFilter::new("#".to_owned(), QoS::AtLeastOnce), SubscribeFilter::new("a/b/c".to_owned(), QoS::ExactlyOnce), ], - pkid_tx: None, }; let mut buf = BytesMut::new(); diff --git a/rumqttc/src/mqttbytes/v4/unsubscribe.rs b/rumqttc/src/mqttbytes/v4/unsubscribe.rs index e4b1ed18a..8263752e3 100644 --- a/rumqttc/src/mqttbytes/v4/unsubscribe.rs +++ b/rumqttc/src/mqttbytes/v4/unsubscribe.rs @@ -1,42 +1,18 @@ -use bytes::{Buf, Bytes}; -use tokio::sync::oneshot::Sender; - use super::*; -use crate::Pkid; +use bytes::{Buf, Bytes}; /// Unsubscribe packet -#[derive(Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Unsubscribe { pub pkid: u16, pub topics: Vec, - pub pkid_tx: Option>, -} - -// TODO: figure out if this is even required -impl Clone for Unsubscribe { - fn clone(&self) -> Self { - Self { - pkid: self.pkid, - topics: self.topics.clone(), - pkid_tx: None, - } - } } -impl PartialEq for Unsubscribe { - fn eq(&self, other: &Self) -> bool { - self.pkid == other.pkid && self.topics == other.topics - } -} - -impl Eq for Unsubscribe {} - impl Unsubscribe { pub fn new>(topic: S) -> Unsubscribe { Unsubscribe { pkid: 0, topics: vec![topic.into()], - pkid_tx: None, } } @@ -66,11 +42,7 @@ impl Unsubscribe { topics.push(topic_filter); } - let unsubscribe = Unsubscribe { - pkid, - topics, - pkid_tx: None, - }; + let unsubscribe = Unsubscribe { pkid, topics }; Ok(unsubscribe) } @@ -86,8 +58,4 @@ impl Unsubscribe { } Ok(1 + remaining_len_bytes + remaining_len) } - - pub fn place_pkid_tx(&mut self, pkid_tx: Sender) { - self.pkid_tx = Some(pkid_tx) - } } diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 09ebfc17f..238f98a60 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -1,10 +1,11 @@ -use crate::{Event, Incoming, Outgoing, Request}; +use crate::{Event, Incoming, Outgoing, Pkid, Request}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; use fixedbitset::FixedBitSet; use std::collections::VecDeque; use std::{io, time::Instant}; +use tokio::sync::oneshot; /// Errors during state handling #[derive(Debug, thiserror::Error)] @@ -108,7 +109,7 @@ impl MqttState { for publish in second_half.iter_mut().chain(first_half) { if let Some(publish) = publish.take() { - let request = Request::Publish(publish); + let request = Request::Publish(None, publish); pending.push(request); } } @@ -140,10 +141,10 @@ impl MqttState { request: Request, ) -> Result, StateError> { let packet = match request { - Request::Publish(publish) => self.outgoing_publish(publish)?, + Request::Publish(tx, publish) => self.outgoing_publish(publish, tx)?, Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, - Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, - Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, + Request::Subscribe(tx, subscribe) => self.outgoing_subscribe(subscribe, tx)?, + Request::Unsubscribe(tx, unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, Request::PingReq(_) => self.outgoing_ping()?, Request::Disconnect(_) => self.outgoing_disconnect()?, Request::PubAck(puback) => self.outgoing_puback(puback)?, @@ -308,7 +309,11 @@ impl MqttState { /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { + fn outgoing_publish( + &mut self, + mut publish: Publish, + pkid_tx: Option>, + ) -> Result, StateError> { // NOTE: pkid promise need not be fulfilled for QoS 0, // user should know this but still handled in Client. if publish.qos != QoS::AtMostOnce { @@ -318,7 +323,7 @@ impl MqttState { let pkid = publish.pkid; // Fulfill the pkid promise - if let Some(pkid_tx) = publish.pkid_tx.take() { + if let Some(pkid_tx) = pkid_tx { _ = pkid_tx.send(pkid); } @@ -416,6 +421,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, + pkid_tx: Option>, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -424,7 +430,7 @@ impl MqttState { let pkid = self.next_pkid(); subscription.pkid = pkid; // Fulfill the pkid promise - if let Some(pkid_tx) = subscription.pkid_tx.take() { + if let Some(pkid_tx) = pkid_tx { _ = pkid_tx.send(pkid); } @@ -442,12 +448,13 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, + pkid_tx: Option>, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; // Fulfill the pkid promise - if let Some(pkid_tx) = unsub.pkid_tx.take() { + if let Some(pkid_tx) = pkid_tx { _ = pkid_tx.send(pkid); } @@ -571,7 +578,7 @@ mod test { let publish = build_outgoing_publish(QoS::AtMostOnce); // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 0); @@ -579,12 +586,12 @@ mod test { let publish = build_outgoing_publish(QoS::AtLeastOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); assert_eq!(mqtt.last_pkid, 2); assert_eq!(mqtt.inflight, 2); @@ -592,12 +599,12 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); assert_eq!(mqtt.last_pkid, 4); assert_eq!(mqtt.inflight, 4); } @@ -683,8 +690,8 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish1).unwrap(); - mqtt.outgoing_publish(publish2).unwrap(); + mqtt.outgoing_publish(publish1, None).unwrap(); + mqtt.outgoing_publish(publish2, None).unwrap(); assert_eq!(mqtt.inflight, 2); mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); @@ -716,8 +723,8 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - let _publish_out = mqtt.outgoing_publish(publish1); - let _publish_out = mqtt.outgoing_publish(publish2); + let _publish_out = mqtt.outgoing_publish(publish1, None); + let _publish_out = mqtt.outgoing_publish(publish2, None); mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); assert_eq!(mqtt.inflight, 2); @@ -735,7 +742,7 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - let packet = mqtt.outgoing_publish(publish).unwrap().unwrap(); + let packet = mqtt.outgoing_publish(publish, None).unwrap().unwrap(); match packet { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), @@ -777,7 +784,7 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); @@ -791,7 +798,7 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(publish)) + mqtt.handle_outgoing_packet(Request::Publish(None, publish)) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) .unwrap(); @@ -830,7 +837,6 @@ mod test { topic: "test".to_string(), pkid: 1, payload: "".into(), - pkid_tx: None, }), Some(Publish { dup: false, @@ -839,7 +845,6 @@ mod test { topic: "test".to_string(), pkid: 2, payload: "".into(), - pkid_tx: None, }), Some(Publish { dup: false, @@ -848,7 +853,6 @@ mod test { topic: "test".to_string(), pkid: 3, payload: "".into(), - pkid_tx: None, }), None, None, @@ -859,7 +863,6 @@ mod test { topic: "test".to_string(), pkid: 6, payload: "".into(), - pkid_tx: None, }), ] } @@ -869,7 +872,7 @@ mod test { let requests = mqtt.clean(); let res = vec![6, 1, 2, 3]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = req { + if let Request::Publish(_, publish) = req { assert_eq!(publish.pkid, idx); } else { unreachable!() @@ -881,7 +884,7 @@ mod test { let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = req { + if let Request::Publish(_, publish) = req { assert_eq!(publish.pkid, idx); } else { unreachable!() @@ -893,7 +896,7 @@ mod test { let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = req { + if let Request::Publish(_, publish) = req { assert_eq!(publish.pkid, idx); } else { unreachable!() diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 67c095e6e..6841ed0c4 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -89,13 +89,14 @@ impl AsyncClient { let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); // Fulfill instantly for QoS 0 - if qos == QoS::AtMostOnce { + let pkid_tx = if qos == QoS::AtMostOnce { _ = pkid_tx.send(0); + None } else { - publish.place_pkid_tx(pkid_tx); - } + Some(pkid_tx) + }; - let publish = Request::Publish(publish); + let publish = Request::Publish(pkid_tx, publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } @@ -152,13 +153,14 @@ impl AsyncClient { let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); // Fulfill instantly for QoS 0 - if qos == QoS::AtMostOnce { + let pkid_tx = if qos == QoS::AtMostOnce { _ = pkid_tx.send(0); + None } else { - publish.place_pkid_tx(pkid_tx); - } + Some(pkid_tx) + }; - let publish = Request::Publish(publish); + let publish = Request::Publish(pkid_tx, publish); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } @@ -232,13 +234,14 @@ impl AsyncClient { let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); // Fulfill instantly for QoS 0 - if qos == QoS::AtMostOnce { + let pkid_tx = if qos == QoS::AtMostOnce { _ = pkid_tx.send(0); + None } else { - publish.place_pkid_tx(pkid_tx); - } + Some(pkid_tx) + }; - let publish = Request::Publish(publish); + let publish = Request::Publish(pkid_tx, publish); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } @@ -284,13 +287,11 @@ impl AsyncClient { ) -> Result { let filter = Filter::new(topic, qos); let is_filter_valid = valid_filter(&filter.path); - let mut subscribe = Subscribe::new(filter, properties); + let subscribe = Subscribe::new(filter, properties); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - subscribe.place_pkid_tx(pkid_tx); - - let request: Request = Request::Subscribe(subscribe); + let request: Request = Request::Subscribe(Some(pkid_tx), subscribe); if !is_filter_valid { return Err(ClientError::Request(request)); } @@ -324,12 +325,11 @@ impl AsyncClient { ) -> Result { let filter = Filter::new(topic, qos); let is_filter_valid = valid_filter(&filter.path); - let mut subscribe = Subscribe::new(filter, properties); + let subscribe = Subscribe::new(filter, properties); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - subscribe.place_pkid_tx(pkid_tx); - let request = Request::Subscribe(subscribe); + let request = Request::Subscribe(Some(pkid_tx), subscribe); if !is_filter_valid { return Err(ClientError::TryRequest(request)); } @@ -365,12 +365,11 @@ impl AsyncClient { { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let mut subscribe = Subscribe::new_many(topics_iter, properties); + let subscribe = Subscribe::new_many(topics_iter, properties); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - subscribe.place_pkid_tx(pkid_tx); - let request = Request::Subscribe(subscribe); + let request = Request::Subscribe(Some(pkid_tx), subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } @@ -408,12 +407,11 @@ impl AsyncClient { { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let mut subscribe = Subscribe::new_many(topics_iter, properties); + let subscribe = Subscribe::new_many(topics_iter, properties); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - subscribe.place_pkid_tx(pkid_tx); - let request = Request::Subscribe(subscribe); + let request = Request::Subscribe(Some(pkid_tx), subscribe); if !is_valid_filters { return Err(ClientError::TryRequest(request)); } @@ -445,12 +443,11 @@ impl AsyncClient { topic: S, properties: Option, ) -> Result { - let mut unsubscribe = Unsubscribe::new(topic, properties); + let unsubscribe = Unsubscribe::new(topic, properties); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - unsubscribe.place_pkid_tx(pkid_tx); - let request = Request::Unsubscribe(unsubscribe); + let request = Request::Unsubscribe(Some(pkid_tx), unsubscribe); self.request_tx.send_async(request).await?; Ok(PkidPromise::new(pkid_rx)) } @@ -473,12 +470,11 @@ impl AsyncClient { topic: S, properties: Option, ) -> Result { - let mut unsubscribe = Unsubscribe::new(topic, properties); + let unsubscribe = Unsubscribe::new(topic, properties); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - unsubscribe.place_pkid_tx(pkid_tx); - let request = Request::Unsubscribe(unsubscribe); + let request = Request::Unsubscribe(Some(pkid_tx), unsubscribe); self.request_tx.try_send(request)?; Ok(PkidPromise::new(pkid_rx)) } @@ -580,13 +576,14 @@ impl Client { let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); // Fulfill instantly for QoS 0 - if qos == QoS::AtMostOnce { + let pkid_tx = if qos == QoS::AtMostOnce { _ = pkid_tx.send(0); + None } else { - publish.place_pkid_tx(pkid_tx); - } + Some(pkid_tx) + }; - let publish = Request::Publish(publish); + let publish = Request::Publish(pkid_tx, publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } @@ -678,12 +675,11 @@ impl Client { ) -> Result { let filter = Filter::new(topic, qos); let is_filter_valid = valid_filter(&filter.path); - let mut subscribe = Subscribe::new(filter, properties); + let subscribe = Subscribe::new(filter, properties); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - subscribe.place_pkid_tx(pkid_tx); - let request = Request::Subscribe(subscribe); + let request = Request::Subscribe(Some(pkid_tx), subscribe); if !is_filter_valid { return Err(ClientError::Request(request)); } @@ -738,12 +734,11 @@ impl Client { { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let mut subscribe = Subscribe::new_many(topics_iter, properties); + let subscribe = Subscribe::new_many(topics_iter, properties); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - subscribe.place_pkid_tx(pkid_tx); - let request = Request::Subscribe(subscribe); + let request = Request::Subscribe(Some(pkid_tx), subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } @@ -794,12 +789,11 @@ impl Client { topic: S, properties: Option, ) -> Result { - let mut unsubscribe = Unsubscribe::new(topic, properties); + let unsubscribe = Unsubscribe::new(topic, properties); let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - unsubscribe.place_pkid_tx(pkid_tx); - let request = Request::Unsubscribe(unsubscribe); + let request = Request::Unsubscribe(Some(pkid_tx), unsubscribe); self.client.request_tx.send(request)?; Ok(PkidPromise::new(pkid_rx)) } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 2518a93f1..d4dbf6b0b 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -7,6 +7,7 @@ use std::{ pin::Pin, sync::Arc, }; +use tokio::sync::oneshot; mod client; mod eventloop; @@ -14,8 +15,8 @@ mod framed; pub mod mqttbytes; mod state; -use crate::Outgoing; use crate::{NetworkOptions, Transport}; +use crate::{Outgoing, Pkid}; use mqttbytes::v5::*; @@ -33,22 +34,63 @@ pub type Incoming = Packet; /// Requests by the client to mqtt event loop. Request are /// handled one by one. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Debug)] pub enum Request { - Publish(Publish), + Publish(Option>, Publish), PubAck(PubAck), PubRec(PubRec), PubComp(PubComp), PubRel(PubRel), PingReq, PingResp, - Subscribe(Subscribe), + Subscribe(Option>, Subscribe), SubAck(SubAck), - Unsubscribe(Unsubscribe), + Unsubscribe(Option>, Unsubscribe), UnsubAck(UnsubAck), Disconnect, } +impl Clone for Request { + fn clone(&self) -> Self { + match self { + Self::Publish(_, p) => Self::Publish(None, p.clone()), + Self::PubAck(p) => Self::PubAck(p.clone()), + Self::PubRec(p) => Self::PubRec(p.clone()), + Self::PubRel(p) => Self::PubRel(p.clone()), + Self::PubComp(p) => Self::PubComp(p.clone()), + Self::Subscribe(_, p) => Self::Subscribe(None, p.clone()), + Self::SubAck(p) => Self::SubAck(p.clone()), + Self::PingReq => Self::PingReq, + Self::PingResp => Self::PingResp, + Self::Disconnect => Self::Disconnect, + Self::Unsubscribe(_, p) => Self::Unsubscribe(None, p.clone()), + Self::UnsubAck(p) => Self::UnsubAck(p.clone()), + } + } +} + +impl PartialEq for Request { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Publish(_, p1), Self::Publish(_, p2)) => p1 == p2, + (Self::PubAck(p1), Self::PubAck(p2)) => p1 == p2, + (Self::PubRec(p1), Self::PubRec(p2)) => p1 == p2, + (Self::PubRel(p1), Self::PubRel(p2)) => p1 == p2, + (Self::PubComp(p1), Self::PubComp(p2)) => p1 == p2, + (Self::Subscribe(_, p1), Self::Subscribe(_, p2)) => p1 == p2, + (Self::SubAck(p1), Self::SubAck(p2)) => p1 == p2, + (Self::PingReq, Self::PingReq) + | (Self::PingResp, Self::PingResp) + | (Self::Disconnect, Self::Disconnect) => true, + (Self::Unsubscribe(_, p1), Self::Unsubscribe(_, p2)) => p1 == p2, + (Self::UnsubAck(p1), Self::UnsubAck(p2)) => p1 == p2, + _ => false, + } + } +} + +impl Eq for Request {} + #[cfg(feature = "websocket")] type RequestModifierFn = Arc< dyn Fn(http::Request<()>) -> Pin> + Send>> diff --git a/rumqttc/src/v5/mqttbytes/v5/publish.rs b/rumqttc/src/v5/mqttbytes/v5/publish.rs index c787e6004..6699e4b20 100644 --- a/rumqttc/src/v5/mqttbytes/v5/publish.rs +++ b/rumqttc/src/v5/mqttbytes/v5/publish.rs @@ -1,11 +1,8 @@ -use bytes::{Buf, Bytes}; -use tokio::sync::oneshot::Sender; - use super::*; -use crate::Pkid; +use bytes::{Buf, Bytes}; /// Publish packet -#[derive(Debug, Default)] +#[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct Publish { pub dup: bool, pub qos: QoS, @@ -14,39 +11,8 @@ pub struct Publish { pub pkid: u16, pub payload: Bytes, pub properties: Option, - pub pkid_tx: Option>, -} - -// TODO: figure out if this is even required -impl Clone for Publish { - fn clone(&self) -> Self { - Self { - dup: self.dup, - qos: self.qos, - retain: self.retain, - topic: self.topic.clone(), - payload: self.payload.clone(), - pkid: self.pkid, - properties: self.properties.clone(), - pkid_tx: None, - } - } -} - -impl PartialEq for Publish { - fn eq(&self, other: &Self) -> bool { - self.dup == other.dup - && self.qos == other.qos - && self.retain == other.retain - && self.topic == other.topic - && self.payload == other.payload - && self.pkid == other.pkid - && self.properties == other.properties - } } -impl Eq for Publish {} - impl Publish { pub fn new, P: Into>( topic: T, @@ -119,7 +85,6 @@ impl Publish { topic, payload: bytes, properties, - pkid_tx: None, }; Ok(publish) @@ -155,10 +120,6 @@ impl Publish { Ok(1 + count + len) } - - pub fn place_pkid_tx(&mut self, pkid_tx: Sender) { - self.pkid_tx = Some(pkid_tx) - } } #[derive(Debug, Clone, PartialEq, Eq, Default)] diff --git a/rumqttc/src/v5/mqttbytes/v5/subscribe.rs b/rumqttc/src/v5/mqttbytes/v5/subscribe.rs index 28796fe30..8b5667b6c 100644 --- a/rumqttc/src/v5/mqttbytes/v5/subscribe.rs +++ b/rumqttc/src/v5/mqttbytes/v5/subscribe.rs @@ -1,40 +1,14 @@ -use bytes::{Buf, Bytes}; -use tokio::sync::oneshot::Sender; - use super::*; -use crate::Pkid; +use bytes::{Buf, Bytes}; /// Subscription packet -#[derive(Debug, Default)] +#[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct Subscribe { pub pkid: u16, pub filters: Vec, pub properties: Option, - pub pkid_tx: Option>, -} - -// TODO: figure out if this is even required -impl Clone for Subscribe { - fn clone(&self) -> Self { - Self { - pkid: self.pkid, - filters: self.filters.clone(), - properties: self.properties.clone(), - pkid_tx: None, - } - } -} - -impl PartialEq for Subscribe { - fn eq(&self, other: &Self) -> bool { - self.pkid == other.pkid - && self.filters == other.filters - && self.properties == other.properties - } } -impl Eq for Subscribe {} - impl Subscribe { pub fn new(filter: Filter, properties: Option) -> Self { Self { @@ -93,7 +67,6 @@ impl Subscribe { pkid, filters, properties, - pkid_tx: None, }), } } @@ -122,10 +95,6 @@ impl Subscribe { Ok(1 + remaining_len_bytes + remaining_len) } - - pub fn place_pkid_tx(&mut self, pkid_tx: Sender) { - self.pkid_tx = Some(pkid_tx) - } } /// Subscription filter diff --git a/rumqttc/src/v5/mqttbytes/v5/unsubscribe.rs b/rumqttc/src/v5/mqttbytes/v5/unsubscribe.rs index 92d1d5610..5e357b96e 100644 --- a/rumqttc/src/v5/mqttbytes/v5/unsubscribe.rs +++ b/rumqttc/src/v5/mqttbytes/v5/unsubscribe.rs @@ -1,38 +1,15 @@ use bytes::{Buf, Bytes}; -use tokio::sync::oneshot::Sender; use super::*; -use crate::Pkid; /// Unsubscribe packet -#[derive(Debug, Default)] +#[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct Unsubscribe { pub pkid: u16, pub filters: Vec, pub properties: Option, - pub pkid_tx: Option>, } -// TODO: figure out if this is even required -impl Clone for Unsubscribe { - fn clone(&self) -> Self { - Self { - pkid: self.pkid, - filters: self.filters.clone(), - properties: self.properties.clone(), - pkid_tx: None, - } - } -} - -impl PartialEq for Unsubscribe { - fn eq(&self, other: &Self) -> bool { - self.pkid == other.pkid && self.filters == other.filters - } -} - -impl Eq for Unsubscribe {} - impl Unsubscribe { pub fn new>(filter: S, properties: Option) -> Self { Self { @@ -83,7 +60,6 @@ impl Unsubscribe { pkid, filters, properties, - pkid_tx: None, }; Ok(unsubscribe) } @@ -111,10 +87,6 @@ impl Unsubscribe { Ok(1 + remaining_len_bytes + remaining_len) } - - pub fn place_pkid_tx(&mut self, pkid_tx: Sender) { - self.pkid_tx = Some(pkid_tx) - } } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index f94100567..e2ae42625 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,3 +1,5 @@ +use crate::Pkid; + use super::mqttbytes::v5::{ ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish, @@ -11,6 +13,7 @@ use bytes::Bytes; use fixedbitset::FixedBitSet; use std::collections::{HashMap, VecDeque}; use std::{io, time::Instant}; +use tokio::sync::oneshot; /// Errors during state handling #[derive(Debug, thiserror::Error)] @@ -158,7 +161,7 @@ impl MqttState { // remove and collect pending publishes for publish in self.outgoing_pub.iter_mut() { if let Some(publish) = publish.take() { - let request = Request::Publish(publish); + let request = Request::Publish(None, publish); pending.push(request); } } @@ -190,10 +193,10 @@ impl MqttState { request: Request, ) -> Result, StateError> { let packet = match request { - Request::Publish(publish) => self.outgoing_publish(publish)?, + Request::Publish(tx, publish) => self.outgoing_publish(publish, tx)?, Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, - Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, - Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, + Request::Subscribe(tx, subscribe) => self.outgoing_subscribe(subscribe, tx)?, + Request::Unsubscribe(tx, unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, Request::PingReq => self.outgoing_ping()?, Request::Disconnect => { self.outgoing_disconnect(DisconnectReasonCode::NormalDisconnection)? @@ -473,7 +476,11 @@ impl MqttState { /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { + fn outgoing_publish( + &mut self, + mut publish: Publish, + pkid_tx: Option>, + ) -> Result, StateError> { // NOTE: pkid promise need not be fulfilled for QoS 0, // user should know this but still handled in Client. if publish.qos != QoS::AtMostOnce { @@ -483,7 +490,7 @@ impl MqttState { let pkid = publish.pkid; // Fulfill the pkid promise - if let Some(pkid_tx) = publish.pkid_tx.take() { + if let Some(pkid_tx) = pkid_tx { _ = pkid_tx.send(pkid); } @@ -596,6 +603,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, + pkid_tx: Option>, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -604,7 +612,7 @@ impl MqttState { let pkid = self.next_pkid(); subscription.pkid = pkid; // Fulfill the pkid promise - if let Some(pkid_tx) = subscription.pkid_tx.take() { + if let Some(pkid_tx) = pkid_tx { _ = pkid_tx.send(pkid); } @@ -623,11 +631,12 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, + pkid_tx: Option>, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; // Fulfill the pkid promise - if let Some(pkid_tx) = unsub.pkid_tx.take() { + if let Some(pkid_tx) = pkid_tx { _ = pkid_tx.send(pkid); } @@ -754,7 +763,7 @@ mod test { let publish = build_outgoing_publish(QoS::AtMostOnce); // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 0); @@ -762,12 +771,12 @@ mod test { let publish = build_outgoing_publish(QoS::AtLeastOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); assert_eq!(mqtt.last_pkid, 2); assert_eq!(mqtt.inflight, 2); @@ -775,12 +784,12 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); assert_eq!(mqtt.last_pkid, 4); assert_eq!(mqtt.inflight, 4); } @@ -792,17 +801,17 @@ mod test { // QoS2 publish let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be set back down to 0, since we hit the limit - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 2); // This should cause a collition - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 2); assert!(mqtt.collision.is_some()); @@ -812,7 +821,7 @@ mod test { assert_eq!(mqtt.inflight, 1); // Now there should be space in the outgoing queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 2); } @@ -896,8 +905,8 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish1).unwrap(); - mqtt.outgoing_publish(publish2).unwrap(); + mqtt.outgoing_publish(publish1, None).unwrap(); + mqtt.outgoing_publish(publish2, None).unwrap(); assert_eq!(mqtt.inflight, 2); mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap(); @@ -931,8 +940,8 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - let _publish_out = mqtt.outgoing_publish(publish1); - let _publish_out = mqtt.outgoing_publish(publish2); + let _publish_out = mqtt.outgoing_publish(publish1, None); + let _publish_out = mqtt.outgoing_publish(publish2, None); mqtt.handle_incoming_pubrec(&PubRec::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 2); @@ -950,7 +959,7 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - match mqtt.outgoing_publish(publish).unwrap().unwrap() { + match mqtt.outgoing_publish(publish, None).unwrap().unwrap() { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } @@ -990,7 +999,7 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap(); mqtt.handle_incoming_pubcomp(&PubComp::new(1, None)) @@ -1005,7 +1014,7 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(publish)) + mqtt.handle_outgoing_packet(Request::Publish(None, publish)) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1, None))) .unwrap(); From 22276a41c099ba4056ef2fb368f852dbef2c792f Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sat, 24 Feb 2024 19:36:12 +0530 Subject: [PATCH 15/20] `PkidError` when no pkid --- rumqttc/examples/pkid_promise.rs | 2 +- rumqttc/examples/pkid_promise_v5.rs | 2 +- rumqttc/src/lib.rs | 21 +++++++++++++++++---- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/rumqttc/examples/pkid_promise.rs b/rumqttc/examples/pkid_promise.rs index caac8d122..0fefd093e 100644 --- a/rumqttc/examples/pkid_promise.rs +++ b/rumqttc/examples/pkid_promise.rs @@ -61,7 +61,7 @@ async fn requests(client: AsyncClient) { .unwrap().wait_async(), ); } - Some(Ok(Some(pkid))) = joins.join_next() => { + Some(Ok(Ok(pkid))) = joins.join_next() => { println!("Pkid: {:?}", pkid); } else => break, diff --git a/rumqttc/examples/pkid_promise_v5.rs b/rumqttc/examples/pkid_promise_v5.rs index c8cbd6422..e4bd3d31c 100644 --- a/rumqttc/examples/pkid_promise_v5.rs +++ b/rumqttc/examples/pkid_promise_v5.rs @@ -61,7 +61,7 @@ async fn requests(client: AsyncClient) { .unwrap().wait_async(), ); } - Some(Ok(Some(pkid))) = joins.join_next() => { + Some(Ok(Ok(pkid))) = joins.join_next() => { println!("Pkid: {:?}", pkid); } else => break, diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index e50d2e7cb..4932cacb3 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -168,19 +168,32 @@ pub struct PkidPromise { inner: oneshot::Receiver, } +#[derive(Debug, thiserror::Error)] +pub enum PkidError { + #[error("Eventloop dropped Sender: {0}")] + Recv(#[from] oneshot::error::RecvError), +} + impl PkidPromise { pub fn new(inner: oneshot::Receiver) -> Self { Self { inner } } /// Wait for the pkid to resolve by blocking the current thread - pub fn wait(self) -> Option { - self.inner.blocking_recv().ok() + /// + /// # Panics + /// Panics if called in an async context + pub fn wait(self) -> Result { + let pkid = self.inner.blocking_recv()?; + + Ok(pkid) } /// Await pkid resolution without blocking the current thread - pub async fn wait_async(self) -> Option { - self.inner.await.ok() + pub async fn wait_async(self) -> Result { + let pkid = self.inner.await?; + + Ok(pkid) } } From c057e766fcc506753af3ad01fcedb564dcb8c33a Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sun, 25 Feb 2024 09:46:21 +0000 Subject: [PATCH 16/20] don't expose `oneshot` --- rumqttc/src/lib.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 4932cacb3..6c587c2cf 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -170,8 +170,14 @@ pub struct PkidPromise { #[derive(Debug, thiserror::Error)] pub enum PkidError { - #[error("Eventloop dropped Sender: {0}")] - Recv(#[from] oneshot::error::RecvError), + #[error("Eventloop dropped Sender")] + Recv, +} + +impl From for PkidError { + fn from(_: oneshot::error::RecvError) -> Self { + Self::Recv + } } impl PkidPromise { From e453f0c2a30294f39b7898121ed5103ccb0edbda Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 9 Apr 2024 14:12:43 +0530 Subject: [PATCH 17/20] poc: acknowledgement await with `NoticeFuture` --- rumqttc/examples/ack_notif.rs | 65 +++++++ rumqttc/src/client.rs | 173 ++++++++---------- rumqttc/src/lib.rs | 104 ++++------- rumqttc/src/state.rs | 320 +++++++++++++++++++++------------- rumqttc/src/v5/client.rs | 217 ++++++++++------------- rumqttc/src/v5/eventloop.rs | 2 +- rumqttc/src/v5/mod.rs | 15 +- rumqttc/src/v5/state.rs | 226 ++++++++++++++---------- 8 files changed, 611 insertions(+), 511 deletions(-) create mode 100644 rumqttc/examples/ack_notif.rs diff --git a/rumqttc/examples/ack_notif.rs b/rumqttc/examples/ack_notif.rs new file mode 100644 index 000000000..bb98599a4 --- /dev/null +++ b/rumqttc/examples/ack_notif.rs @@ -0,0 +1,65 @@ +use tokio::task::{self, JoinSet}; + +use rumqttc::{AsyncClient, MqttOptions, QoS}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + loop { + let event = eventloop.poll().await; + match &event { + Ok(v) => { + println!("Event = {v:?}"); + } + Err(e) => { + println!("Error = {e:?}"); + } + } + } + }); + + // Subscribe and wait for broker acknowledgement + client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap() + .wait_async() + .await + .unwrap(); + + // Publish and spawn wait for notification + let mut set = JoinSet::new(); + + let future = client + .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1024]) + .await + .unwrap(); + set.spawn(future.wait_async()); + + let future = client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 1024]) + .await + .unwrap(); + set.spawn(future.wait_async()); + + let future = client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 1024]) + .await + .unwrap(); + set.spawn(future.wait_async()); + + while let Some(res) = set.join_next().await { + println!("Acknoledged = {:?}", res?); + } + + Ok(()) +} diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index be90e3870..adfc7c1c7 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -4,7 +4,8 @@ use std::time::Duration; use crate::mqttbytes::{v4::*, QoS}; use crate::{ - valid_filter, valid_topic, ConnectionError, Event, EventLoop, MqttOptions, PkidPromise, Request, + valid_filter, valid_topic, ConnectionError, Event, EventLoop, MqttOptions, NoticeFuture, + NoticeTx, Request, }; use bytes::Bytes; @@ -74,7 +75,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result + ) -> Result where S: Into, V: Into>, @@ -83,21 +84,15 @@ impl AsyncClient { let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - // Fulfill instantly for QoS 0 - let pkid_tx = if qos == QoS::AtMostOnce { - _ = pkid_tx.send(0); - None - } else { - Some(pkid_tx) - }; + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let notice_tx = Some(NoticeTx(notice_tx)); - let publish = Request::Publish(pkid_tx, publish); + let publish = Request::Publish(notice_tx, publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } self.request_tx.send_async(publish).await?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } /// Attempts to send a MQTT Publish to the `EventLoop`. @@ -107,7 +102,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result + ) -> Result where S: Into, V: Into>, @@ -116,21 +111,22 @@ impl AsyncClient { let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let notice_tx = NoticeTx(notice_tx); // Fulfill instantly for QoS 0 - let pkid_tx = if qos == QoS::AtMostOnce { - _ = pkid_tx.send(0); + let notice_tx = if qos == QoS::AtMostOnce { + notice_tx.success(); None } else { - Some(pkid_tx) + Some(notice_tx) }; - let publish = Request::Publish(pkid_tx, publish); + let publish = Request::Publish(notice_tx, publish); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } self.request_tx.try_send(publish)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. @@ -159,25 +155,19 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result + ) -> Result where S: Into, { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - // Fulfill instantly for QoS 0 - let pkid_tx = if qos == QoS::AtMostOnce { - _ = pkid_tx.send(0); - None - } else { - Some(pkid_tx) - }; + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let notice_tx = Some(NoticeTx(notice_tx)); - let publish = Request::Publish(pkid_tx, publish); + let publish = Request::Publish(notice_tx, publish); self.request_tx.send_async(publish).await?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } /// Sends a MQTT Subscribe to the `EventLoop` @@ -185,18 +175,17 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result { let topic = topic.into(); let subscribe = Subscribe::new(&topic, qos); - - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request = Request::Subscribe(Some(pkid_tx), subscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let notice_tx = NoticeTx(notice_tx); + let request = Request::Subscribe(Some(notice_tx), subscribe); if !valid_filter(&topic) { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } /// Attempts to send a MQTT Subscribe to the `EventLoop` @@ -204,78 +193,78 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result { let topic = topic.into(); let subscribe = Subscribe::new(&topic, qos); - - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request = Request::Subscribe(Some(pkid_tx), subscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let notice_tx = NoticeTx(notice_tx); + let request = Request::Subscribe(Some(notice_tx), subscribe); if !valid_filter(&topic) { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub async fn subscribe_many(&self, topics: T) -> Result + pub async fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter); - - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request = Request::Subscribe(Some(pkid_tx), subscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let notice_tx = NoticeTx(notice_tx); + let request = Request::Subscribe(Some(notice_tx), subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } /// Attempts to send a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn try_subscribe_many(&self, topics: T) -> Result + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter); - - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request = Request::Subscribe(Some(pkid_tx), subscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let notice_tx = NoticeTx(notice_tx); + let request = Request::Subscribe(Some(notice_tx), subscribe); if !is_valid_filters { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub async fn unsubscribe>(&self, topic: S) -> Result { + pub async fn unsubscribe>( + &self, + topic: S, + ) -> Result { let unsubscribe = Unsubscribe::new(topic.into()); - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request = Request::Unsubscribe(Some(pkid_tx), unsubscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let notice_tx = NoticeTx(notice_tx); + let request = Request::Unsubscribe(Some(notice_tx), unsubscribe); self.request_tx.send_async(request).await?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } /// Attempts to send a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result { + pub fn try_unsubscribe>(&self, topic: S) -> Result { let unsubscribe = Unsubscribe::new(topic.into()); - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request = Request::Unsubscribe(Some(pkid_tx), unsubscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let notice_tx = NoticeTx(notice_tx); + let request = Request::Unsubscribe(Some(notice_tx), unsubscribe); self.request_tx.try_send(request)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } /// Sends a MQTT disconnect to the `EventLoop` @@ -350,7 +339,7 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result + ) -> Result where S: Into, V: Into>, @@ -359,21 +348,15 @@ impl Client { let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - // Fulfill instantly for QoS 0 - let pkid_tx = if qos == QoS::AtMostOnce { - _ = pkid_tx.send(0); - None - } else { - Some(pkid_tx) - }; + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let notice_tx = Some(NoticeTx(notice_tx)); - let publish = Request::Publish(pkid_tx, publish); + let publish = Request::Publish(notice_tx, publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } self.client.request_tx.send(publish)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } pub fn try_publish( @@ -382,7 +365,7 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result + ) -> Result where S: Into, V: Into>, @@ -411,18 +394,17 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result { let topic = topic.into(); let subscribe = Subscribe::new(&topic, qos); - - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request = Request::Subscribe(Some(pkid_tx), subscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let notice_tx = NoticeTx(notice_tx); + let request = Request::Subscribe(Some(notice_tx), subscribe); if !valid_filter(&topic) { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } /// Sends a MQTT Subscribe to the `EventLoop` @@ -430,30 +412,29 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result { self.client.try_subscribe(topic, qos) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn subscribe_many(&self, topics: T) -> Result + pub fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter); - - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request = Request::Subscribe(Some(pkid_tx), subscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let notice_tx = NoticeTx(notice_tx); + let request = Request::Subscribe(Some(notice_tx), subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } - pub fn try_subscribe_many(&self, topics: T) -> Result + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -461,18 +442,18 @@ impl Client { } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn unsubscribe>(&self, topic: S) -> Result { + pub fn unsubscribe>(&self, topic: S) -> Result { let unsubscribe = Unsubscribe::new(topic.into()); - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request = Request::Unsubscribe(Some(pkid_tx), unsubscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let notice_tx = NoticeTx(notice_tx); + let request = Request::Unsubscribe(Some(notice_tx), unsubscribe); self.client.request_tx.send(request)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result { + pub fn try_unsubscribe>(&self, topic: S) -> Result { self.client.try_unsubscribe(topic) } diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 6c587c2cf..beab9ffe0 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -160,46 +160,57 @@ pub use proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; -pub type Pkid = u16; - -/// A token through which the user can access the assigned packet id of an associated request, once it -/// is processed by the [`EventLoop`]. -pub struct PkidPromise { - inner: oneshot::Receiver, -} +use v5::mqttbytes::v5::{SubscribeReasonCode as V5SubscribeReasonCode, UnsubAckReason}; #[derive(Debug, thiserror::Error)] -pub enum PkidError { +pub enum NoticeError { #[error("Eventloop dropped Sender")] Recv, + #[error(" v4 Subscription Failure Reason Code: {0:?}")] + V4Subscribe(SubscribeReasonCode), + #[error(" v5 Subscription Failure Reason Code: {0:?}")] + V5Subscribe(V5SubscribeReasonCode), + #[error(" v5 Unsubscription Failure Reason: {0:?}")] + V5Unsubscribe(UnsubAckReason), } -impl From for PkidError { +impl From for NoticeError { fn from(_: oneshot::error::RecvError) -> Self { Self::Recv } } -impl PkidPromise { - pub fn new(inner: oneshot::Receiver) -> Self { - Self { inner } - } +type NoticeResult = Result<(), NoticeError>; + +/// A token through which the user is notified of the publish/subscribe/unsubscribe packet being acked by the broker. +#[derive(Debug)] +pub struct NoticeFuture(oneshot::Receiver); - /// Wait for the pkid to resolve by blocking the current thread +impl NoticeFuture { + /// Wait for broker to acknowledge by blocking the current thread /// /// # Panics /// Panics if called in an async context - pub fn wait(self) -> Result { - let pkid = self.inner.blocking_recv()?; + pub fn wait(self) -> NoticeResult { + self.0.blocking_recv()? + } - Ok(pkid) + /// Await the packet acknowledgement from broker, without blocking the current thread + pub async fn wait_async(self) -> NoticeResult { + self.0.await? } +} - /// Await pkid resolution without blocking the current thread - pub async fn wait_async(self) -> Result { - let pkid = self.inner.await?; +#[derive(Debug)] +pub struct NoticeTx(oneshot::Sender); + +impl NoticeTx { + fn success(self) { + _ = self.0.send(Ok(())); + } - Ok(pkid) + fn error(self, e: NoticeError) { + _ = self.0.send(Err(e)); } } @@ -234,61 +245,20 @@ pub enum Outgoing { /// handled one by one. #[derive(Debug)] pub enum Request { - Publish(Option>, Publish), + Publish(Option, Publish), PubAck(PubAck), PubRec(PubRec), PubComp(PubComp), - PubRel(PubRel), + PubRel(Option, PubRel), PingReq(PingReq), PingResp(PingResp), - Subscribe(Option>, Subscribe), + Subscribe(Option, Subscribe), SubAck(SubAck), - Unsubscribe(Option>, Unsubscribe), + Unsubscribe(Option, Unsubscribe), UnsubAck(UnsubAck), Disconnect(Disconnect), } -impl Clone for Request { - fn clone(&self) -> Self { - match self { - Self::Publish(_, p) => Self::Publish(None, p.clone()), - Self::PubAck(p) => Self::PubAck(p.clone()), - Self::PubRec(p) => Self::PubRec(p.clone()), - Self::PubRel(p) => Self::PubRel(p.clone()), - Self::PubComp(p) => Self::PubComp(p.clone()), - Self::Subscribe(_, p) => Self::Subscribe(None, p.clone()), - Self::SubAck(p) => Self::SubAck(p.clone()), - Self::PingReq(p) => Self::PingReq(p.clone()), - Self::PingResp(p) => Self::PingResp(p.clone()), - Self::Disconnect(p) => Self::Disconnect(p.clone()), - Self::Unsubscribe(_, p) => Self::Unsubscribe(None, p.clone()), - Self::UnsubAck(p) => Self::UnsubAck(p.clone()), - } - } -} - -impl PartialEq for Request { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::Publish(_, p1), Self::Publish(_, p2)) => p1 == p2, - (Self::PubAck(p1), Self::PubAck(p2)) => p1 == p2, - (Self::PubRec(p1), Self::PubRec(p2)) => p1 == p2, - (Self::PubRel(p1), Self::PubRel(p2)) => p1 == p2, - (Self::PubComp(p1), Self::PubComp(p2)) => p1 == p2, - (Self::Subscribe(_, p1), Self::Subscribe(_, p2)) => p1 == p2, - (Self::SubAck(p1), Self::SubAck(p2)) => p1 == p2, - (Self::PingReq(p1), Self::PingReq(p2)) => p1 == p2, - (Self::PingResp(p1), Self::PingResp(p2)) => p1 == p2, - (Self::Unsubscribe(_, p1), Self::Unsubscribe(_, p2)) => p1 == p2, - (Self::UnsubAck(p1), Self::UnsubAck(p2)) => p1 == p2, - (Self::Disconnect(p1), Self::Disconnect(p2)) => p1 == p2, - _ => false, - } - } -} - -impl Eq for Request {} - impl Request { fn size(&self) -> usize { match &self { @@ -296,7 +266,7 @@ impl Request { Request::PubAck(puback) => puback.size(), Request::PubRec(pubrec) => pubrec.size(), Request::PubComp(pubcomp) => pubcomp.size(), - Request::PubRel(pubrel) => pubrel.size(), + Request::PubRel(_, pubrel) => pubrel.size(), Request::PingReq(pingreq) => pingreq.size(), Request::PingResp(pingresp) => pingresp.size(), Request::Subscribe(_, subscribe) => subscribe.size(), diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 238f98a60..1c48f467c 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -1,11 +1,10 @@ -use crate::{Event, Incoming, Outgoing, Pkid, Request}; +use crate::{Event, Incoming, NoticeError, NoticeTx, Outgoing, Request}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; use fixedbitset::FixedBitSet; -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::{io, time::Instant}; -use tokio::sync::oneshot; /// Errors during state handling #[derive(Debug, thiserror::Error)] @@ -33,6 +32,8 @@ pub enum StateError { Deserialization(#[from] mqttbytes::Error), #[error("Connection closed by peer abruptly")] ConnectionAborted, + #[error("Subscribe failed with reason '{reason:?}' ")] + SubFail { reason: SubscribeReasonCode }, } /// State of the mqtt connection. @@ -41,7 +42,7 @@ pub enum StateError { // This is done for 2 reasons // Bad acks or out of order acks aren't O(n) causing cpu spikes // Any missing acks from the broker are detected during the next recycled use of packet ids -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct MqttState { /// Status of last ping pub await_pingresp: bool, @@ -62,13 +63,17 @@ pub struct MqttState { /// Maximum number of allowed inflight pub(crate) max_inflight: u16, /// Outgoing QoS 1, 2 publishes which aren't acked yet - pub(crate) outgoing_pub: Vec>, + pub(crate) outgoing_pub: HashMap)>, /// Packet ids of released QoS 2 publishes - pub(crate) outgoing_rel: FixedBitSet, + pub(crate) outgoing_rel: HashMap>, /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, + + outgoing_sub: HashMap>, + outgoing_unsub: HashMap>, + /// Last collision due to broker not acking in order - pub collision: Option, + pub collision: Option<(Publish, Option)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately @@ -89,10 +94,12 @@ impl MqttState { last_puback: 0, inflight: 0, max_inflight, + outgoing_pub: HashMap::new(), + outgoing_rel: HashMap::new(), // index 0 is wasted as 0 is not a valid packet id - outgoing_pub: vec![None; max_inflight as usize + 1], - outgoing_rel: FixedBitSet::with_capacity(max_inflight as usize + 1), incoming_pub: FixedBitSet::with_capacity(u16::MAX as usize + 1), + outgoing_sub: HashMap::new(), + outgoing_unsub: HashMap::new(), collision: None, // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), @@ -103,20 +110,15 @@ impl MqttState { /// Returns inflight outgoing packets and clears internal queues pub fn clean(&mut self) -> Vec { let mut pending = Vec::with_capacity(100); - let (first_half, second_half) = self - .outgoing_pub - .split_at_mut(self.last_puback as usize + 1); - for publish in second_half.iter_mut().chain(first_half) { - if let Some(publish) = publish.take() { - let request = Request::Publish(None, publish); - pending.push(request); - } + for (_, (publish, tx)) in self.outgoing_pub.drain() { + let request = Request::Publish(tx, publish); + pending.push(request); } // remove and collect pending releases - for pkid in self.outgoing_rel.ones() { - let request = Request::PubRel(PubRel::new(pkid as u16)); + for (pkid, tx) in self.outgoing_rel.drain() { + let request = Request::PubRel(tx, PubRel::new(pkid)); pending.push(request); } self.outgoing_rel.clear(); @@ -142,7 +144,7 @@ impl MqttState { ) -> Result, StateError> { let packet = match request { Request::Publish(tx, publish) => self.outgoing_publish(publish, tx)?, - Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, + Request::PubRel(tx, pubrel) => self.outgoing_pubrel(pubrel, tx)?, Request::Subscribe(tx, subscribe) => self.outgoing_subscribe(subscribe, tx)?, Request::Unsubscribe(tx, unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, Request::PingReq(_) => self.outgoing_ping()?, @@ -169,8 +171,8 @@ impl MqttState { let outgoing = match &packet { Incoming::PingResp => self.handle_incoming_pingresp()?, Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, - Incoming::SubAck(_suback) => self.handle_incoming_suback()?, - Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback()?, + Incoming::SubAck(suback) => self.handle_incoming_suback(suback)?, + Incoming::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback)?, Incoming::PubAck(puback) => self.handle_incoming_puback(puback)?, Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec)?, Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel)?, @@ -185,11 +187,56 @@ impl MqttState { Ok(outgoing) } - fn handle_incoming_suback(&mut self) -> Result, StateError> { + fn handle_incoming_suback(&mut self, suback: &SubAck) -> Result, StateError> { + if suback.pkid > self.max_inflight { + error!("Unsolicited suback packet: {:?}", suback.pkid); + return Err(StateError::Unsolicited(suback.pkid)); + } + + let tx = self + .outgoing_sub + .remove(&suback.pkid) + .ok_or(StateError::Unsolicited(suback.pkid))?; + + for reason in suback.return_codes.iter() { + match reason { + SubscribeReasonCode::Success(qos) => { + debug!("SubAck Pkid = {:?}, QoS = {:?}", suback.pkid, qos); + } + _ => { + if let Some(tx) = tx { + tx.error(NoticeError::V4Subscribe(*reason)) + } + + return Err(StateError::SubFail { reason: *reason }); + } + } + } + + if let Some(tx) = tx { + tx.success() + } + Ok(None) } - fn handle_incoming_unsuback(&mut self) -> Result, StateError> { + fn handle_incoming_unsuback( + &mut self, + unsuback: &UnsubAck, + ) -> Result, StateError> { + if unsuback.pkid > self.max_inflight { + error!("Unsolicited unsuback packet: {:?}", unsuback.pkid); + return Err(StateError::Unsolicited(unsuback.pkid)); + } + + if let Some(tx) = self + .outgoing_sub + .remove(&unsuback.pkid) + .ok_or(StateError::Unsolicited(unsuback.pkid))? + { + tx.success() + } + Ok(None) } @@ -221,21 +268,25 @@ impl MqttState { } fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { - let publish = self - .outgoing_pub - .get_mut(puback.pkid as usize) - .ok_or(StateError::Unsolicited(puback.pkid))?; - - self.last_puback = puback.pkid; - - if publish.take().is_none() { + if puback.pkid > self.max_inflight { error!("Unsolicited puback packet: {:?}", puback.pkid); return Err(StateError::Unsolicited(puback.pkid)); } + if let (_, Some(tx)) = self + .outgoing_pub + .remove(&puback.pkid) + .ok_or(StateError::Unsolicited(puback.pkid))? + { + tx.success() + } + + self.last_puback = puback.pkid; + self.inflight -= 1; - let packet = self.check_collision(puback.pkid).map(|publish| { - self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); + let packet = self.check_collision(puback.pkid).map(|(publish, tx)| { + self.outgoing_pub + .insert(publish.pkid, (publish.clone(), tx)); self.inflight += 1; let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); @@ -249,18 +300,18 @@ impl MqttState { } fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { - let publish = self - .outgoing_pub - .get_mut(pubrec.pkid as usize) - .ok_or(StateError::Unsolicited(pubrec.pkid))?; - - if publish.take().is_none() { + if pubrec.pkid > self.max_inflight { error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); return Err(StateError::Unsolicited(pubrec.pkid)); } + let (_, tx) = self + .outgoing_pub + .remove(&pubrec.pkid) + .ok_or(StateError::Unsolicited(pubrec.pkid))?; + // NOTE: Inflight - 1 for qos2 in comp - self.outgoing_rel.insert(pubrec.pkid as usize); + self.outgoing_rel.insert(pubrec.pkid, tx); let pubrel = PubRel { pkid: pubrec.pkid }; let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); self.events.push_back(event); @@ -283,14 +334,22 @@ impl MqttState { } fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { - if !self.outgoing_rel.contains(pubcomp.pkid as usize) { + if pubcomp.pkid > self.max_inflight { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); return Err(StateError::Unsolicited(pubcomp.pkid)); } + if let Some(tx) = self + .outgoing_rel + .remove(&pubcomp.pkid) + .ok_or(StateError::Unsolicited(pubcomp.pkid))? + { + tx.success() + } - self.outgoing_rel.set(pubcomp.pkid as usize, false); self.inflight -= 1; - let packet = self.check_collision(pubcomp.pkid).map(|publish| { + let packet = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { + self.outgoing_pub + .insert(pubcomp.pkid, (publish.clone(), tx)); let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); self.collision_ping_count = 0; @@ -312,29 +371,18 @@ impl MqttState { fn outgoing_publish( &mut self, mut publish: Publish, - pkid_tx: Option>, + notice_tx: Option, ) -> Result, StateError> { - // NOTE: pkid promise need not be fulfilled for QoS 0, - // user should know this but still handled in Client. if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); } let pkid = publish.pkid; - // Fulfill the pkid promise - if let Some(pkid_tx) = pkid_tx { - _ = pkid_tx.send(pkid); - } - if self - .outgoing_pub - .get(publish.pkid as usize) - .ok_or(StateError::Unsolicited(publish.pkid))? - .is_some() - { + if self.outgoing_pub.get(&publish.pkid).is_some() { info!("Collision on packet id = {:?}", publish.pkid); - self.collision = Some(publish); + self.collision = Some((publish, notice_tx)); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); return Ok(None); @@ -342,9 +390,11 @@ impl MqttState { // if there is an existing publish at this pkid, this implies that broker hasn't acked this // packet yet. This error is possible only when broker isn't acking sequentially - self.outgoing_pub[pkid as usize] = Some(publish.clone()); + self.outgoing_pub.insert(pkid, (publish.clone(), notice_tx)); self.inflight += 1; - }; + } else if let Some(tx) = notice_tx { + tx.success() + } debug!( "Publish. Topic = {}, Pkid = {:?}, Payload Size = {:?}", @@ -359,8 +409,12 @@ impl MqttState { Ok(Some(Packet::Publish(publish))) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { - let pubrel = self.save_pubrel(pubrel)?; + fn outgoing_pubrel( + &mut self, + pubrel: PubRel, + notice_tx: Option, + ) -> Result, StateError> { + let pubrel = self.save_pubrel(pubrel, notice_tx)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); @@ -421,7 +475,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, - pkid_tx: Option>, + notice_tx: Option, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -429,16 +483,13 @@ impl MqttState { let pkid = self.next_pkid(); subscription.pkid = pkid; - // Fulfill the pkid promise - if let Some(pkid_tx) = pkid_tx { - _ = pkid_tx.send(pkid); - } debug!( "Subscribe. Topics = {:?}, Pkid = {:?}", subscription.filters, subscription.pkid ); + self.outgoing_sub.insert(pkid, notice_tx); let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); self.events.push_back(event); @@ -448,21 +499,17 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, - pkid_tx: Option>, + notice_tx: Option, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; - // Fulfill the pkid promise - if let Some(pkid_tx) = pkid_tx { - _ = pkid_tx.send(pkid); - } - debug!( "Unsubscribe. Topics = {:?}, Pkid = {:?}", unsub.topics, unsub.pkid ); + self.outgoing_unsub.insert(pkid, notice_tx); let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); self.events.push_back(event); @@ -478,8 +525,8 @@ impl MqttState { Ok(Some(Packet::Disconnect)) } - fn check_collision(&mut self, pkid: u16) -> Option { - if let Some(publish) = &self.collision { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Option)> { + if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); } @@ -488,7 +535,11 @@ impl MqttState { None } - fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { + fn save_pubrel( + &mut self, + mut pubrel: PubRel, + notice_tx: Option, + ) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets 0 => { @@ -498,7 +549,7 @@ impl MqttState { _ => pubrel, }; - self.outgoing_rel.insert(pubrel.pkid as usize); + self.outgoing_rel.insert(pubrel.pkid, notice_tx); self.inflight += 1; Ok(pubrel) } @@ -525,9 +576,11 @@ impl MqttState { #[cfg(test)] mod test { + use std::collections::HashMap; + use super::{MqttState, StateError}; use crate::mqttbytes::v4::*; - use crate::mqttbytes::*; + use crate::{mqttbytes::*, NoticeTx}; use crate::{Event, Incoming, Outgoing, Request}; fn build_outgoing_publish(qos: QoS) -> Publish { @@ -700,8 +753,8 @@ mod test { mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); assert_eq!(mqtt.inflight, 0); - assert!(mqtt.outgoing_pub[1].is_none()); - assert!(mqtt.outgoing_pub[2].is_none()); + assert!(mqtt.outgoing_pub.get(&1).is_none()); + assert!(mqtt.outgoing_pub.get(&2).is_none()); } #[test] @@ -730,11 +783,11 @@ mod test { assert_eq!(mqtt.inflight, 2); // check if the remaining element's pkid is 1 - let backup = mqtt.outgoing_pub[1].clone(); - assert_eq!(backup.unwrap().pkid, 1); + let (backup, _) = mqtt.outgoing_pub.get(&1).unwrap(); + assert_eq!(backup.pkid, 1); - // check if the qos2 element's release pkid is 2 - assert!(mqtt.outgoing_rel.contains(2)); + // check if the qos2 element's release pkik has been set + assert!(mqtt.outgoing_rel.get(&2).is_some()); } #[test] @@ -827,44 +880,67 @@ mod test { fn clean_is_calculating_pending_correctly() { let mut mqtt = build_mqttstate(); - fn build_outgoing_pub() -> Vec> { - vec![ - None, - Some(Publish { - dup: false, - qos: QoS::AtMostOnce, - retain: false, - topic: "test".to_string(), - pkid: 1, - payload: "".into(), - }), - Some(Publish { - dup: false, - qos: QoS::AtMostOnce, - retain: false, - topic: "test".to_string(), - pkid: 2, - payload: "".into(), - }), - Some(Publish { - dup: false, - qos: QoS::AtMostOnce, - retain: false, - topic: "test".to_string(), - pkid: 3, - payload: "".into(), - }), - None, - None, - Some(Publish { - dup: false, - qos: QoS::AtMostOnce, - retain: false, - topic: "test".to_string(), - pkid: 6, - payload: "".into(), - }), - ] + fn build_outgoing_pub() -> HashMap)> { + let mut outgoing_pub = HashMap::new(); + outgoing_pub.insert( + 2, + ( + Publish { + dup: false, + qos: QoS::AtMostOnce, + retain: false, + topic: "test".to_string(), + pkid: 1, + payload: "".into(), + }, + None, + ), + ); + outgoing_pub.insert( + 3, + ( + Publish { + dup: false, + qos: QoS::AtMostOnce, + retain: false, + topic: "test".to_string(), + pkid: 2, + payload: "".into(), + }, + None, + ), + ); + outgoing_pub.insert( + 4, + ( + Publish { + dup: false, + qos: QoS::AtMostOnce, + retain: false, + topic: "test".to_string(), + pkid: 3, + payload: "".into(), + }, + None, + ), + ); + + outgoing_pub.insert( + 7, + ( + Publish { + dup: false, + qos: QoS::AtMostOnce, + retain: false, + topic: "test".to_string(), + pkid: 6, + payload: "".into(), + }, + None, + ), + ); + + outgoing_pub } mqtt.outgoing_pub = build_outgoing_pub(); diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 6841ed0c4..7b1e39ca6 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -8,7 +8,7 @@ use super::mqttbytes::v5::{ }; use super::mqttbytes::{valid_filter, QoS}; use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; -use crate::{valid_topic, PkidPromise}; +use crate::{valid_topic, NoticeFuture, NoticeTx}; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; @@ -78,7 +78,7 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result + ) -> Result where S: Into, P: Into, @@ -87,21 +87,16 @@ impl AsyncClient { let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); // Fulfill instantly for QoS 0 - let pkid_tx = if qos == QoS::AtMostOnce { - _ = pkid_tx.send(0); - None - } else { - Some(pkid_tx) - }; + let notice_tx = Some(NoticeTx(notice_tx)); - let publish = Request::Publish(pkid_tx, publish); + let publish = Request::Publish(notice_tx, publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } self.request_tx.send_async(publish).await?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } pub async fn publish_with_properties( @@ -111,7 +106,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result + ) -> Result where S: Into, P: Into, @@ -126,7 +121,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result + ) -> Result where S: Into, P: Into, @@ -142,7 +137,7 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result + ) -> Result where S: Into, P: Into, @@ -151,21 +146,16 @@ impl AsyncClient { let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); // Fulfill instantly for QoS 0 - let pkid_tx = if qos == QoS::AtMostOnce { - _ = pkid_tx.send(0); - None - } else { - Some(pkid_tx) - }; + let notice_tx = Some(NoticeTx(notice_tx)); - let publish = Request::Publish(pkid_tx, publish); + let publish = Request::Publish(notice_tx, publish); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } self.request_tx.try_send(publish)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } pub fn try_publish_with_properties( @@ -175,7 +165,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result + ) -> Result where S: Into, P: Into, @@ -189,7 +179,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result + ) -> Result where S: Into, P: Into, @@ -224,7 +214,7 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: Option, - ) -> Result + ) -> Result where S: Into, { @@ -232,21 +222,16 @@ impl AsyncClient { let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); // Fulfill instantly for QoS 0 - let pkid_tx = if qos == QoS::AtMostOnce { - _ = pkid_tx.send(0); - None - } else { - Some(pkid_tx) - }; + let notice_tx = Some(NoticeTx(notice_tx)); - let publish = Request::Publish(pkid_tx, publish); + let publish = Request::Publish(notice_tx, publish); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } self.request_tx.send_async(publish).await?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } pub async fn publish_bytes_with_properties( @@ -256,7 +241,7 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: PublishProperties, - ) -> Result + ) -> Result where S: Into, { @@ -270,7 +255,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result + ) -> Result where S: Into, { @@ -284,19 +269,17 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result { + ) -> Result { let filter = Filter::new(topic, qos); let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); - - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request: Request = Request::Subscribe(Some(pkid_tx), subscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let request = Request::Subscribe(Some(NoticeTx(notice_tx)), subscribe); if !is_filter_valid { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } pub async fn subscribe_with_properties>( @@ -304,7 +287,7 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result { + ) -> Result { self.handle_subscribe(topic, qos, Some(properties)).await } @@ -312,7 +295,7 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result { self.handle_subscribe(topic, qos, None).await } @@ -322,19 +305,17 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result { + ) -> Result { let filter = Filter::new(topic, qos); let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); - - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request = Request::Subscribe(Some(pkid_tx), subscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let request = Request::Subscribe(Some(NoticeTx(notice_tx)), subscribe); if !is_filter_valid { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } pub fn try_subscribe_with_properties>( @@ -342,7 +323,7 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result { + ) -> Result { self.handle_try_subscribe(topic, qos, Some(properties)) } @@ -350,7 +331,7 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result { self.handle_try_subscribe(topic, qos, None) } @@ -359,37 +340,35 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result + ) -> Result where T: IntoIterator, { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter, properties); - - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request = Request::Subscribe(Some(pkid_tx), subscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let request = Request::Subscribe(Some(NoticeTx(notice_tx)), subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } self.request_tx.send_async(request).await?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } pub async fn subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result + ) -> Result where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)).await } - pub async fn subscribe_many(&self, topics: T) -> Result + pub async fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -401,36 +380,34 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result + ) -> Result where T: IntoIterator, { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter, properties); - - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request = Request::Subscribe(Some(pkid_tx), subscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let request = Request::Subscribe(Some(NoticeTx(notice_tx)), subscribe); if !is_valid_filters { return Err(ClientError::TryRequest(request)); } self.request_tx.try_send(request)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } pub fn try_subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result + ) -> Result where T: IntoIterator, { self.handle_try_subscribe_many(topics, Some(properties)) } - pub fn try_subscribe_many(&self, topics: T) -> Result + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -442,25 +419,28 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result { + ) -> Result { let unsubscribe = Unsubscribe::new(topic, properties); - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let request = Request::Unsubscribe(Some(pkid_tx), unsubscribe); + let request = Request::Unsubscribe(Some(NoticeTx(notice_tx)), unsubscribe); self.request_tx.send_async(request).await?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } pub async fn unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result { + ) -> Result { self.handle_unsubscribe(topic, Some(properties)).await } - pub async fn unsubscribe>(&self, topic: S) -> Result { + pub async fn unsubscribe>( + &self, + topic: S, + ) -> Result { self.handle_unsubscribe(topic, None).await } @@ -469,25 +449,25 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result { + ) -> Result { let unsubscribe = Unsubscribe::new(topic, properties); - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let request = Request::Unsubscribe(Some(pkid_tx), unsubscribe); + let request = Request::Unsubscribe(Some(NoticeTx(notice_tx)), unsubscribe); self.request_tx.try_send(request)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } pub fn try_unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result { + ) -> Result { self.handle_try_unsubscribe(topic, Some(properties)) } - pub fn try_unsubscribe>(&self, topic: S) -> Result { + pub fn try_unsubscribe>(&self, topic: S) -> Result { self.handle_try_unsubscribe(topic, None) } @@ -565,7 +545,7 @@ impl Client { retain: bool, payload: P, properties: Option, - ) -> Result + ) -> Result where S: Into, P: Into, @@ -574,21 +554,16 @@ impl Client { let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); // Fulfill instantly for QoS 0 - let pkid_tx = if qos == QoS::AtMostOnce { - _ = pkid_tx.send(0); - None - } else { - Some(pkid_tx) - }; + let notice_tx = Some(NoticeTx(notice_tx)); - let publish = Request::Publish(pkid_tx, publish); + let publish = Request::Publish(notice_tx, publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } self.client.request_tx.send(publish)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } pub fn publish_with_properties( @@ -598,7 +573,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result + ) -> Result where S: Into, P: Into, @@ -612,7 +587,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result + ) -> Result where S: Into, P: Into, @@ -627,7 +602,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result + ) -> Result where S: Into, P: Into, @@ -642,7 +617,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result + ) -> Result where S: Into, P: Into, @@ -672,19 +647,17 @@ impl Client { topic: S, qos: QoS, properties: Option, - ) -> Result { + ) -> Result { let filter = Filter::new(topic, qos); let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); - - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request = Request::Subscribe(Some(pkid_tx), subscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let request = Request::Subscribe(Some(NoticeTx(notice_tx)), subscribe); if !is_filter_valid { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } pub fn subscribe_with_properties>( @@ -692,7 +665,7 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result { + ) -> Result { self.handle_subscribe(topic, qos, Some(properties)) } @@ -700,7 +673,7 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result { self.handle_subscribe(topic, qos, None) } @@ -710,7 +683,7 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result { + ) -> Result { self.client .try_subscribe_with_properties(topic, qos, properties) } @@ -719,7 +692,7 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result { self.client.try_subscribe(topic, qos) } @@ -728,36 +701,34 @@ impl Client { &self, topics: T, properties: Option, - ) -> Result + ) -> Result where T: IntoIterator, { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter, properties); - - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); - - let request = Request::Subscribe(Some(pkid_tx), subscribe); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let request = Request::Subscribe(Some(NoticeTx(notice_tx)), subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } self.client.request_tx.send(request)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } pub fn subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result + ) -> Result where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)) } - pub fn subscribe_many(&self, topics: T) -> Result + pub fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -768,7 +739,7 @@ impl Client { &self, topics: T, properties: SubscribeProperties, - ) -> Result + ) -> Result where T: IntoIterator, { @@ -776,7 +747,7 @@ impl Client { .try_subscribe_many_with_properties(topics, properties) } - pub fn try_subscribe_many(&self, topics: T) -> Result + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -788,25 +759,25 @@ impl Client { &self, topic: S, properties: Option, - ) -> Result { + ) -> Result { let unsubscribe = Unsubscribe::new(topic, properties); - let (pkid_tx, pkid_rx) = tokio::sync::oneshot::channel(); + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let request = Request::Unsubscribe(Some(pkid_tx), unsubscribe); + let request = Request::Unsubscribe(Some(NoticeTx(notice_tx)), unsubscribe); self.client.request_tx.send(request)?; - Ok(PkidPromise::new(pkid_rx)) + Ok(NoticeFuture(notice_rx)) } pub fn unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result { + ) -> Result { self.handle_unsubscribe(topic, Some(properties)) } - pub fn unsubscribe>(&self, topic: S) -> Result { + pub fn unsubscribe>(&self, topic: S) -> Result { self.handle_unsubscribe(topic, None) } @@ -815,12 +786,12 @@ impl Client { &self, topic: S, properties: UnsubscribeProperties, - ) -> Result { + ) -> Result { self.client .try_unsubscribe_with_properties(topic, properties) } - pub fn try_unsubscribe>(&self, topic: S) -> Result { + pub fn try_unsubscribe>(&self, topic: S) -> Result { self.client.try_unsubscribe(topic) } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index cd0568ada..5c4444135 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -77,7 +77,7 @@ pub struct EventLoop { /// Requests handle to send requests pub(crate) requests_tx: Sender, /// Pending packets from last session - pub pending: VecDeque, + pub(crate) pending: VecDeque, /// Network connection to the broker network: Option, /// Keep alive time diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index d4dbf6b0b..a07bbb2e3 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -7,7 +7,6 @@ use std::{ pin::Pin, sync::Arc, }; -use tokio::sync::oneshot; mod client; mod eventloop; @@ -16,7 +15,7 @@ pub mod mqttbytes; mod state; use crate::{NetworkOptions, Transport}; -use crate::{Outgoing, Pkid}; +use crate::{NoticeTx, Outgoing}; use mqttbytes::v5::*; @@ -36,16 +35,16 @@ pub type Incoming = Packet; /// handled one by one. #[derive(Debug)] pub enum Request { - Publish(Option>, Publish), + Publish(Option, Publish), PubAck(PubAck), PubRec(PubRec), PubComp(PubComp), - PubRel(PubRel), + PubRel(Option, PubRel), PingReq, PingResp, - Subscribe(Option>, Subscribe), + Subscribe(Option, Subscribe), SubAck(SubAck), - Unsubscribe(Option>, Unsubscribe), + Unsubscribe(Option, Unsubscribe), UnsubAck(UnsubAck), Disconnect, } @@ -56,7 +55,7 @@ impl Clone for Request { Self::Publish(_, p) => Self::Publish(None, p.clone()), Self::PubAck(p) => Self::PubAck(p.clone()), Self::PubRec(p) => Self::PubRec(p.clone()), - Self::PubRel(p) => Self::PubRel(p.clone()), + Self::PubRel(_, p) => Self::PubRel(None, p.clone()), Self::PubComp(p) => Self::PubComp(p.clone()), Self::Subscribe(_, p) => Self::Subscribe(None, p.clone()), Self::SubAck(p) => Self::SubAck(p.clone()), @@ -75,7 +74,7 @@ impl PartialEq for Request { (Self::Publish(_, p1), Self::Publish(_, p2)) => p1 == p2, (Self::PubAck(p1), Self::PubAck(p2)) => p1 == p2, (Self::PubRec(p1), Self::PubRec(p2)) => p1 == p2, - (Self::PubRel(p1), Self::PubRel(p2)) => p1 == p2, + (Self::PubRel(_, p1), Self::PubRel(_, p2)) => p1 == p2, (Self::PubComp(p1), Self::PubComp(p2)) => p1 == p2, (Self::Subscribe(_, p1), Self::Subscribe(_, p2)) => p1 == p2, (Self::SubAck(p1), Self::SubAck(p2)) => p1 == p2, diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index e2ae42625..b324d4718 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,4 +1,4 @@ -use crate::Pkid; +use crate::{NoticeError, NoticeTx}; use super::mqttbytes::v5::{ ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, @@ -13,7 +13,6 @@ use bytes::Bytes; use fixedbitset::FixedBitSet; use std::collections::{HashMap, VecDeque}; use std::{io, time::Instant}; -use tokio::sync::oneshot; /// Errors during state handling #[derive(Debug, thiserror::Error)] @@ -89,7 +88,7 @@ impl From for StateError { // This is done for 2 reasons // Bad acks or out of order acks aren't O(n) causing cpu spikes // Any missing acks from the broker are detected during the next recycled use of packet ids -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct MqttState { /// Status of last ping pub await_pingresp: bool, @@ -106,13 +105,17 @@ pub struct MqttState { /// Number of outgoing inflight publishes pub(crate) inflight: u16, /// Outgoing QoS 1, 2 publishes which aren't acked yet - pub(crate) outgoing_pub: Vec>, + pub(crate) outgoing_pub: HashMap)>, /// Packet ids of released QoS 2 publishes - pub(crate) outgoing_rel: FixedBitSet, + pub(crate) outgoing_rel: HashMap>, /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, + + outgoing_sub: HashMap>, + outgoing_unsub: HashMap>, + /// Last collision due to broker not acking in order - pub collision: Option, + pub(crate) collision: Option<(Publish, Option)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately @@ -139,10 +142,12 @@ impl MqttState { last_outgoing: Instant::now(), last_pkid: 0, inflight: 0, + outgoing_pub: HashMap::new(), + outgoing_rel: HashMap::new(), // index 0 is wasted as 0 is not a valid packet id - outgoing_pub: vec![None; max_inflight as usize + 1], - outgoing_rel: FixedBitSet::with_capacity(max_inflight as usize + 1), incoming_pub: FixedBitSet::with_capacity(u16::MAX as usize + 1), + outgoing_sub: HashMap::new(), + outgoing_unsub: HashMap::new(), collision: None, // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), @@ -159,16 +164,14 @@ impl MqttState { pub fn clean(&mut self) -> Vec { let mut pending = Vec::with_capacity(100); // remove and collect pending publishes - for publish in self.outgoing_pub.iter_mut() { - if let Some(publish) = publish.take() { - let request = Request::Publish(None, publish); - pending.push(request); - } + for (_, (publish, tx)) in self.outgoing_pub.drain() { + let request = Request::Publish(tx, publish); + pending.push(request); } // remove and collect pending releases - for pkid in self.outgoing_rel.ones() { - let request = Request::PubRel(PubRel::new(pkid as u16, None)); + for (pkid, tx) in self.outgoing_rel.drain() { + let request = Request::PubRel(tx, PubRel::new(pkid, None)); pending.push(request); } self.outgoing_rel.clear(); @@ -194,7 +197,7 @@ impl MqttState { ) -> Result, StateError> { let packet = match request { Request::Publish(tx, publish) => self.outgoing_publish(publish, tx)?, - Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, + Request::PubRel(tx, pubrel) => self.outgoing_pubrel(pubrel, tx)?, Request::Subscribe(tx, subscribe) => self.outgoing_subscribe(subscribe, tx)?, Request::Unsubscribe(tx, unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, Request::PingReq => self.outgoing_ping()?, @@ -250,14 +253,35 @@ impl MqttState { &mut self, suback: &mut SubAck, ) -> Result, StateError> { + if suback.pkid > self.max_outgoing_inflight { + error!("Unsolicited suback packet: {:?}", suback.pkid); + return Err(StateError::Unsolicited(suback.pkid)); + } + + let tx = self + .outgoing_sub + .remove(&suback.pkid) + .ok_or(StateError::Unsolicited(suback.pkid))?; + for reason in suback.return_codes.iter() { match reason { SubscribeReasonCode::Success(qos) => { debug!("SubAck Pkid = {:?}, QoS = {:?}", suback.pkid, qos); } - _ => return Err(StateError::SubFail { reason: *reason }), + _ => { + if let Some(tx) = tx { + tx.error(NoticeError::V5Subscribe(*reason)) + } + + return Err(StateError::SubFail { reason: *reason }); + } } } + + if let Some(tx) = tx { + tx.success() + } + Ok(None) } @@ -265,11 +289,30 @@ impl MqttState { &mut self, unsuback: &mut UnsubAck, ) -> Result, StateError> { + if unsuback.pkid > self.max_outgoing_inflight { + error!("Unsolicited unsuback packet: {:?}", unsuback.pkid); + return Err(StateError::Unsolicited(unsuback.pkid)); + } + + let tx = self + .outgoing_unsub + .remove(&unsuback.pkid) + .ok_or(StateError::Unsolicited(unsuback.pkid))?; + for reason in unsuback.reasons.iter() { if reason != &UnsubAckReason::Success { + if let Some(tx) = tx { + tx.error(NoticeError::V5Unsubscribe(*reason)) + } + return Err(StateError::UnsubFail { reason: *reason }); } } + + if let Some(tx) = tx { + tx.success() + } + Ok(None) } @@ -362,16 +405,19 @@ impl MqttState { } fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { - let publish = self - .outgoing_pub - .get_mut(puback.pkid as usize) - .ok_or(StateError::Unsolicited(puback.pkid))?; - - if publish.take().is_none() { + if puback.pkid > self.max_outgoing_inflight { error!("Unsolicited puback packet: {:?}", puback.pkid); return Err(StateError::Unsolicited(puback.pkid)); } + if let (_, Some(tx)) = self + .outgoing_pub + .remove(&puback.pkid) + .ok_or(StateError::Unsolicited(puback.pkid))? + { + tx.success() + } + self.inflight -= 1; if puback.reason != PubAckReason::Success @@ -382,32 +428,32 @@ impl MqttState { }); } - if let Some(publish) = self.check_collision(puback.pkid) { - self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); + let packet = self.check_collision(puback.pkid).map(|(publish, tx)| { + self.outgoing_pub + .insert(publish.pkid, (publish.clone(), tx)); self.inflight += 1; - let pkid = publish.pkid; - let event = Event::Outgoing(Outgoing::Publish(pkid)); + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); self.collision_ping_count = 0; - return Ok(Some(Packet::Publish(publish))); - } + Packet::Publish(publish) + }); - Ok(None) + Ok(packet) } fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { - let publish = self - .outgoing_pub - .get_mut(pubrec.pkid as usize) - .ok_or(StateError::Unsolicited(pubrec.pkid))?; - - if publish.take().is_none() { + if pubrec.pkid > self.max_outgoing_inflight { error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); return Err(StateError::Unsolicited(pubrec.pkid)); } + let (_, tx) = self + .outgoing_pub + .remove(&pubrec.pkid) + .ok_or(StateError::Unsolicited(pubrec.pkid))?; + if pubrec.reason != PubRecReason::Success && pubrec.reason != PubRecReason::NoMatchingSubscribers { @@ -417,7 +463,7 @@ impl MqttState { } // NOTE: Inflight - 1 for qos2 in comp - self.outgoing_rel.insert(pubrec.pkid as usize); + self.outgoing_rel.insert(pubrec.pkid, tx); let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); self.events.push_back(event); @@ -444,29 +490,30 @@ impl MqttState { } fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { - let outgoing = self.check_collision(pubcomp.pkid).map(|publish| { - let pkid = publish.pkid; - let event = Event::Outgoing(Outgoing::Publish(pkid)); - self.events.push_back(event); - self.collision_ping_count = 0; - - Packet::Publish(publish) - }); - - if !self.outgoing_rel.contains(pubcomp.pkid as usize) { + if pubcomp.pkid > self.max_outgoing_inflight { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); return Err(StateError::Unsolicited(pubcomp.pkid)); } - self.outgoing_rel.set(pubcomp.pkid as usize, false); - - if pubcomp.reason != PubCompReason::Success { - return Err(StateError::PubCompFail { - reason: pubcomp.reason, - }); + if let Some(tx) = self + .outgoing_rel + .remove(&pubcomp.pkid) + .ok_or(StateError::Unsolicited(pubcomp.pkid))? + { + tx.success() } self.inflight -= 1; - Ok(outgoing) + let packet = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { + self.outgoing_pub + .insert(pubcomp.pkid, (publish.clone(), tx)); + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + + Packet::Publish(publish) + }); + + Ok(packet) } fn handle_incoming_pingresp(&mut self) -> Result, StateError> { @@ -479,29 +526,18 @@ impl MqttState { fn outgoing_publish( &mut self, mut publish: Publish, - pkid_tx: Option>, + notice_tx: Option, ) -> Result, StateError> { - // NOTE: pkid promise need not be fulfilled for QoS 0, - // user should know this but still handled in Client. if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); } let pkid = publish.pkid; - // Fulfill the pkid promise - if let Some(pkid_tx) = pkid_tx { - _ = pkid_tx.send(pkid); - } - if self - .outgoing_pub - .get(publish.pkid as usize) - .ok_or(StateError::Unsolicited(publish.pkid))? - .is_some() - { + if self.outgoing_pub.get(&publish.pkid).is_some() { info!("Collision on packet id = {:?}", publish.pkid); - self.collision = Some(publish); + self.collision = Some((publish, notice_tx)); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); return Ok(None); @@ -509,9 +545,11 @@ impl MqttState { // if there is an existing publish at this pkid, this implies that broker hasn't acked this // packet yet. This error is possible only when broker isn't acking sequentially - self.outgoing_pub[pkid as usize] = Some(publish.clone()); + self.outgoing_pub.insert(pkid, (publish.clone(), notice_tx)); self.inflight += 1; - }; + } else if let Some(tx) = notice_tx { + tx.success() + } debug!( "Publish. Topic = {}, Pkid = {:?}, Payload Size = {:?}", @@ -541,8 +579,12 @@ impl MqttState { Ok(Some(Packet::Publish(publish))) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { - let pubrel = self.save_pubrel(pubrel)?; + fn outgoing_pubrel( + &mut self, + pubrel: PubRel, + notice_tx: Option, + ) -> Result, StateError> { + let pubrel = self.save_pubrel(pubrel, notice_tx)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); @@ -603,7 +645,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, - pkid_tx: Option>, + notice_tx: Option, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -611,17 +653,13 @@ impl MqttState { let pkid = self.next_pkid(); subscription.pkid = pkid; - // Fulfill the pkid promise - if let Some(pkid_tx) = pkid_tx { - _ = pkid_tx.send(pkid); - } debug!( "Subscribe. Topics = {:?}, Pkid = {:?}", subscription.filters, subscription.pkid ); - let pkid = subscription.pkid; + self.outgoing_sub.insert(pkid, notice_tx); let event = Event::Outgoing(Outgoing::Subscribe(pkid)); self.events.push_back(event); @@ -631,21 +669,17 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, - pkid_tx: Option>, + notice_tx: Option, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; - // Fulfill the pkid promise - if let Some(pkid_tx) = pkid_tx { - _ = pkid_tx.send(pkid); - } debug!( "Unsubscribe. Topics = {:?}, Pkid = {:?}", unsub.filters, unsub.pkid ); - let pkid = unsub.pkid; + self.outgoing_unsub.insert(pkid, notice_tx); let event = Event::Outgoing(Outgoing::Unsubscribe(pkid)); self.events.push_back(event); @@ -663,8 +697,8 @@ impl MqttState { Ok(Some(Packet::Disconnect(Disconnect::new(reason)))) } - fn check_collision(&mut self, pkid: u16) -> Option { - if let Some(publish) = &self.collision { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Option)> { + if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); } @@ -673,7 +707,11 @@ impl MqttState { None } - fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { + fn save_pubrel( + &mut self, + mut pubrel: PubRel, + notice_tx: Option, + ) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets 0 => { @@ -683,7 +721,7 @@ impl MqttState { _ => pubrel, }; - self.outgoing_rel.insert(pubrel.pkid as usize); + self.outgoing_rel.insert(pubrel.pkid, notice_tx); self.inflight += 1; Ok(pubrel) } @@ -915,8 +953,8 @@ mod test { mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 0); - assert!(mqtt.outgoing_pub[1].is_none()); - assert!(mqtt.outgoing_pub[2].is_none()); + assert!(mqtt.outgoing_pub.get(&1).is_none()); + assert!(mqtt.outgoing_pub.get(&2).is_none()); } #[test] @@ -947,11 +985,11 @@ mod test { assert_eq!(mqtt.inflight, 2); // check if the remaining element's pkid is 1 - let backup = mqtt.outgoing_pub[1].clone(); - assert_eq!(backup.unwrap().pkid, 1); + let (backup, _) = mqtt.outgoing_pub.get(&1).unwrap(); + assert_eq!(backup.pkid, 1); - // check if the qos2 element's release pkid is 2 - assert!(mqtt.outgoing_rel.contains(2)); + // check if the qos2 element's release pkik has been set + assert!(mqtt.outgoing_rel.get(&2).is_some()); } #[test] From 16c8115122ec585368c676106af8761465923f48 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 28 May 2024 00:56:45 +0530 Subject: [PATCH 18/20] refactor: allow `Request::clone` --- rumqttc/src/client.rs | 88 ++++++++++++++++++---------------- rumqttc/src/eventloop.rs | 21 +++++---- rumqttc/src/lib.rs | 81 ++++++-------------------------- rumqttc/src/notice.rs | 58 +++++++++++++++++++++++ rumqttc/src/state.rs | 38 ++++++++------- rumqttc/src/v5/client.rs | 94 +++++++++++++++++++++---------------- rumqttc/src/v5/eventloop.rs | 21 +++++---- rumqttc/src/v5/mod.rs | 53 +++------------------ rumqttc/src/v5/state.rs | 21 +++++---- 9 files changed, 232 insertions(+), 243 deletions(-) create mode 100644 rumqttc/src/notice.rs diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index adfc7c1c7..a1b9faee1 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -23,15 +23,15 @@ pub enum ClientError { TryRequest(Request), } -impl From> for ClientError { - fn from(e: SendError) -> Self { - Self::Request(e.into_inner()) +impl From, Request)>> for ClientError { + fn from(e: SendError<(Option, Request)>) -> Self { + Self::Request(e.into_inner().1) } } -impl From> for ClientError { - fn from(e: TrySendError) -> Self { - Self::TryRequest(e.into_inner()) +impl From, Request)>> for ClientError { + fn from(e: TrySendError<(Option, Request)>) -> Self { + Self::TryRequest(e.into_inner().1) } } @@ -44,7 +44,7 @@ impl From> for ClientError { /// from the broker, i.e. move ahead. #[derive(Clone, Debug)] pub struct AsyncClient { - request_tx: Sender, + request_tx: Sender<(Option, Request)>, } impl AsyncClient { @@ -64,7 +64,7 @@ impl AsyncClient { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_senders(request_tx: Sender) -> AsyncClient { + pub fn from_senders(request_tx: Sender<(Option, Request)>) -> AsyncClient { AsyncClient { request_tx } } @@ -87,11 +87,11 @@ impl AsyncClient { let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let notice_tx = Some(NoticeTx(notice_tx)); - let publish = Request::Publish(notice_tx, publish); + let publish = Request::Publish(publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } - self.request_tx.send_async(publish).await?; + self.request_tx.send_async((notice_tx, publish)).await?; Ok(NoticeFuture(notice_rx)) } @@ -121,11 +121,11 @@ impl AsyncClient { Some(notice_tx) }; - let publish = Request::Publish(notice_tx, publish); + let publish = Request::Publish(publish); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } - self.request_tx.try_send(publish)?; + self.request_tx.try_send((notice_tx, publish))?; Ok(NoticeFuture(notice_rx)) } @@ -134,7 +134,7 @@ impl AsyncClient { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_tx.send_async(ack).await?; + self.request_tx.send_async((None, ack)).await?; } Ok(()) } @@ -143,7 +143,7 @@ impl AsyncClient { pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_tx.try_send(ack)?; + self.request_tx.try_send((None, ack))?; } Ok(()) } @@ -165,8 +165,8 @@ impl AsyncClient { let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let notice_tx = Some(NoticeTx(notice_tx)); - let publish = Request::Publish(notice_tx, publish); - self.request_tx.send_async(publish).await?; + let publish = Request::Publish(publish); + self.request_tx.send_async((notice_tx, publish)).await?; Ok(NoticeFuture(notice_rx)) } @@ -180,11 +180,13 @@ impl AsyncClient { let subscribe = Subscribe::new(&topic, qos); let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let notice_tx = NoticeTx(notice_tx); - let request = Request::Subscribe(Some(notice_tx), subscribe); + let request = Request::Subscribe(subscribe); if !valid_filter(&topic) { return Err(ClientError::Request(request)); } - self.request_tx.send_async(request).await?; + self.request_tx + .send_async((Some(notice_tx), request)) + .await?; Ok(NoticeFuture(notice_rx)) } @@ -198,11 +200,11 @@ impl AsyncClient { let subscribe = Subscribe::new(&topic, qos); let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let notice_tx = NoticeTx(notice_tx); - let request = Request::Subscribe(Some(notice_tx), subscribe); + let request = Request::Subscribe(subscribe); if !valid_filter(&topic) { return Err(ClientError::TryRequest(request)); } - self.request_tx.try_send(request)?; + self.request_tx.try_send((Some(notice_tx), request))?; Ok(NoticeFuture(notice_rx)) } @@ -216,11 +218,13 @@ impl AsyncClient { let subscribe = Subscribe::new_many(topics_iter); let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let notice_tx = NoticeTx(notice_tx); - let request = Request::Subscribe(Some(notice_tx), subscribe); + let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } - self.request_tx.send_async(request).await?; + self.request_tx + .send_async((Some(notice_tx), request)) + .await?; Ok(NoticeFuture(notice_rx)) } @@ -234,11 +238,11 @@ impl AsyncClient { let subscribe = Subscribe::new_many(topics_iter); let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let notice_tx = NoticeTx(notice_tx); - let request = Request::Subscribe(Some(notice_tx), subscribe); + let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::TryRequest(request)); } - self.request_tx.try_send(request)?; + self.request_tx.try_send((Some(notice_tx), request))?; Ok(NoticeFuture(notice_rx)) } @@ -251,8 +255,10 @@ impl AsyncClient { let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let notice_tx = NoticeTx(notice_tx); - let request = Request::Unsubscribe(Some(notice_tx), unsubscribe); - self.request_tx.send_async(request).await?; + let request = Request::Unsubscribe(unsubscribe); + self.request_tx + .send_async((Some(notice_tx), request)) + .await?; Ok(NoticeFuture(notice_rx)) } @@ -262,22 +268,22 @@ impl AsyncClient { let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let notice_tx = NoticeTx(notice_tx); - let request = Request::Unsubscribe(Some(notice_tx), unsubscribe); - self.request_tx.try_send(request)?; + let request = Request::Unsubscribe(unsubscribe); + self.request_tx.try_send((Some(notice_tx), request))?; Ok(NoticeFuture(notice_rx)) } /// Sends a MQTT disconnect to the `EventLoop` pub async fn disconnect(&self) -> Result<(), ClientError> { let request = Request::Disconnect(Disconnect); - self.request_tx.send_async(request).await?; + self.request_tx.send_async((None, request)).await?; Ok(()) } /// Attempts to send a MQTT disconnect to the `EventLoop` pub fn try_disconnect(&self) -> Result<(), ClientError> { let request = Request::Disconnect(Disconnect); - self.request_tx.try_send(request)?; + self.request_tx.try_send((None, request))?; Ok(()) } } @@ -326,7 +332,7 @@ impl Client { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_sender(request_tx: Sender) -> Client { + pub fn from_sender(request_tx: Sender<(Option, Request)>) -> Client { Client { client: AsyncClient::from_senders(request_tx), } @@ -351,11 +357,11 @@ impl Client { let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let notice_tx = Some(NoticeTx(notice_tx)); - let publish = Request::Publish(notice_tx, publish); + let publish = Request::Publish(publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } - self.client.request_tx.send(publish)?; + self.client.request_tx.send((notice_tx, publish))?; Ok(NoticeFuture(notice_rx)) } @@ -378,7 +384,7 @@ impl Client { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.client.request_tx.send(ack)?; + self.client.request_tx.send((None, ack))?; } Ok(()) } @@ -399,11 +405,11 @@ impl Client { let subscribe = Subscribe::new(&topic, qos); let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let notice_tx = NoticeTx(notice_tx); - let request = Request::Subscribe(Some(notice_tx), subscribe); + let request = Request::Subscribe(subscribe); if !valid_filter(&topic) { return Err(ClientError::Request(request)); } - self.client.request_tx.send(request)?; + self.client.request_tx.send((Some(notice_tx), request))?; Ok(NoticeFuture(notice_rx)) } @@ -426,11 +432,11 @@ impl Client { let subscribe = Subscribe::new_many(topics_iter); let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let notice_tx = NoticeTx(notice_tx); - let request = Request::Subscribe(Some(notice_tx), subscribe); + let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } - self.client.request_tx.send(request)?; + self.client.request_tx.send((Some(notice_tx), request))?; Ok(NoticeFuture(notice_rx)) } @@ -447,8 +453,8 @@ impl Client { let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let notice_tx = NoticeTx(notice_tx); - let request = Request::Unsubscribe(Some(notice_tx), unsubscribe); - self.client.request_tx.send(request)?; + let request = Request::Unsubscribe(unsubscribe); + self.client.request_tx.send((Some(notice_tx), request))?; Ok(NoticeFuture(notice_rx)) } @@ -460,7 +466,7 @@ impl Client { /// Sends a MQTT disconnect to the `EventLoop` pub fn disconnect(&self) -> Result<(), ClientError> { let request = Request::Disconnect(Disconnect); - self.client.request_tx.send(request)?; + self.client.request_tx.send((None, request))?; Ok(()) } diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index d31690d99..7fef45ff6 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -1,3 +1,4 @@ +use crate::notice::NoticeTx; use crate::{framed::Network, Transport}; use crate::{Incoming, MqttState, NetworkOptions, Packet, Request, StateError}; use crate::{MqttOptions, Outgoing}; @@ -75,11 +76,11 @@ pub struct EventLoop { /// Current state of the connection pub state: MqttState, /// Request stream - requests_rx: Receiver, + requests_rx: Receiver<(Option, Request)>, /// Requests handle to send requests - pub(crate) requests_tx: Sender, + pub(crate) requests_tx: Sender<(Option, Request)>, /// Pending packets from last session - pub pending: VecDeque, + pub pending: VecDeque<(Option, Request)>, /// Network connection to the broker pub network: Option, /// Keep alive time @@ -132,7 +133,7 @@ impl EventLoop { // drain requests from channel which weren't yet received let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect(); - requests_in_channel.retain(|request| { + requests_in_channel.retain(|(_, request)| { match request { Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack _ => true, @@ -241,8 +242,8 @@ impl EventLoop { &self.requests_rx, self.mqtt_options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { - Ok(request) => { - if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { + Ok((tx, request)) => { + if let Some(outgoing) = self.state.handle_outgoing_packet(tx, request)? { network.write(outgoing).await?; } match time::timeout(network_timeout, network.flush()).await { @@ -260,7 +261,7 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive); - if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq))? { + if let Some(outgoing) = self.state.handle_outgoing_packet(None, Request::PingReq(PingReq))? { network.write(outgoing).await?; } match time::timeout(network_timeout, network.flush()).await { @@ -282,10 +283,10 @@ impl EventLoop { } async fn next_request( - pending: &mut VecDeque, - rx: &Receiver, + pending: &mut VecDeque<(Option, Request)>, + rx: &Receiver<(Option, Request)>, pending_throttle: Duration, - ) -> Result { + ) -> Result<(Option, Request), ConnectionError> { if !pending.is_empty() { time::sleep(pending_throttle).await; // We must call .pop_front() AFTER sleep() otherwise we would have diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index beab9ffe0..31a32b571 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -109,6 +109,7 @@ mod client; mod eventloop; mod framed; pub mod mqttbytes; +mod notice; mod state; pub mod v5; @@ -140,6 +141,8 @@ pub use client::{ pub use eventloop::{ConnectionError, Event, EventLoop}; pub use mqttbytes::v4::*; pub use mqttbytes::*; +use notice::NoticeTx; +pub use notice::{NoticeError, NoticeFuture}; #[cfg(feature = "use-rustls")] use rustls_native_certs::load_native_certs; pub use state::{MqttState, StateError}; @@ -160,60 +163,6 @@ pub use proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; -use v5::mqttbytes::v5::{SubscribeReasonCode as V5SubscribeReasonCode, UnsubAckReason}; - -#[derive(Debug, thiserror::Error)] -pub enum NoticeError { - #[error("Eventloop dropped Sender")] - Recv, - #[error(" v4 Subscription Failure Reason Code: {0:?}")] - V4Subscribe(SubscribeReasonCode), - #[error(" v5 Subscription Failure Reason Code: {0:?}")] - V5Subscribe(V5SubscribeReasonCode), - #[error(" v5 Unsubscription Failure Reason: {0:?}")] - V5Unsubscribe(UnsubAckReason), -} - -impl From for NoticeError { - fn from(_: oneshot::error::RecvError) -> Self { - Self::Recv - } -} - -type NoticeResult = Result<(), NoticeError>; - -/// A token through which the user is notified of the publish/subscribe/unsubscribe packet being acked by the broker. -#[derive(Debug)] -pub struct NoticeFuture(oneshot::Receiver); - -impl NoticeFuture { - /// Wait for broker to acknowledge by blocking the current thread - /// - /// # Panics - /// Panics if called in an async context - pub fn wait(self) -> NoticeResult { - self.0.blocking_recv()? - } - - /// Await the packet acknowledgement from broker, without blocking the current thread - pub async fn wait_async(self) -> NoticeResult { - self.0.await? - } -} - -#[derive(Debug)] -pub struct NoticeTx(oneshot::Sender); - -impl NoticeTx { - fn success(self) { - _ = self.0.send(Ok(())); - } - - fn error(self, e: NoticeError) { - _ = self.0.send(Err(e)); - } -} - /// Current outgoing activity on the eventloop #[derive(Debug, Clone, PartialEq, Eq)] pub enum Outgoing { @@ -243,18 +192,18 @@ pub enum Outgoing { /// Requests by the client to mqtt event loop. Request are /// handled one by one. -#[derive(Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum Request { - Publish(Option, Publish), + Publish(Publish), PubAck(PubAck), PubRec(PubRec), PubComp(PubComp), - PubRel(Option, PubRel), + PubRel(PubRel), PingReq(PingReq), PingResp(PingResp), - Subscribe(Option, Subscribe), + Subscribe(Subscribe), SubAck(SubAck), - Unsubscribe(Option, Unsubscribe), + Unsubscribe(Unsubscribe), UnsubAck(UnsubAck), Disconnect(Disconnect), } @@ -262,16 +211,16 @@ pub enum Request { impl Request { fn size(&self) -> usize { match &self { - Request::Publish(_, publish) => publish.size(), + Request::Publish(publish) => publish.size(), Request::PubAck(puback) => puback.size(), Request::PubRec(pubrec) => pubrec.size(), Request::PubComp(pubcomp) => pubcomp.size(), - Request::PubRel(_, pubrel) => pubrel.size(), + Request::PubRel(pubrel) => pubrel.size(), Request::PingReq(pingreq) => pingreq.size(), Request::PingResp(pingresp) => pingresp.size(), - Request::Subscribe(_, subscribe) => subscribe.size(), + Request::Subscribe(subscribe) => subscribe.size(), Request::SubAck(suback) => suback.size(), - Request::Unsubscribe(_, unsubscribe) => unsubscribe.size(), + Request::Unsubscribe(unsubscribe) => unsubscribe.size(), Request::UnsubAck(unsuback) => unsuback.size(), Request::Disconnect(disconn) => disconn.size(), } @@ -280,19 +229,19 @@ impl Request { impl From for Request { fn from(publish: Publish) -> Request { - Request::Publish(None, publish) + Request::Publish(publish) } } impl From for Request { fn from(subscribe: Subscribe) -> Request { - Request::Subscribe(None, subscribe) + Request::Subscribe(subscribe) } } impl From for Request { fn from(unsubscribe: Unsubscribe) -> Request { - Request::Unsubscribe(None, unsubscribe) + Request::Unsubscribe(unsubscribe) } } diff --git a/rumqttc/src/notice.rs b/rumqttc/src/notice.rs new file mode 100644 index 000000000..2b3eb0ebd --- /dev/null +++ b/rumqttc/src/notice.rs @@ -0,0 +1,58 @@ +use tokio::sync::oneshot; + +use crate::{ + v5::mqttbytes::v5::{SubscribeReasonCode as V5SubscribeReasonCode, UnsubAckReason}, + SubscribeReasonCode, +}; + +#[derive(Debug, thiserror::Error)] +pub enum NoticeError { + #[error("Eventloop dropped Sender")] + Recv, + #[error(" v4 Subscription Failure Reason Code: {0:?}")] + V4Subscribe(SubscribeReasonCode), + #[error(" v5 Subscription Failure Reason Code: {0:?}")] + V5Subscribe(V5SubscribeReasonCode), + #[error(" v5 Unsubscription Failure Reason: {0:?}")] + V5Unsubscribe(UnsubAckReason), +} + +impl From for NoticeError { + fn from(_: oneshot::error::RecvError) -> Self { + Self::Recv + } +} + +type NoticeResult = Result<(), NoticeError>; + +/// A token through which the user is notified of the publish/subscribe/unsubscribe packet being acked by the broker. +#[derive(Debug)] +pub struct NoticeFuture(pub(crate) oneshot::Receiver); + +impl NoticeFuture { + /// Wait for broker to acknowledge by blocking the current thread + /// + /// # Panics + /// Panics if called in an async context + pub fn wait(self) -> NoticeResult { + self.0.blocking_recv()? + } + + /// Await the packet acknowledgement from broker, without blocking the current thread + pub async fn wait_async(self) -> NoticeResult { + self.0.await? + } +} + +#[derive(Debug)] +pub struct NoticeTx(pub(crate) oneshot::Sender); + +impl NoticeTx { + pub(crate) fn success(self) { + _ = self.0.send(Ok(())); + } + + pub(crate) fn error(self, e: NoticeError) { + _ = self.0.send(Err(e)); + } +} diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 1c48f467c..77b962ee4 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -1,4 +1,5 @@ -use crate::{Event, Incoming, NoticeError, NoticeTx, Outgoing, Request}; +use crate::notice::NoticeTx; +use crate::{Event, Incoming, NoticeError, Outgoing, Request}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; @@ -73,7 +74,7 @@ pub struct MqttState { outgoing_unsub: HashMap>, /// Last collision due to broker not acking in order - pub collision: Option<(Publish, Option)>, + pub(crate) collision: Option<(Publish, Option)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately @@ -108,18 +109,18 @@ impl MqttState { } /// Returns inflight outgoing packets and clears internal queues - pub fn clean(&mut self) -> Vec { + pub fn clean(&mut self) -> Vec<(Option, Request)> { let mut pending = Vec::with_capacity(100); for (_, (publish, tx)) in self.outgoing_pub.drain() { - let request = Request::Publish(tx, publish); - pending.push(request); + let request = Request::Publish(publish); + pending.push((tx, request)); } // remove and collect pending releases for (pkid, tx) in self.outgoing_rel.drain() { - let request = Request::PubRel(tx, PubRel::new(pkid)); - pending.push(request); + let request = Request::PubRel(PubRel::new(pkid)); + pending.push((tx, request)); } self.outgoing_rel.clear(); @@ -140,13 +141,14 @@ impl MqttState { /// be put on to the network by the eventloop pub fn handle_outgoing_packet( &mut self, + tx: Option, request: Request, ) -> Result, StateError> { let packet = match request { - Request::Publish(tx, publish) => self.outgoing_publish(publish, tx)?, - Request::PubRel(tx, pubrel) => self.outgoing_pubrel(pubrel, tx)?, - Request::Subscribe(tx, subscribe) => self.outgoing_subscribe(subscribe, tx)?, - Request::Unsubscribe(tx, unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, + Request::Publish(publish) => self.outgoing_publish(publish, tx)?, + Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel, tx)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe, tx)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, Request::PingReq(_) => self.outgoing_ping()?, Request::Disconnect(_) => self.outgoing_disconnect()?, Request::PubAck(puback) => self.outgoing_puback(puback)?, @@ -851,7 +853,7 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(None, publish)) + mqtt.handle_outgoing_packet(None, Request::Publish(publish)) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) .unwrap(); @@ -947,8 +949,8 @@ mod test { mqtt.last_puback = 3; let requests = mqtt.clean(); let res = vec![6, 1, 2, 3]; - for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(_, publish) = req { + for ((_, req), idx) in requests.iter().zip(res) { + if let Request::Publish(publish) = req { assert_eq!(publish.pkid, idx); } else { unreachable!() @@ -959,8 +961,8 @@ mod test { mqtt.last_puback = 0; let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; - for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(_, publish) = req { + for ((_, req), idx) in requests.iter().zip(res) { + if let Request::Publish(publish) = req { assert_eq!(publish.pkid, idx); } else { unreachable!() @@ -971,8 +973,8 @@ mod test { mqtt.last_puback = 6; let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; - for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(_, publish) = req { + for ((_, req), idx) in requests.iter().zip(res) { + if let Request::Publish(publish) = req { assert_eq!(publish.pkid, idx); } else { unreachable!() diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 7b1e39ca6..292619672 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -25,15 +25,15 @@ pub enum ClientError { TryRequest(Request), } -impl From> for ClientError { - fn from(e: SendError) -> Self { - Self::Request(e.into_inner()) +impl From, Request)>> for ClientError { + fn from(e: SendError<(Option, Request)>) -> Self { + Self::Request(e.into_inner().1) } } -impl From> for ClientError { - fn from(e: TrySendError) -> Self { - Self::TryRequest(e.into_inner()) +impl From, Request)>> for ClientError { + fn from(e: TrySendError<(Option, Request)>) -> Self { + Self::TryRequest(e.into_inner().1) } } @@ -46,7 +46,7 @@ impl From> for ClientError { /// from the broker, i.e. move ahead. #[derive(Clone, Debug)] pub struct AsyncClient { - request_tx: Sender, + request_tx: Sender<(Option, Request)>, } impl AsyncClient { @@ -66,7 +66,7 @@ impl AsyncClient { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_senders(request_tx: Sender) -> AsyncClient { + pub fn from_senders(request_tx: Sender<(Option, Request)>) -> AsyncClient { AsyncClient { request_tx } } @@ -91,11 +91,11 @@ impl AsyncClient { // Fulfill instantly for QoS 0 let notice_tx = Some(NoticeTx(notice_tx)); - let publish = Request::Publish(notice_tx, publish); + let publish = Request::Publish(publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } - self.request_tx.send_async(publish).await?; + self.request_tx.send_async((notice_tx, publish)).await?; Ok(NoticeFuture(notice_rx)) } @@ -150,11 +150,11 @@ impl AsyncClient { // Fulfill instantly for QoS 0 let notice_tx = Some(NoticeTx(notice_tx)); - let publish = Request::Publish(notice_tx, publish); + let publish = Request::Publish(publish); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } - self.request_tx.try_send(publish)?; + self.request_tx.try_send((notice_tx, publish))?; Ok(NoticeFuture(notice_rx)) } @@ -192,7 +192,7 @@ impl AsyncClient { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_tx.send_async(ack).await?; + self.request_tx.send_async((None, ack)).await?; } Ok(()) } @@ -201,7 +201,7 @@ impl AsyncClient { pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_tx.try_send(ack)?; + self.request_tx.try_send((None, ack))?; } Ok(()) } @@ -226,11 +226,11 @@ impl AsyncClient { // Fulfill instantly for QoS 0 let notice_tx = Some(NoticeTx(notice_tx)); - let publish = Request::Publish(notice_tx, publish); + let publish = Request::Publish(publish); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } - self.request_tx.send_async(publish).await?; + self.request_tx.send_async((notice_tx, publish)).await?; Ok(NoticeFuture(notice_rx)) } @@ -274,11 +274,13 @@ impl AsyncClient { let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let request = Request::Subscribe(Some(NoticeTx(notice_tx)), subscribe); + let request = Request::Subscribe(subscribe); if !is_filter_valid { return Err(ClientError::Request(request)); } - self.request_tx.send_async(request).await?; + self.request_tx + .send_async((Some(NoticeTx(notice_tx)), request)) + .await?; Ok(NoticeFuture(notice_rx)) } @@ -310,11 +312,12 @@ impl AsyncClient { let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let request = Request::Subscribe(Some(NoticeTx(notice_tx)), subscribe); + let request = Request::Subscribe(subscribe); if !is_filter_valid { return Err(ClientError::TryRequest(request)); } - self.request_tx.try_send(request)?; + self.request_tx + .try_send((Some(NoticeTx(notice_tx)), request))?; Ok(NoticeFuture(notice_rx)) } @@ -348,12 +351,14 @@ impl AsyncClient { let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter, properties); let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let request = Request::Subscribe(Some(NoticeTx(notice_tx)), subscribe); + let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } - self.request_tx.send_async(request).await?; + self.request_tx + .send_async((Some(NoticeTx(notice_tx)), request)) + .await?; Ok(NoticeFuture(notice_rx)) } @@ -388,11 +393,12 @@ impl AsyncClient { let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter, properties); let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let request = Request::Subscribe(Some(NoticeTx(notice_tx)), subscribe); + let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::TryRequest(request)); } - self.request_tx.try_send(request)?; + self.request_tx + .try_send((Some(NoticeTx(notice_tx)), request))?; Ok(NoticeFuture(notice_rx)) } @@ -424,8 +430,10 @@ impl AsyncClient { let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let request = Request::Unsubscribe(Some(NoticeTx(notice_tx)), unsubscribe); - self.request_tx.send_async(request).await?; + let request = Request::Unsubscribe(unsubscribe); + self.request_tx + .send_async((Some(NoticeTx(notice_tx)), request)) + .await?; Ok(NoticeFuture(notice_rx)) } @@ -454,8 +462,9 @@ impl AsyncClient { let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let request = Request::Unsubscribe(Some(NoticeTx(notice_tx)), unsubscribe); - self.request_tx.try_send(request)?; + let request = Request::Unsubscribe(unsubscribe); + self.request_tx + .try_send((Some(NoticeTx(notice_tx)), request))?; Ok(NoticeFuture(notice_rx)) } @@ -474,14 +483,14 @@ impl AsyncClient { /// Sends a MQTT disconnect to the `EventLoop` pub async fn disconnect(&self) -> Result<(), ClientError> { let request = Request::Disconnect; - self.request_tx.send_async(request).await?; + self.request_tx.send_async((None, request)).await?; Ok(()) } /// Attempts to send a MQTT disconnect to the `EventLoop` pub fn try_disconnect(&self) -> Result<(), ClientError> { let request = Request::Disconnect; - self.request_tx.try_send(request)?; + self.request_tx.try_send((None, request))?; Ok(()) } } @@ -531,7 +540,7 @@ impl Client { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_sender(request_tx: Sender) -> Client { + pub fn from_sender(request_tx: Sender<(Option, Request)>) -> Client { Client { client: AsyncClient::from_senders(request_tx), } @@ -558,11 +567,11 @@ impl Client { // Fulfill instantly for QoS 0 let notice_tx = Some(NoticeTx(notice_tx)); - let publish = Request::Publish(notice_tx, publish); + let publish = Request::Publish(publish); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } - self.client.request_tx.send(publish)?; + self.client.request_tx.send((notice_tx, publish))?; Ok(NoticeFuture(notice_rx)) } @@ -630,7 +639,7 @@ impl Client { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.client.request_tx.send(ack)?; + self.client.request_tx.send((None, ack))?; } Ok(()) } @@ -652,11 +661,12 @@ impl Client { let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let request = Request::Subscribe(Some(NoticeTx(notice_tx)), subscribe); + let notice_tx = NoticeTx(notice_tx); + let request = Request::Subscribe(subscribe); if !is_filter_valid { return Err(ClientError::Request(request)); } - self.client.request_tx.send(request)?; + self.client.request_tx.send((Some(notice_tx), request))?; Ok(NoticeFuture(notice_rx)) } @@ -709,11 +719,12 @@ impl Client { let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter, properties); let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let request = Request::Subscribe(Some(NoticeTx(notice_tx)), subscribe); + let notice_tx = NoticeTx(notice_tx); + let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } - self.client.request_tx.send(request)?; + self.client.request_tx.send((Some(notice_tx), request))?; Ok(NoticeFuture(notice_rx)) } @@ -763,9 +774,10 @@ impl Client { let unsubscribe = Unsubscribe::new(topic, properties); let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let notice_tx = NoticeTx(notice_tx); - let request = Request::Unsubscribe(Some(NoticeTx(notice_tx)), unsubscribe); - self.client.request_tx.send(request)?; + let request = Request::Unsubscribe(unsubscribe); + self.client.request_tx.send((Some(notice_tx), request))?; Ok(NoticeFuture(notice_rx)) } @@ -798,7 +810,7 @@ impl Client { /// Sends a MQTT disconnect to the `EventLoop` pub fn disconnect(&self) -> Result<(), ClientError> { let request = Request::Disconnect; - self.client.request_tx.send(request)?; + self.client.request_tx.send((None, request))?; Ok(()) } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 5c4444135..66efa84e0 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -3,6 +3,7 @@ use super::mqttbytes::v5::*; use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport}; use crate::eventloop::socket_connect; use crate::framed::AsyncReadWrite; +use crate::notice::NoticeTx; use flume::{bounded, Receiver, Sender}; use tokio::select; @@ -73,11 +74,11 @@ pub struct EventLoop { /// Current state of the connection pub state: MqttState, /// Request stream - requests_rx: Receiver, + requests_rx: Receiver<(Option, Request)>, /// Requests handle to send requests - pub(crate) requests_tx: Sender, + pub(crate) requests_tx: Sender<(Option, Request)>, /// Pending packets from last session - pub(crate) pending: VecDeque, + pub(crate) pending: VecDeque<(Option, Request)>, /// Network connection to the broker network: Option, /// Keep alive time @@ -128,7 +129,7 @@ impl EventLoop { // drain requests from channel which weren't yet received let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect(); - requests_in_channel.retain(|request| { + requests_in_channel.retain(|(_, request)| { match request { Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack _ => true, @@ -223,8 +224,8 @@ impl EventLoop { &self.requests_rx, self.options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { - Ok(request) => { - if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { + Ok((tx, request)) => { + if let Some(outgoing) = self.state.handle_outgoing_packet(tx, request)? { network.write(outgoing).await?; } network.flush().await?; @@ -245,7 +246,7 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.options.keep_alive); - if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? { + if let Some(outgoing) = self.state.handle_outgoing_packet(None, Request::PingReq)? { network.write(outgoing).await?; } network.flush().await?; @@ -255,10 +256,10 @@ impl EventLoop { } async fn next_request( - pending: &mut VecDeque, - rx: &Receiver, + pending: &mut VecDeque<(Option, Request)>, + rx: &Receiver<(Option, Request)>, pending_throttle: Duration, - ) -> Result { + ) -> Result<(Option, Request), ConnectionError> { if !pending.is_empty() { time::sleep(pending_throttle).await; // We must call .next() AFTER sleep() otherwise .next() would diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index a07bbb2e3..fb83847dc 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -14,8 +14,8 @@ mod framed; pub mod mqttbytes; mod state; +use crate::Outgoing; use crate::{NetworkOptions, Transport}; -use crate::{NoticeTx, Outgoing}; use mqttbytes::v5::*; @@ -33,63 +33,22 @@ pub type Incoming = Packet; /// Requests by the client to mqtt event loop. Request are /// handled one by one. -#[derive(Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum Request { - Publish(Option, Publish), + Publish(Publish), PubAck(PubAck), PubRec(PubRec), PubComp(PubComp), - PubRel(Option, PubRel), + PubRel(PubRel), PingReq, PingResp, - Subscribe(Option, Subscribe), + Subscribe(Subscribe), SubAck(SubAck), - Unsubscribe(Option, Unsubscribe), + Unsubscribe(Unsubscribe), UnsubAck(UnsubAck), Disconnect, } -impl Clone for Request { - fn clone(&self) -> Self { - match self { - Self::Publish(_, p) => Self::Publish(None, p.clone()), - Self::PubAck(p) => Self::PubAck(p.clone()), - Self::PubRec(p) => Self::PubRec(p.clone()), - Self::PubRel(_, p) => Self::PubRel(None, p.clone()), - Self::PubComp(p) => Self::PubComp(p.clone()), - Self::Subscribe(_, p) => Self::Subscribe(None, p.clone()), - Self::SubAck(p) => Self::SubAck(p.clone()), - Self::PingReq => Self::PingReq, - Self::PingResp => Self::PingResp, - Self::Disconnect => Self::Disconnect, - Self::Unsubscribe(_, p) => Self::Unsubscribe(None, p.clone()), - Self::UnsubAck(p) => Self::UnsubAck(p.clone()), - } - } -} - -impl PartialEq for Request { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::Publish(_, p1), Self::Publish(_, p2)) => p1 == p2, - (Self::PubAck(p1), Self::PubAck(p2)) => p1 == p2, - (Self::PubRec(p1), Self::PubRec(p2)) => p1 == p2, - (Self::PubRel(_, p1), Self::PubRel(_, p2)) => p1 == p2, - (Self::PubComp(p1), Self::PubComp(p2)) => p1 == p2, - (Self::Subscribe(_, p1), Self::Subscribe(_, p2)) => p1 == p2, - (Self::SubAck(p1), Self::SubAck(p2)) => p1 == p2, - (Self::PingReq, Self::PingReq) - | (Self::PingResp, Self::PingResp) - | (Self::Disconnect, Self::Disconnect) => true, - (Self::Unsubscribe(_, p1), Self::Unsubscribe(_, p2)) => p1 == p2, - (Self::UnsubAck(p1), Self::UnsubAck(p2)) => p1 == p2, - _ => false, - } - } -} - -impl Eq for Request {} - #[cfg(feature = "websocket")] type RequestModifierFn = Arc< dyn Fn(http::Request<()>) -> Pin> + Send>> diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index b324d4718..ccfdaa1a9 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -161,18 +161,18 @@ impl MqttState { } /// Returns inflight outgoing packets and clears internal queues - pub fn clean(&mut self) -> Vec { + pub fn clean(&mut self) -> Vec<(Option, Request)> { let mut pending = Vec::with_capacity(100); // remove and collect pending publishes for (_, (publish, tx)) in self.outgoing_pub.drain() { - let request = Request::Publish(tx, publish); - pending.push(request); + let request = Request::Publish(publish); + pending.push((tx, request)); } // remove and collect pending releases for (pkid, tx) in self.outgoing_rel.drain() { - let request = Request::PubRel(tx, PubRel::new(pkid, None)); - pending.push(request); + let request = Request::PubRel(PubRel::new(pkid, None)); + pending.push((tx, request)); } self.outgoing_rel.clear(); @@ -193,13 +193,14 @@ impl MqttState { /// be put on to the network by the eventloop pub fn handle_outgoing_packet( &mut self, + tx: Option, request: Request, ) -> Result, StateError> { let packet = match request { - Request::Publish(tx, publish) => self.outgoing_publish(publish, tx)?, - Request::PubRel(tx, pubrel) => self.outgoing_pubrel(pubrel, tx)?, - Request::Subscribe(tx, subscribe) => self.outgoing_subscribe(subscribe, tx)?, - Request::Unsubscribe(tx, unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, + Request::Publish(publish) => self.outgoing_publish(publish, tx)?, + Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel, tx)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe, tx)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, Request::PingReq => self.outgoing_ping()?, Request::Disconnect => { self.outgoing_disconnect(DisconnectReasonCode::NormalDisconnection)? @@ -1052,7 +1053,7 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(None, publish)) + mqtt.handle_outgoing_packet(None, Request::Publish(publish)) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1, None))) .unwrap(); From d5698fea27545e3c3f13084f865a0f8b4c344a36 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 28 May 2024 02:54:04 +0530 Subject: [PATCH 19/20] refactor: `NoticeFuture::new` --- rumqttc/src/client.rs | 169 +++++++++++++++++++++------------------ rumqttc/src/notice.rs | 6 ++ rumqttc/src/v5/client.rs | 156 +++++++++++++++++++----------------- 3 files changed, 177 insertions(+), 154 deletions(-) diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index a1b9faee1..58688d5d4 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -83,16 +83,17 @@ impl AsyncClient { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = Some(NoticeTx(notice_tx)); - - let publish = Request::Publish(publish); + let request = Request::Publish(publish); if !valid_topic(&topic) { - return Err(ClientError::Request(publish)); + return Err(ClientError::Request(request)); } - self.request_tx.send_async((notice_tx, publish)).await?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + // Fulfill instantly for QoS 0 + let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); + self.request_tx.send_async((notice_tx, request)).await?; + + Ok(future) } /// Attempts to send a MQTT Publish to the `EventLoop`. @@ -110,23 +111,17 @@ impl AsyncClient { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; + let request = Request::Publish(publish); + if !valid_topic(&topic) { + return Err(ClientError::TryRequest(request)); + } - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = NoticeTx(notice_tx); + let (notice_tx, future) = NoticeTx::new(); // Fulfill instantly for QoS 0 - let notice_tx = if qos == QoS::AtMostOnce { - notice_tx.success(); - None - } else { - Some(notice_tx) - }; + let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); + self.request_tx.try_send((notice_tx, request))?; - let publish = Request::Publish(publish); - if !valid_topic(&topic) { - return Err(ClientError::TryRequest(publish)); - } - self.request_tx.try_send((notice_tx, publish))?; - Ok(NoticeFuture(notice_rx)) + Ok(future) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. @@ -161,13 +156,14 @@ impl AsyncClient { { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; + let request = Request::Publish(publish); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = Some(NoticeTx(notice_tx)); + let (notice_tx, future) = NoticeTx::new(); + // Fulfill instantly for QoS 0 + let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); + self.request_tx.send_async((notice_tx, request)).await?; - let publish = Request::Publish(publish); - self.request_tx.send_async((notice_tx, publish)).await?; - Ok(NoticeFuture(notice_rx)) + Ok(future) } /// Sends a MQTT Subscribe to the `EventLoop` @@ -178,16 +174,16 @@ impl AsyncClient { ) -> Result { let topic = topic.into(); let subscribe = Subscribe::new(&topic, qos); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = NoticeTx(notice_tx); let request = Request::Subscribe(subscribe); if !valid_filter(&topic) { return Err(ClientError::Request(request)); } - self.request_tx - .send_async((Some(notice_tx), request)) - .await?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.request_tx.send_async((notice_tx, request)).await?; + + Ok(future) } /// Attempts to send a MQTT Subscribe to the `EventLoop` @@ -198,14 +194,16 @@ impl AsyncClient { ) -> Result { let topic = topic.into(); let subscribe = Subscribe::new(&topic, qos); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = NoticeTx(notice_tx); let request = Request::Subscribe(subscribe); if !valid_filter(&topic) { return Err(ClientError::TryRequest(request)); } - self.request_tx.try_send((Some(notice_tx), request))?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` @@ -216,16 +214,17 @@ impl AsyncClient { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = NoticeTx(notice_tx); let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } - self.request_tx - .send_async((Some(notice_tx), request)) - .await?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + // Fulfill instantly for QoS 0 + let notice_tx = Some(notice_tx); + self.request_tx.send_async((notice_tx, request)).await?; + + Ok(future) } /// Attempts to send a MQTT Subscribe for multiple topics to the `EventLoop` @@ -236,14 +235,17 @@ impl AsyncClient { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = NoticeTx(notice_tx); let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::TryRequest(request)); } - self.request_tx.try_send((Some(notice_tx), request))?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + // Fulfill instantly for QoS 0 + let notice_tx = Some(notice_tx); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } /// Sends a MQTT Unsubscribe to the `EventLoop` @@ -252,25 +254,26 @@ impl AsyncClient { topic: S, ) -> Result { let unsubscribe = Unsubscribe::new(topic.into()); - - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = NoticeTx(notice_tx); let request = Request::Unsubscribe(unsubscribe); - self.request_tx - .send_async((Some(notice_tx), request)) - .await?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + // Fulfill instantly for QoS 0 + let notice_tx = Some(notice_tx); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } /// Attempts to send a MQTT Unsubscribe to the `EventLoop` pub fn try_unsubscribe>(&self, topic: S) -> Result { let unsubscribe = Unsubscribe::new(topic.into()); - - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = NoticeTx(notice_tx); let request = Request::Unsubscribe(unsubscribe); - self.request_tx.try_send((Some(notice_tx), request))?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } /// Sends a MQTT disconnect to the `EventLoop` @@ -353,16 +356,17 @@ impl Client { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = Some(NoticeTx(notice_tx)); - - let publish = Request::Publish(publish); + let request = Request::Publish(publish); if !valid_topic(&topic) { - return Err(ClientError::Request(publish)); + return Err(ClientError::Request(request)); } - self.client.request_tx.send((notice_tx, publish))?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + // Fulfill instantly for QoS 0 + let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); + self.client.request_tx.send((notice_tx, request))?; + + Ok(future) } pub fn try_publish( @@ -403,14 +407,16 @@ impl Client { ) -> Result { let topic = topic.into(); let subscribe = Subscribe::new(&topic, qos); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = NoticeTx(notice_tx); let request = Request::Subscribe(subscribe); if !valid_filter(&topic) { return Err(ClientError::Request(request)); } - self.client.request_tx.send((Some(notice_tx), request))?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.client.request_tx.send((notice_tx, request))?; + + Ok(future) } /// Sends a MQTT Subscribe to the `EventLoop` @@ -430,14 +436,16 @@ impl Client { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = NoticeTx(notice_tx); let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } - self.client.request_tx.send((Some(notice_tx), request))?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.client.request_tx.send((notice_tx, request))?; + + Ok(future) } pub fn try_subscribe_many(&self, topics: T) -> Result @@ -450,12 +458,13 @@ impl Client { /// Sends a MQTT Unsubscribe to the `EventLoop` pub fn unsubscribe>(&self, topic: S) -> Result { let unsubscribe = Unsubscribe::new(topic.into()); - - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = NoticeTx(notice_tx); let request = Request::Unsubscribe(unsubscribe); - self.client.request_tx.send((Some(notice_tx), request))?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.client.request_tx.send((notice_tx, request))?; + + Ok(future) } /// Sends a MQTT Unsubscribe to the `EventLoop` diff --git a/rumqttc/src/notice.rs b/rumqttc/src/notice.rs index 2b3eb0ebd..3896e5c4d 100644 --- a/rumqttc/src/notice.rs +++ b/rumqttc/src/notice.rs @@ -48,6 +48,12 @@ impl NoticeFuture { pub struct NoticeTx(pub(crate) oneshot::Sender); impl NoticeTx { + pub fn new() -> (Self, NoticeFuture) { + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + + (NoticeTx(notice_tx), NoticeFuture(notice_rx)) + } + pub(crate) fn success(self) { _ = self.0.send(Ok(())); } diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 292619672..5edeba93c 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -86,17 +86,17 @@ impl AsyncClient { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; + let request = Request::Publish(publish); + if !valid_topic(&topic) { + return Err(ClientError::Request(request)); + } - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let (notice_tx, future) = NoticeTx::new(); // Fulfill instantly for QoS 0 - let notice_tx = Some(NoticeTx(notice_tx)); + let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); + self.request_tx.send_async((notice_tx, request)).await?; - let publish = Request::Publish(publish); - if !valid_topic(&topic) { - return Err(ClientError::Request(publish)); - } - self.request_tx.send_async((notice_tx, publish)).await?; - Ok(NoticeFuture(notice_rx)) + Ok(future) } pub async fn publish_with_properties( @@ -145,17 +145,17 @@ impl AsyncClient { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; + let request = Request::Publish(publish); + if !valid_topic(&topic) { + return Err(ClientError::TryRequest(request)); + } - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let (notice_tx, future) = NoticeTx::new(); // Fulfill instantly for QoS 0 - let notice_tx = Some(NoticeTx(notice_tx)); + let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); + self.request_tx.try_send((notice_tx, request))?; - let publish = Request::Publish(publish); - if !valid_topic(&topic) { - return Err(ClientError::TryRequest(publish)); - } - self.request_tx.try_send((notice_tx, publish))?; - Ok(NoticeFuture(notice_rx)) + Ok(future) } pub fn try_publish_with_properties( @@ -221,17 +221,17 @@ impl AsyncClient { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; + let request = Request::Publish(publish); + if !valid_topic(&topic) { + return Err(ClientError::TryRequest(request)); + } - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let (notice_tx, future) = NoticeTx::new(); // Fulfill instantly for QoS 0 - let notice_tx = Some(NoticeTx(notice_tx)); + let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); + self.request_tx.send_async((notice_tx, request)).await?; - let publish = Request::Publish(publish); - if !valid_topic(&topic) { - return Err(ClientError::TryRequest(publish)); - } - self.request_tx.send_async((notice_tx, publish)).await?; - Ok(NoticeFuture(notice_rx)) + Ok(future) } pub async fn publish_bytes_with_properties( @@ -273,15 +273,16 @@ impl AsyncClient { let filter = Filter::new(topic, qos); let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let request = Request::Subscribe(subscribe); if !is_filter_valid { return Err(ClientError::Request(request)); } - self.request_tx - .send_async((Some(NoticeTx(notice_tx)), request)) - .await?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.request_tx.send_async((notice_tx, request)).await?; + + Ok(future) } pub async fn subscribe_with_properties>( @@ -311,14 +312,16 @@ impl AsyncClient { let filter = Filter::new(topic, qos); let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let request = Request::Subscribe(subscribe); if !is_filter_valid { return Err(ClientError::TryRequest(request)); } - self.request_tx - .try_send((Some(NoticeTx(notice_tx)), request))?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } pub fn try_subscribe_with_properties>( @@ -350,16 +353,16 @@ impl AsyncClient { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter, properties); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } - self.request_tx - .send_async((Some(NoticeTx(notice_tx)), request)) - .await?; - Ok(NoticeFuture(notice_rx)) + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.request_tx.send_async((notice_tx, request)).await?; + + Ok(future) } pub async fn subscribe_many_with_properties( @@ -392,14 +395,16 @@ impl AsyncClient { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter, properties); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::TryRequest(request)); } - self.request_tx - .try_send((Some(NoticeTx(notice_tx)), request))?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } pub fn try_subscribe_many_with_properties( @@ -427,14 +432,13 @@ impl AsyncClient { properties: Option, ) -> Result { let unsubscribe = Unsubscribe::new(topic, properties); + let request = Request::Unsubscribe(unsubscribe); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.request_tx.try_send((notice_tx, request))?; - let request = Request::Unsubscribe(unsubscribe); - self.request_tx - .send_async((Some(NoticeTx(notice_tx)), request)) - .await?; - Ok(NoticeFuture(notice_rx)) + Ok(future) } pub async fn unsubscribe_with_properties>( @@ -459,13 +463,13 @@ impl AsyncClient { properties: Option, ) -> Result { let unsubscribe = Unsubscribe::new(topic, properties); + let request = Request::Unsubscribe(unsubscribe); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.request_tx.try_send((notice_tx, request))?; - let request = Request::Unsubscribe(unsubscribe); - self.request_tx - .try_send((Some(NoticeTx(notice_tx)), request))?; - Ok(NoticeFuture(notice_rx)) + Ok(future) } pub fn try_unsubscribe_with_properties>( @@ -562,17 +566,17 @@ impl Client { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; + let request = Request::Publish(publish); + if !valid_topic(&topic) { + return Err(ClientError::Request(request)); + } - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + let (notice_tx, future) = NoticeTx::new(); // Fulfill instantly for QoS 0 - let notice_tx = Some(NoticeTx(notice_tx)); + let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); + self.client.request_tx.send((notice_tx, request))?; - let publish = Request::Publish(publish); - if !valid_topic(&topic) { - return Err(ClientError::Request(publish)); - } - self.client.request_tx.send((notice_tx, publish))?; - Ok(NoticeFuture(notice_rx)) + Ok(future) } pub fn publish_with_properties( @@ -660,14 +664,16 @@ impl Client { let filter = Filter::new(topic, qos); let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = NoticeTx(notice_tx); let request = Request::Subscribe(subscribe); if !is_filter_valid { return Err(ClientError::Request(request)); } - self.client.request_tx.send((Some(notice_tx), request))?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.client.request_tx.try_send((notice_tx, request))?; + + Ok(future) } pub fn subscribe_with_properties>( @@ -718,14 +724,16 @@ impl Client { let mut topics_iter = topics.into_iter(); let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); let subscribe = Subscribe::new_many(topics_iter, properties); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = NoticeTx(notice_tx); let request = Request::Subscribe(subscribe); if !is_valid_filters { return Err(ClientError::Request(request)); } - self.client.request_tx.send((Some(notice_tx), request))?; - Ok(NoticeFuture(notice_rx)) + + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.client.request_tx.try_send((notice_tx, request))?; + + Ok(future) } pub fn subscribe_many_with_properties( @@ -772,13 +780,13 @@ impl Client { properties: Option, ) -> Result { let unsubscribe = Unsubscribe::new(topic, properties); + let request = Request::Unsubscribe(unsubscribe); - let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); - let notice_tx = NoticeTx(notice_tx); + let (notice_tx, future) = NoticeTx::new(); + let notice_tx = Some(notice_tx); + self.client.request_tx.try_send((notice_tx, request))?; - let request = Request::Unsubscribe(unsubscribe); - self.client.request_tx.send((Some(notice_tx), request))?; - Ok(NoticeFuture(notice_rx)) + Ok(future) } pub fn unsubscribe_with_properties>( From 4d99e8c8bf21de7a3798a39d27ca2f9cab8e6127 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 28 May 2024 13:05:32 +0530 Subject: [PATCH 20/20] refactor: not an `Option` --- rumqttc/src/client.rs | 83 +++++++++++++------------- rumqttc/src/eventloop.rs | 15 ++--- rumqttc/src/state.rs | 114 +++++++++++++++++------------------ rumqttc/src/v5/client.rs | 93 ++++++++++++++--------------- rumqttc/src/v5/eventloop.rs | 15 ++--- rumqttc/src/v5/state.rs | 116 ++++++++++++++++++------------------ 6 files changed, 217 insertions(+), 219 deletions(-) diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 58688d5d4..359605c38 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -23,14 +23,14 @@ pub enum ClientError { TryRequest(Request), } -impl From, Request)>> for ClientError { - fn from(e: SendError<(Option, Request)>) -> Self { +impl From> for ClientError { + fn from(e: SendError<(NoticeTx, Request)>) -> Self { Self::Request(e.into_inner().1) } } -impl From, Request)>> for ClientError { - fn from(e: TrySendError<(Option, Request)>) -> Self { +impl From> for ClientError { + fn from(e: TrySendError<(NoticeTx, Request)>) -> Self { Self::TryRequest(e.into_inner().1) } } @@ -44,7 +44,7 @@ impl From, Request)>> for ClientError { /// from the broker, i.e. move ahead. #[derive(Clone, Debug)] pub struct AsyncClient { - request_tx: Sender<(Option, Request)>, + request_tx: Sender<(NoticeTx, Request)>, } impl AsyncClient { @@ -64,7 +64,7 @@ impl AsyncClient { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_senders(request_tx: Sender<(Option, Request)>) -> AsyncClient { + pub fn from_senders(request_tx: Sender<(NoticeTx, Request)>) -> AsyncClient { AsyncClient { request_tx } } @@ -89,8 +89,6 @@ impl AsyncClient { } let (notice_tx, future) = NoticeTx::new(); - // Fulfill instantly for QoS 0 - let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); self.request_tx.send_async((notice_tx, request)).await?; Ok(future) @@ -117,30 +115,33 @@ impl AsyncClient { } let (notice_tx, future) = NoticeTx::new(); - // Fulfill instantly for QoS 0 - let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); self.request_tx.try_send((notice_tx, request))?; Ok(future) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub async fn ack(&self, publish: &Publish) -> Result { let ack = get_ack_req(publish); + let (notice_tx, future) = NoticeTx::new(); if let Some(ack) = ack { - self.request_tx.send_async((None, ack)).await?; + self.request_tx.send_async((notice_tx, ack)).await?; } - Ok(()) + + Ok(future) } /// Attempts to send a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn try_ack(&self, publish: &Publish) -> Result { let ack = get_ack_req(publish); + + let (notice_tx, future) = NoticeTx::new(); if let Some(ack) = ack { - self.request_tx.try_send((None, ack))?; + self.request_tx.try_send((notice_tx, ack))?; } - Ok(()) + + Ok(future) } /// Sends a MQTT Publish to the `EventLoop` @@ -159,8 +160,6 @@ impl AsyncClient { let request = Request::Publish(publish); let (notice_tx, future) = NoticeTx::new(); - // Fulfill instantly for QoS 0 - let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); self.request_tx.send_async((notice_tx, request)).await?; Ok(future) @@ -180,7 +179,6 @@ impl AsyncClient { } let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.request_tx.send_async((notice_tx, request)).await?; Ok(future) @@ -200,7 +198,6 @@ impl AsyncClient { } let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.request_tx.try_send((notice_tx, request))?; Ok(future) @@ -221,7 +218,6 @@ impl AsyncClient { let (notice_tx, future) = NoticeTx::new(); // Fulfill instantly for QoS 0 - let notice_tx = Some(notice_tx); self.request_tx.send_async((notice_tx, request)).await?; Ok(future) @@ -242,7 +238,6 @@ impl AsyncClient { let (notice_tx, future) = NoticeTx::new(); // Fulfill instantly for QoS 0 - let notice_tx = Some(notice_tx); self.request_tx.try_send((notice_tx, request))?; Ok(future) @@ -258,7 +253,6 @@ impl AsyncClient { let (notice_tx, future) = NoticeTx::new(); // Fulfill instantly for QoS 0 - let notice_tx = Some(notice_tx); self.request_tx.try_send((notice_tx, request))?; Ok(future) @@ -270,24 +264,29 @@ impl AsyncClient { let request = Request::Unsubscribe(unsubscribe); let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.request_tx.try_send((notice_tx, request))?; Ok(future) } /// Sends a MQTT disconnect to the `EventLoop` - pub async fn disconnect(&self) -> Result<(), ClientError> { + pub async fn disconnect(&self) -> Result { let request = Request::Disconnect(Disconnect); - self.request_tx.send_async((None, request)).await?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.send_async((notice_tx, request)).await?; + + Ok(future) } /// Attempts to send a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { + pub fn try_disconnect(&self) -> Result { let request = Request::Disconnect(Disconnect); - self.request_tx.try_send((None, request))?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } } @@ -335,7 +334,7 @@ impl Client { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_sender(request_tx: Sender<(Option, Request)>) -> Client { + pub fn from_sender(request_tx: Sender<(NoticeTx, Request)>) -> Client { Client { client: AsyncClient::from_senders(request_tx), } @@ -362,8 +361,6 @@ impl Client { } let (notice_tx, future) = NoticeTx::new(); - // Fulfill instantly for QoS 0 - let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); self.client.request_tx.send((notice_tx, request))?; Ok(future) @@ -384,13 +381,15 @@ impl Client { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn ack(&self, publish: &Publish) -> Result { let ack = get_ack_req(publish); + let (notice_tx, future) = NoticeTx::new(); if let Some(ack) = ack { - self.client.request_tx.send((None, ack))?; + self.client.request_tx.send((notice_tx, ack))?; } - Ok(()) + + Ok(future) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. @@ -413,7 +412,6 @@ impl Client { } let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.client.request_tx.send((notice_tx, request))?; Ok(future) @@ -442,7 +440,6 @@ impl Client { } let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.client.request_tx.send((notice_tx, request))?; Ok(future) @@ -461,7 +458,6 @@ impl Client { let request = Request::Unsubscribe(unsubscribe); let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.client.request_tx.send((notice_tx, request))?; Ok(future) @@ -473,10 +469,13 @@ impl Client { } /// Sends a MQTT disconnect to the `EventLoop` - pub fn disconnect(&self) -> Result<(), ClientError> { + pub fn disconnect(&self) -> Result { let request = Request::Disconnect(Disconnect); - self.client.request_tx.send((None, request))?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.client.request_tx.send((notice_tx, request))?; + + Ok(future) } /// Sends a MQTT disconnect to the `EventLoop` diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index 7fef45ff6..bac532ca6 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -76,11 +76,11 @@ pub struct EventLoop { /// Current state of the connection pub state: MqttState, /// Request stream - requests_rx: Receiver<(Option, Request)>, + requests_rx: Receiver<(NoticeTx, Request)>, /// Requests handle to send requests - pub(crate) requests_tx: Sender<(Option, Request)>, + pub(crate) requests_tx: Sender<(NoticeTx, Request)>, /// Pending packets from last session - pub pending: VecDeque<(Option, Request)>, + pub pending: VecDeque<(NoticeTx, Request)>, /// Network connection to the broker pub network: Option, /// Keep alive time @@ -261,7 +261,8 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive); - if let Some(outgoing) = self.state.handle_outgoing_packet(None, Request::PingReq(PingReq))? { + let (tx, _) = NoticeTx::new(); + if let Some(outgoing) = self.state.handle_outgoing_packet(tx, Request::PingReq(PingReq))? { network.write(outgoing).await?; } match time::timeout(network_timeout, network.flush()).await { @@ -283,10 +284,10 @@ impl EventLoop { } async fn next_request( - pending: &mut VecDeque<(Option, Request)>, - rx: &Receiver<(Option, Request)>, + pending: &mut VecDeque<(NoticeTx, Request)>, + rx: &Receiver<(NoticeTx, Request)>, pending_throttle: Duration, - ) -> Result<(Option, Request), ConnectionError> { + ) -> Result<(NoticeTx, Request), ConnectionError> { if !pending.is_empty() { time::sleep(pending_throttle).await; // We must call .pop_front() AFTER sleep() otherwise we would have diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 77b962ee4..739b77345 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -64,17 +64,17 @@ pub struct MqttState { /// Maximum number of allowed inflight pub(crate) max_inflight: u16, /// Outgoing QoS 1, 2 publishes which aren't acked yet - pub(crate) outgoing_pub: HashMap)>, + pub(crate) outgoing_pub: HashMap, /// Packet ids of released QoS 2 publishes - pub(crate) outgoing_rel: HashMap>, + pub(crate) outgoing_rel: HashMap, /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, - outgoing_sub: HashMap>, - outgoing_unsub: HashMap>, + outgoing_sub: HashMap, + outgoing_unsub: HashMap, /// Last collision due to broker not acking in order - pub(crate) collision: Option<(Publish, Option)>, + pub(crate) collision: Option<(Publish, NoticeTx)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately @@ -109,7 +109,7 @@ impl MqttState { } /// Returns inflight outgoing packets and clears internal queues - pub fn clean(&mut self) -> Vec<(Option, Request)> { + pub fn clean(&mut self) -> Vec<(NoticeTx, Request)> { let mut pending = Vec::with_capacity(100); for (_, (publish, tx)) in self.outgoing_pub.drain() { @@ -141,7 +141,7 @@ impl MqttState { /// be put on to the network by the eventloop pub fn handle_outgoing_packet( &mut self, - tx: Option, + tx: NoticeTx, request: Request, ) -> Result, StateError> { let packet = match request { @@ -206,18 +206,12 @@ impl MqttState { debug!("SubAck Pkid = {:?}, QoS = {:?}", suback.pkid, qos); } _ => { - if let Some(tx) = tx { - tx.error(NoticeError::V4Subscribe(*reason)) - } - + tx.error(NoticeError::V4Subscribe(*reason)); return Err(StateError::SubFail { reason: *reason }); } } } - - if let Some(tx) = tx { - tx.success() - } + tx.success(); Ok(None) } @@ -230,14 +224,10 @@ impl MqttState { error!("Unsolicited unsuback packet: {:?}", unsuback.pkid); return Err(StateError::Unsolicited(unsuback.pkid)); } - - if let Some(tx) = self - .outgoing_sub + self.outgoing_sub .remove(&unsuback.pkid) .ok_or(StateError::Unsolicited(unsuback.pkid))? - { - tx.success() - } + .success(); Ok(None) } @@ -275,13 +265,11 @@ impl MqttState { return Err(StateError::Unsolicited(puback.pkid)); } - if let (_, Some(tx)) = self + let (_, tx) = self .outgoing_pub .remove(&puback.pkid) - .ok_or(StateError::Unsolicited(puback.pkid))? - { - tx.success() - } + .ok_or(StateError::Unsolicited(puback.pkid))?; + tx.success(); self.last_puback = puback.pkid; @@ -340,13 +328,10 @@ impl MqttState { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); return Err(StateError::Unsolicited(pubcomp.pkid)); } - if let Some(tx) = self - .outgoing_rel + self.outgoing_rel .remove(&pubcomp.pkid) .ok_or(StateError::Unsolicited(pubcomp.pkid))? - { - tx.success() - } + .success(); self.inflight -= 1; let packet = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { @@ -373,7 +358,7 @@ impl MqttState { fn outgoing_publish( &mut self, mut publish: Publish, - notice_tx: Option, + notice_tx: NoticeTx, ) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { @@ -394,8 +379,8 @@ impl MqttState { // packet yet. This error is possible only when broker isn't acking sequentially self.outgoing_pub.insert(pkid, (publish.clone(), notice_tx)); self.inflight += 1; - } else if let Some(tx) = notice_tx { - tx.success() + } else { + notice_tx.success() } debug!( @@ -414,7 +399,7 @@ impl MqttState { fn outgoing_pubrel( &mut self, pubrel: PubRel, - notice_tx: Option, + notice_tx: NoticeTx, ) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel, notice_tx)?; @@ -477,7 +462,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, - notice_tx: Option, + notice_tx: NoticeTx, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -501,7 +486,7 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, - notice_tx: Option, + notice_tx: NoticeTx, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -527,7 +512,7 @@ impl MqttState { Ok(Some(Packet::Disconnect)) } - fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Option)> { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, NoticeTx)> { if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); @@ -540,7 +525,7 @@ impl MqttState { fn save_pubrel( &mut self, mut pubrel: PubRel, - notice_tx: Option, + notice_tx: NoticeTx, ) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets @@ -633,7 +618,8 @@ mod test { let publish = build_outgoing_publish(QoS::AtMostOnce); // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 0); @@ -641,12 +627,14 @@ mod test { let publish = build_outgoing_publish(QoS::AtLeastOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); assert_eq!(mqtt.last_pkid, 2); assert_eq!(mqtt.inflight, 2); @@ -654,12 +642,14 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); assert_eq!(mqtt.last_pkid, 4); assert_eq!(mqtt.inflight, 4); } @@ -745,8 +735,10 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish1, None).unwrap(); - mqtt.outgoing_publish(publish2, None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish1, tx).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish2, tx).unwrap(); assert_eq!(mqtt.inflight, 2); mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); @@ -778,8 +770,10 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - let _publish_out = mqtt.outgoing_publish(publish1, None); - let _publish_out = mqtt.outgoing_publish(publish2, None); + let (tx, _) = NoticeTx::new(); + let _publish_out = mqtt.outgoing_publish(publish1, tx); + let (tx, _) = NoticeTx::new(); + let _publish_out = mqtt.outgoing_publish(publish2, tx); mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); assert_eq!(mqtt.inflight, 2); @@ -797,7 +791,8 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - let packet = mqtt.outgoing_publish(publish, None).unwrap().unwrap(); + let (tx, _) = NoticeTx::new(); + let packet = mqtt.outgoing_publish(publish, tx).unwrap().unwrap(); match packet { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), @@ -839,7 +834,8 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish, None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); @@ -853,7 +849,8 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(None, Request::Publish(publish)) + let (tx, _) = NoticeTx::new(); + mqtt.handle_outgoing_packet(tx, Request::Publish(publish)) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) .unwrap(); @@ -882,8 +879,9 @@ mod test { fn clean_is_calculating_pending_correctly() { let mut mqtt = build_mqttstate(); - fn build_outgoing_pub() -> HashMap)> { + fn build_outgoing_pub() -> HashMap { let mut outgoing_pub = HashMap::new(); + let (tx, _) = NoticeTx::new(); outgoing_pub.insert( 2, ( @@ -895,9 +893,10 @@ mod test { pkid: 1, payload: "".into(), }, - None, + tx, ), ); + let (tx, _) = NoticeTx::new(); outgoing_pub.insert( 3, ( @@ -909,9 +908,10 @@ mod test { pkid: 2, payload: "".into(), }, - None, + tx, ), ); + let (tx, _) = NoticeTx::new(); outgoing_pub.insert( 4, ( @@ -923,10 +923,10 @@ mod test { pkid: 3, payload: "".into(), }, - None, + tx, ), ); - + let (tx, _) = NoticeTx::new(); outgoing_pub.insert( 7, ( @@ -938,7 +938,7 @@ mod test { pkid: 6, payload: "".into(), }, - None, + tx, ), ); diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 5edeba93c..bfb8ebbc1 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -25,14 +25,14 @@ pub enum ClientError { TryRequest(Request), } -impl From, Request)>> for ClientError { - fn from(e: SendError<(Option, Request)>) -> Self { +impl From> for ClientError { + fn from(e: SendError<(NoticeTx, Request)>) -> Self { Self::Request(e.into_inner().1) } } -impl From, Request)>> for ClientError { - fn from(e: TrySendError<(Option, Request)>) -> Self { +impl From> for ClientError { + fn from(e: TrySendError<(NoticeTx, Request)>) -> Self { Self::TryRequest(e.into_inner().1) } } @@ -46,7 +46,7 @@ impl From, Request)>> for ClientError { /// from the broker, i.e. move ahead. #[derive(Clone, Debug)] pub struct AsyncClient { - request_tx: Sender<(Option, Request)>, + request_tx: Sender<(NoticeTx, Request)>, } impl AsyncClient { @@ -66,7 +66,7 @@ impl AsyncClient { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_senders(request_tx: Sender<(Option, Request)>) -> AsyncClient { + pub fn from_senders(request_tx: Sender<(NoticeTx, Request)>) -> AsyncClient { AsyncClient { request_tx } } @@ -92,8 +92,6 @@ impl AsyncClient { } let (notice_tx, future) = NoticeTx::new(); - // Fulfill instantly for QoS 0 - let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); self.request_tx.send_async((notice_tx, request)).await?; Ok(future) @@ -151,8 +149,6 @@ impl AsyncClient { } let (notice_tx, future) = NoticeTx::new(); - // Fulfill instantly for QoS 0 - let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); self.request_tx.try_send((notice_tx, request))?; Ok(future) @@ -188,22 +184,27 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub async fn ack(&self, publish: &Publish) -> Result { let ack = get_ack_req(publish); + let (notice_tx, future) = NoticeTx::new(); if let Some(ack) = ack { - self.request_tx.send_async((None, ack)).await?; + self.request_tx.send_async((notice_tx, ack)).await?; } - Ok(()) + + Ok(future) } /// Attempts to send a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn try_ack(&self, publish: &Publish) -> Result { let ack = get_ack_req(publish); + + let (notice_tx, future) = NoticeTx::new(); if let Some(ack) = ack { - self.request_tx.try_send((None, ack))?; + self.request_tx.try_send((notice_tx, ack))?; } - Ok(()) + + Ok(future) } /// Sends a MQTT Publish to the `EventLoop` @@ -227,8 +228,6 @@ impl AsyncClient { } let (notice_tx, future) = NoticeTx::new(); - // Fulfill instantly for QoS 0 - let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); self.request_tx.send_async((notice_tx, request)).await?; Ok(future) @@ -279,7 +278,6 @@ impl AsyncClient { } let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.request_tx.send_async((notice_tx, request)).await?; Ok(future) @@ -318,7 +316,6 @@ impl AsyncClient { } let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.request_tx.try_send((notice_tx, request))?; Ok(future) @@ -359,7 +356,6 @@ impl AsyncClient { } let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.request_tx.send_async((notice_tx, request)).await?; Ok(future) @@ -401,7 +397,6 @@ impl AsyncClient { } let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.request_tx.try_send((notice_tx, request))?; Ok(future) @@ -435,7 +430,6 @@ impl AsyncClient { let request = Request::Unsubscribe(unsubscribe); let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.request_tx.try_send((notice_tx, request))?; Ok(future) @@ -466,7 +460,6 @@ impl AsyncClient { let request = Request::Unsubscribe(unsubscribe); let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.request_tx.try_send((notice_tx, request))?; Ok(future) @@ -485,17 +478,23 @@ impl AsyncClient { } /// Sends a MQTT disconnect to the `EventLoop` - pub async fn disconnect(&self) -> Result<(), ClientError> { + pub async fn disconnect(&self) -> Result { let request = Request::Disconnect; - self.request_tx.send_async((None, request)).await?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.send_async((notice_tx, request)).await?; + + Ok(future) } /// Attempts to send a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { + pub fn try_disconnect(&self) -> Result { let request = Request::Disconnect; - self.request_tx.try_send((None, request))?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } } @@ -544,7 +543,7 @@ impl Client { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_sender(request_tx: Sender<(Option, Request)>) -> Client { + pub fn from_sender(request_tx: Sender<(NoticeTx, Request)>) -> Client { Client { client: AsyncClient::from_senders(request_tx), } @@ -572,8 +571,6 @@ impl Client { } let (notice_tx, future) = NoticeTx::new(); - // Fulfill instantly for QoS 0 - let notice_tx = (qos == QoS::AtMostOnce).then_some(notice_tx); self.client.request_tx.send((notice_tx, request))?; Ok(future) @@ -639,19 +636,20 @@ impl Client { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn ack(&self, publish: &Publish) -> Result { let ack = get_ack_req(publish); + let (notice_tx, future) = NoticeTx::new(); if let Some(ack) = ack { - self.client.request_tx.send((None, ack))?; + self.client.request_tx.send((notice_tx, ack))?; } - Ok(()) + + Ok(future) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - self.client.try_ack(publish)?; - Ok(()) + pub fn try_ack(&self, publish: &Publish) -> Result { + self.client.try_ack(publish) } /// Sends a MQTT Subscribe to the `EventLoop` @@ -670,7 +668,6 @@ impl Client { } let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.client.request_tx.try_send((notice_tx, request))?; Ok(future) @@ -730,7 +727,6 @@ impl Client { } let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.client.request_tx.try_send((notice_tx, request))?; Ok(future) @@ -783,7 +779,6 @@ impl Client { let request = Request::Unsubscribe(unsubscribe); let (notice_tx, future) = NoticeTx::new(); - let notice_tx = Some(notice_tx); self.client.request_tx.try_send((notice_tx, request))?; Ok(future) @@ -816,16 +811,18 @@ impl Client { } /// Sends a MQTT disconnect to the `EventLoop` - pub fn disconnect(&self) -> Result<(), ClientError> { + pub fn disconnect(&self) -> Result { let request = Request::Disconnect; - self.client.request_tx.send((None, request))?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.client.request_tx.send((notice_tx, request))?; + + Ok(future) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { - self.client.try_disconnect()?; - Ok(()) + pub fn try_disconnect(&self) -> Result { + self.client.try_disconnect() } } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 66efa84e0..3ad1d4874 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -74,11 +74,11 @@ pub struct EventLoop { /// Current state of the connection pub state: MqttState, /// Request stream - requests_rx: Receiver<(Option, Request)>, + requests_rx: Receiver<(NoticeTx, Request)>, /// Requests handle to send requests - pub(crate) requests_tx: Sender<(Option, Request)>, + pub(crate) requests_tx: Sender<(NoticeTx, Request)>, /// Pending packets from last session - pub(crate) pending: VecDeque<(Option, Request)>, + pub(crate) pending: VecDeque<(NoticeTx, Request)>, /// Network connection to the broker network: Option, /// Keep alive time @@ -246,7 +246,8 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.options.keep_alive); - if let Some(outgoing) = self.state.handle_outgoing_packet(None, Request::PingReq)? { + let (tx, _) = NoticeTx::new(); + if let Some(outgoing) = self.state.handle_outgoing_packet(tx, Request::PingReq)? { network.write(outgoing).await?; } network.flush().await?; @@ -256,10 +257,10 @@ impl EventLoop { } async fn next_request( - pending: &mut VecDeque<(Option, Request)>, - rx: &Receiver<(Option, Request)>, + pending: &mut VecDeque<(NoticeTx, Request)>, + rx: &Receiver<(NoticeTx, Request)>, pending_throttle: Duration, - ) -> Result<(Option, Request), ConnectionError> { + ) -> Result<(NoticeTx, Request), ConnectionError> { if !pending.is_empty() { time::sleep(pending_throttle).await; // We must call .next() AFTER sleep() otherwise .next() would diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index ccfdaa1a9..d0d196909 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -105,17 +105,17 @@ pub struct MqttState { /// Number of outgoing inflight publishes pub(crate) inflight: u16, /// Outgoing QoS 1, 2 publishes which aren't acked yet - pub(crate) outgoing_pub: HashMap)>, + pub(crate) outgoing_pub: HashMap, /// Packet ids of released QoS 2 publishes - pub(crate) outgoing_rel: HashMap>, + pub(crate) outgoing_rel: HashMap, /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, - outgoing_sub: HashMap>, - outgoing_unsub: HashMap>, + outgoing_sub: HashMap, + outgoing_unsub: HashMap, /// Last collision due to broker not acking in order - pub(crate) collision: Option<(Publish, Option)>, + pub(crate) collision: Option<(Publish, NoticeTx)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately @@ -161,7 +161,7 @@ impl MqttState { } /// Returns inflight outgoing packets and clears internal queues - pub fn clean(&mut self) -> Vec<(Option, Request)> { + pub fn clean(&mut self) -> Vec<(NoticeTx, Request)> { let mut pending = Vec::with_capacity(100); // remove and collect pending publishes for (_, (publish, tx)) in self.outgoing_pub.drain() { @@ -193,7 +193,7 @@ impl MqttState { /// be put on to the network by the eventloop pub fn handle_outgoing_packet( &mut self, - tx: Option, + tx: NoticeTx, request: Request, ) -> Result, StateError> { let packet = match request { @@ -270,18 +270,12 @@ impl MqttState { debug!("SubAck Pkid = {:?}, QoS = {:?}", suback.pkid, qos); } _ => { - if let Some(tx) = tx { - tx.error(NoticeError::V5Subscribe(*reason)) - } - + tx.error(NoticeError::V5Subscribe(*reason)); return Err(StateError::SubFail { reason: *reason }); } } } - - if let Some(tx) = tx { - tx.success() - } + tx.success(); Ok(None) } @@ -302,17 +296,11 @@ impl MqttState { for reason in unsuback.reasons.iter() { if reason != &UnsubAckReason::Success { - if let Some(tx) = tx { - tx.error(NoticeError::V5Unsubscribe(*reason)) - } - + tx.error(NoticeError::V5Unsubscribe(*reason)); return Err(StateError::UnsubFail { reason: *reason }); } } - - if let Some(tx) = tx { - tx.success() - } + tx.success(); Ok(None) } @@ -411,13 +399,11 @@ impl MqttState { return Err(StateError::Unsolicited(puback.pkid)); } - if let (_, Some(tx)) = self + let (_, tx) = self .outgoing_pub .remove(&puback.pkid) - .ok_or(StateError::Unsolicited(puback.pkid))? - { - tx.success() - } + .ok_or(StateError::Unsolicited(puback.pkid))?; + tx.success(); self.inflight -= 1; @@ -495,13 +481,10 @@ impl MqttState { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); return Err(StateError::Unsolicited(pubcomp.pkid)); } - if let Some(tx) = self - .outgoing_rel + self.outgoing_rel .remove(&pubcomp.pkid) .ok_or(StateError::Unsolicited(pubcomp.pkid))? - { - tx.success() - } + .success(); self.inflight -= 1; let packet = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { @@ -527,7 +510,7 @@ impl MqttState { fn outgoing_publish( &mut self, mut publish: Publish, - notice_tx: Option, + notice_tx: NoticeTx, ) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { @@ -548,8 +531,8 @@ impl MqttState { // packet yet. This error is possible only when broker isn't acking sequentially self.outgoing_pub.insert(pkid, (publish.clone(), notice_tx)); self.inflight += 1; - } else if let Some(tx) = notice_tx { - tx.success() + } else { + notice_tx.success() } debug!( @@ -560,7 +543,6 @@ impl MqttState { ); let pkid = publish.pkid; - if let Some(props) = &publish.properties { if let Some(alias) = props.topic_alias { if alias > self.broker_topic_alias_max { @@ -583,7 +565,7 @@ impl MqttState { fn outgoing_pubrel( &mut self, pubrel: PubRel, - notice_tx: Option, + notice_tx: NoticeTx, ) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel, notice_tx)?; @@ -646,7 +628,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, - notice_tx: Option, + notice_tx: NoticeTx, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -670,7 +652,7 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, - notice_tx: Option, + notice_tx: NoticeTx, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -698,7 +680,7 @@ impl MqttState { Ok(Some(Packet::Disconnect(Disconnect::new(reason)))) } - fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Option)> { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, NoticeTx)> { if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); @@ -711,7 +693,7 @@ impl MqttState { fn save_pubrel( &mut self, mut pubrel: PubRel, - notice_tx: Option, + notice_tx: NoticeTx, ) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets @@ -749,6 +731,8 @@ impl MqttState { #[cfg(test)] mod test { + use crate::notice::NoticeTx; + use super::mqttbytes::v5::*; use super::mqttbytes::*; use super::{Event, Incoming, Outgoing, Request}; @@ -802,7 +786,8 @@ mod test { let publish = build_outgoing_publish(QoS::AtMostOnce); // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 0); @@ -810,12 +795,14 @@ mod test { let publish = build_outgoing_publish(QoS::AtLeastOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); assert_eq!(mqtt.last_pkid, 2); assert_eq!(mqtt.inflight, 2); @@ -823,12 +810,14 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); assert_eq!(mqtt.last_pkid, 4); assert_eq!(mqtt.inflight, 4); } @@ -840,17 +829,20 @@ mod test { // QoS2 publish let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be set back down to 0, since we hit the limit - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 2); // This should cause a collition - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 2); assert!(mqtt.collision.is_some()); @@ -860,7 +852,8 @@ mod test { assert_eq!(mqtt.inflight, 1); // Now there should be space in the outgoing queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 2); } @@ -944,8 +937,10 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish1, None).unwrap(); - mqtt.outgoing_publish(publish2, None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish1, tx).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish2, tx).unwrap(); assert_eq!(mqtt.inflight, 2); mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap(); @@ -979,8 +974,10 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - let _publish_out = mqtt.outgoing_publish(publish1, None); - let _publish_out = mqtt.outgoing_publish(publish2, None); + let (tx, _) = NoticeTx::new(); + let _publish_out = mqtt.outgoing_publish(publish1, tx); + let (tx, _) = NoticeTx::new(); + let _publish_out = mqtt.outgoing_publish(publish2, tx); mqtt.handle_incoming_pubrec(&PubRec::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 2); @@ -998,7 +995,8 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - match mqtt.outgoing_publish(publish, None).unwrap().unwrap() { + let (tx, _) = NoticeTx::new(); + match mqtt.outgoing_publish(publish, tx).unwrap().unwrap() { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } @@ -1038,7 +1036,8 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish, None).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap(); mqtt.handle_incoming_pubcomp(&PubComp::new(1, None)) @@ -1053,7 +1052,8 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(None, Request::Publish(publish)) + let (tx, _) = NoticeTx::new(); + mqtt.handle_outgoing_packet(tx, Request::Publish(publish)) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1, None))) .unwrap();