diff --git a/Cargo.lock b/Cargo.lock index 59f8430d1..1b7f50657 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2559,6 +2559,7 @@ dependencies = [ "futures-core", "futures-sink", "pin-project-lite", + "slab", "tokio", "tracing", ] diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index 34f5438cd..c36fc45bd 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -65,6 +65,7 @@ To update your code simply remove `Key::ECC()` or `Key::RSA()` from the initiali `rusttls-pemfile` to `2.0.0`, `async-tungstenite` to `0.24.0`, `ws_stream_tungstenite` to `0.12.0` and `http` to `1.0.0`. This is a breaking change as types from some of these crates are part of the public API. +- `publish` / `subscribe` / `unsubscribe` methods on `AsyncClient` and `Client` now return a `PkidPromise` which resolves into the identifier value chosen by the `EventLoop` when handling the packet. ### Deprecated diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 2e2bae14e..f0585aae8 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -24,7 +24,7 @@ proxy = ["dep:async-http-proxy"] [dependencies] futures-util = { version = "0.3", default-features = false, features = ["std", "sink"] } -tokio = { version = "1.36", features = ["rt", "macros", "io-util", "net", "time"] } +tokio = { version = "1.36", features = ["rt", "macros", "io-util", "net", "time", "sync"] } tokio-util = { version = "0.7", features = ["codec"] } bytes = "1.5" log = "0.4" @@ -58,6 +58,7 @@ matches = "0.1" pretty_assertions = "1" pretty_env_logger = "0.5" serde = { version = "1", features = ["derive"] } +tokio-util = { version = "0.7", features = ["time"] } [[example]] name = "tls" diff --git a/rumqttc/examples/ack_notif.rs b/rumqttc/examples/ack_notif.rs new file mode 100644 index 000000000..bb98599a4 --- /dev/null +++ b/rumqttc/examples/ack_notif.rs @@ -0,0 +1,65 @@ +use tokio::task::{self, JoinSet}; + +use rumqttc::{AsyncClient, MqttOptions, QoS}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + loop { + let event = eventloop.poll().await; + match &event { + Ok(v) => { + println!("Event = {v:?}"); + } + Err(e) => { + println!("Error = {e:?}"); + } + } + } + }); + + // Subscribe and wait for broker acknowledgement + client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap() + .wait_async() + .await + .unwrap(); + + // Publish and spawn wait for notification + let mut set = JoinSet::new(); + + let future = client + .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1024]) + .await + .unwrap(); + set.spawn(future.wait_async()); + + let future = client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 1024]) + .await + .unwrap(); + set.spawn(future.wait_async()); + + let future = client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 1024]) + .await + .unwrap(); + set.spawn(future.wait_async()); + + while let Some(res) = set.join_next().await { + println!("Acknoledged = {:?}", res?); + } + + Ok(()) +} diff --git a/rumqttc/examples/pkid_promise.rs b/rumqttc/examples/pkid_promise.rs new file mode 100644 index 000000000..0fefd093e --- /dev/null +++ b/rumqttc/examples/pkid_promise.rs @@ -0,0 +1,70 @@ +use futures_util::stream::StreamExt; +use tokio::{ + select, + task::{self, JoinSet}, +}; +use tokio_util::time::DelayQueue; + +use rumqttc::{AsyncClient, MqttOptions, QoS}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "broker.emqx.io", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + requests(client).await; + }); + + loop { + let event = eventloop.poll().await; + match &event { + Ok(v) => { + println!("Event = {v:?}"); + } + Err(e) => { + println!("Error = {e:?}"); + return Ok(()); + } + } + } +} + +async fn requests(client: AsyncClient) { + let mut joins = JoinSet::new(); + joins.spawn( + client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap() + .wait_async(), + ); + + let mut queue = DelayQueue::new(); + for i in 1..=10 { + queue.insert(i as usize, Duration::from_secs(i)); + } + + loop { + select! { + Some(i) = queue.next() => { + joins.spawn( + client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; i.into_inner()]) + .await + .unwrap().wait_async(), + ); + } + Some(Ok(Ok(pkid))) = joins.join_next() => { + println!("Pkid: {:?}", pkid); + } + else => break, + } + } +} diff --git a/rumqttc/examples/pkid_promise_v5.rs b/rumqttc/examples/pkid_promise_v5.rs new file mode 100644 index 000000000..e4bd3d31c --- /dev/null +++ b/rumqttc/examples/pkid_promise_v5.rs @@ -0,0 +1,70 @@ +use futures_util::stream::StreamExt; +use tokio::{ + select, + task::{self, JoinSet}, +}; +use tokio_util::time::DelayQueue; + +use rumqttc::v5::{mqttbytes::QoS, AsyncClient, MqttOptions}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "broker.emqx.io", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + requests(client).await; + }); + + loop { + let event = eventloop.poll().await; + match &event { + Ok(v) => { + println!("Event = {v:?}"); + } + Err(e) => { + println!("Error = {e:?}"); + return Ok(()); + } + } + } +} + +async fn requests(client: AsyncClient) { + let mut joins = JoinSet::new(); + joins.spawn( + client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap() + .wait_async(), + ); + + let mut queue = DelayQueue::new(); + for i in 1..=10 { + queue.insert(i as usize, Duration::from_secs(i)); + } + + loop { + select! { + Some(i) = queue.next() => { + joins.spawn( + client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; i.into_inner()]) + .await + .unwrap().wait_async(), + ); + } + Some(Ok(Ok(pkid))) = joins.join_next() => { + println!("Pkid: {:?}", pkid); + } + else => break, + } + } +} diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 15cd5f5ad..c69cfcf52 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -3,7 +3,10 @@ use std::time::Duration; use crate::mqttbytes::{v4::*, QoS}; -use crate::{valid_filter, valid_topic, ConnectionError, Event, EventLoop, MqttOptions, Request}; +use crate::{ + valid_topic, ConnectionError, Event, EventLoop, MqttOptions, NoticeFuture, + NoticeTx, Request, +}; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; @@ -20,15 +23,15 @@ pub enum ClientError { TryRequest(Request), } -impl From> for ClientError { - fn from(e: SendError) -> Self { - Self::Request(e.into_inner()) +impl From> for ClientError { + fn from(e: SendError<(NoticeTx, Request)>) -> Self { + Self::Request(e.into_inner().1) } } -impl From> for ClientError { - fn from(e: TrySendError) -> Self { - Self::TryRequest(e.into_inner()) +impl From> for ClientError { + fn from(e: TrySendError<(NoticeTx, Request)>) -> Self { + Self::TryRequest(e.into_inner().1) } } @@ -41,7 +44,7 @@ impl From> for ClientError { /// from the broker, i.e. move ahead. #[derive(Clone, Debug)] pub struct AsyncClient { - request_tx: Sender, + request_tx: Sender<(NoticeTx, Request)>, } impl AsyncClient { @@ -61,7 +64,7 @@ impl AsyncClient { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_senders(request_tx: Sender) -> AsyncClient { + pub fn from_senders(request_tx: Sender<(NoticeTx, Request)>) -> AsyncClient { AsyncClient { request_tx } } @@ -72,7 +75,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, @@ -80,12 +83,15 @@ impl AsyncClient { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); + let request = Request::Publish(publish); if !valid_topic(&topic) { - return Err(ClientError::Request(publish)); + return Err(ClientError::Request(request)); } - self.request_tx.send_async(publish).await?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.send_async((notice_tx, request)).await?; + + Ok(future) } /// Attempts to send a MQTT Publish to the `EventLoop`. @@ -95,7 +101,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, @@ -103,31 +109,39 @@ impl AsyncClient { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); + let request = Request::Publish(publish); if !valid_topic(&topic) { - return Err(ClientError::TryRequest(publish)); + return Err(ClientError::TryRequest(request)); } - self.request_tx.try_send(publish)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub async fn ack(&self, publish: &Publish) -> Result { let ack = get_ack_req(publish); + let (notice_tx, future) = NoticeTx::new(); if let Some(ack) = ack { - self.request_tx.send_async(ack).await?; + self.request_tx.send_async((notice_tx, ack)).await?; } - Ok(()) + + Ok(future) } /// Attempts to send a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn try_ack(&self, publish: &Publish) -> Result { let ack = get_ack_req(publish); + + let (notice_tx, future) = NoticeTx::new(); if let Some(ack) = ack { - self.request_tx.try_send(ack)?; + self.request_tx.try_send((notice_tx, ack))?; } - Ok(()) + + Ok(future) } /// Sends a MQTT Publish to the `EventLoop` @@ -137,101 +151,132 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result<(), ClientError> + ) -> Result where S: Into, { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); - self.request_tx.send_async(publish).await?; - Ok(()) + let request = Request::Publish(publish); + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.send_async((notice_tx, request)).await?; + + Ok(future) } /// Sends a MQTT Subscribe to the `EventLoop` - pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - let topic = topic.into(); - let subscribe = Subscribe::new(&topic, qos); - let request = Request::Subscribe(subscribe); - if !valid_filter(&topic) { - return Err(ClientError::Request(request)); + pub async fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { + let subscribe = Subscribe::new(topic, qos); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(subscribe.into())); } - self.request_tx.send_async(request).await?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.send_async((notice_tx, subscribe.into())).await?; + + Ok(future) } /// Attempts to send a MQTT Subscribe to the `EventLoop` - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - let topic = topic.into(); - let subscribe = Subscribe::new(&topic, qos); - let request = Request::Subscribe(subscribe); - if !valid_filter(&topic) { - return Err(ClientError::TryRequest(request)); + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { + let subscribe = Subscribe::new(topic, qos); + if !subscribe.has_valid_filters() { + return Err(ClientError::TryRequest(subscribe.into())); } - self.request_tx.try_send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.try_send((notice_tx, subscribe.into()))?; + + Ok(future) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub async fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); - let request = Request::Subscribe(subscribe); - if !is_valid_filters { - return Err(ClientError::Request(request)); + let subscribe = Subscribe::new_many(topics); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(subscribe.into())); } - self.request_tx.send_async(request).await?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + // Fulfill instantly for QoS 0 + self.request_tx.send_async((notice_tx, subscribe.into())).await?; + + Ok(future) } /// Attempts to send a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); - let request = Request::Subscribe(subscribe); - if !is_valid_filters { - return Err(ClientError::TryRequest(request)); + let subscribe = Subscribe::new_many(topics); + if !subscribe.has_valid_filters() { + return Err(ClientError::TryRequest(subscribe.into())); } - self.request_tx.try_send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + // Fulfill instantly for QoS 0 + self.request_tx.try_send((notice_tx, subscribe.into()))?; + + Ok(future) } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub async fn unsubscribe>( + &self, + topic: S, + ) -> Result { let unsubscribe = Unsubscribe::new(topic.into()); let request = Request::Unsubscribe(unsubscribe); - self.request_tx.send_async(request).await?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + // Fulfill instantly for QoS 0 + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } /// Attempts to send a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn try_unsubscribe>(&self, topic: S) -> Result { let unsubscribe = Unsubscribe::new(topic.into()); let request = Request::Unsubscribe(unsubscribe); - self.request_tx.try_send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } /// Sends a MQTT disconnect to the `EventLoop` - pub async fn disconnect(&self) -> Result<(), ClientError> { + pub async fn disconnect(&self) -> Result { let request = Request::Disconnect(Disconnect); - self.request_tx.send_async(request).await?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.send_async((notice_tx, request)).await?; + + Ok(future) } /// Attempts to send a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { + pub fn try_disconnect(&self) -> Result { let request = Request::Disconnect(Disconnect); - self.request_tx.try_send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } } @@ -279,7 +324,7 @@ impl Client { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_sender(request_tx: Sender) -> Client { + pub fn from_sender(request_tx: Sender<(NoticeTx, Request)>) -> Client { Client { client: AsyncClient::from_senders(request_tx), } @@ -292,7 +337,7 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, @@ -300,12 +345,15 @@ impl Client { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); + let request = Request::Publish(publish); if !valid_topic(&topic) { - return Err(ClientError::Request(publish)); + return Err(ClientError::Request(request)); } - self.client.request_tx.send(publish)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.client.request_tx.send((notice_tx, request))?; + + Ok(future) } pub fn try_publish( @@ -314,23 +362,24 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, { - self.client.try_publish(topic, qos, retain, payload)?; - Ok(()) + self.client.try_publish(topic, qos, retain, payload) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn ack(&self, publish: &Publish) -> Result { let ack = get_ack_req(publish); + let (notice_tx, future) = NoticeTx::new(); if let Some(ack) = ack { - self.client.request_tx.send(ack)?; + self.client.request_tx.send((notice_tx, ack))?; } - Ok(()) + + Ok(future) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. @@ -340,40 +389,49 @@ impl Client { } /// Sends a MQTT Subscribe to the `EventLoop` - pub fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { let topic = topic.into(); let subscribe = Subscribe::new(&topic, qos); - let request = Request::Subscribe(subscribe); - if !valid_filter(&topic) { - return Err(ClientError::Request(request)); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(subscribe.into())); } - self.client.request_tx.send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.client.request_tx.send((notice_tx, subscribe.into()))?; + + Ok(future) } /// Sends a MQTT Subscribe to the `EventLoop` - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - self.client.try_subscribe(topic, qos)?; - Ok(()) + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { + self.client.try_subscribe(topic, qos) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter); - let request = Request::Subscribe(subscribe); - if !is_valid_filters { - return Err(ClientError::Request(request)); + let subscribe = Subscribe::new_many(topics); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(subscribe.into())); } - self.client.request_tx.send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.client.request_tx.send((notice_tx, subscribe.into()))?; + + Ok(future) } - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -381,24 +439,29 @@ impl Client { } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn unsubscribe>(&self, topic: S) -> Result { let unsubscribe = Unsubscribe::new(topic.into()); let request = Request::Unsubscribe(unsubscribe); - self.client.request_tx.send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.client.request_tx.send((notice_tx, request))?; + + Ok(future) } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - self.client.try_unsubscribe(topic)?; - Ok(()) + pub fn try_unsubscribe>(&self, topic: S) -> Result { + self.client.try_unsubscribe(topic) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn disconnect(&self) -> Result<(), ClientError> { + pub fn disconnect(&self) -> Result { let request = Request::Disconnect(Disconnect); - self.client.request_tx.send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.client.request_tx.send((notice_tx, request))?; + + Ok(future) } /// Sends a MQTT disconnect to the `EventLoop` diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index d31690d99..bac532ca6 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -1,3 +1,4 @@ +use crate::notice::NoticeTx; use crate::{framed::Network, Transport}; use crate::{Incoming, MqttState, NetworkOptions, Packet, Request, StateError}; use crate::{MqttOptions, Outgoing}; @@ -75,11 +76,11 @@ pub struct EventLoop { /// Current state of the connection pub state: MqttState, /// Request stream - requests_rx: Receiver, + requests_rx: Receiver<(NoticeTx, Request)>, /// Requests handle to send requests - pub(crate) requests_tx: Sender, + pub(crate) requests_tx: Sender<(NoticeTx, Request)>, /// Pending packets from last session - pub pending: VecDeque, + pub pending: VecDeque<(NoticeTx, Request)>, /// Network connection to the broker pub network: Option, /// Keep alive time @@ -132,7 +133,7 @@ impl EventLoop { // drain requests from channel which weren't yet received let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect(); - requests_in_channel.retain(|request| { + requests_in_channel.retain(|(_, request)| { match request { Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack _ => true, @@ -241,8 +242,8 @@ impl EventLoop { &self.requests_rx, self.mqtt_options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { - Ok(request) => { - if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { + Ok((tx, request)) => { + if let Some(outgoing) = self.state.handle_outgoing_packet(tx, request)? { network.write(outgoing).await?; } match time::timeout(network_timeout, network.flush()).await { @@ -260,7 +261,8 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive); - if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq))? { + let (tx, _) = NoticeTx::new(); + if let Some(outgoing) = self.state.handle_outgoing_packet(tx, Request::PingReq(PingReq))? { network.write(outgoing).await?; } match time::timeout(network_timeout, network.flush()).await { @@ -282,10 +284,10 @@ impl EventLoop { } async fn next_request( - pending: &mut VecDeque, - rx: &Receiver, + pending: &mut VecDeque<(NoticeTx, Request)>, + rx: &Receiver<(NoticeTx, Request)>, pending_throttle: Duration, - ) -> Result { + ) -> Result<(NoticeTx, Request), ConnectionError> { if !pending.is_empty() { time::sleep(pending_throttle).await; // We must call .pop_front() AFTER sleep() otherwise we would have diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 0c2d67fef..0803fbd8e 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -109,6 +109,7 @@ mod client; mod eventloop; mod framed; pub mod mqttbytes; +mod notice; mod state; pub mod v5; @@ -140,6 +141,8 @@ pub use client::{ pub use eventloop::{ConnectionError, Event, EventLoop}; pub use mqttbytes::v4::*; pub use mqttbytes::*; +use notice::NoticeTx; +pub use notice::{NoticeError, NoticeFuture}; #[cfg(feature = "use-rustls")] use rustls_native_certs::load_native_certs; pub use state::{MqttState, StateError}; @@ -188,7 +191,7 @@ pub enum Outgoing { /// Requests by the client to mqtt event loop. Request are /// handled one by one. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum Request { Publish(Publish), PubAck(PubAck), @@ -204,6 +207,25 @@ pub enum Request { Disconnect(Disconnect), } +impl Request { + fn size(&self) -> usize { + match &self { + Request::Publish(publish) => publish.size(), + Request::PubAck(puback) => puback.size(), + Request::PubRec(pubrec) => pubrec.size(), + Request::PubComp(pubcomp) => pubcomp.size(), + Request::PubRel(pubrel) => pubrel.size(), + Request::PingReq(pingreq) => pingreq.size(), + Request::PingResp(pingresp) => pingresp.size(), + Request::Subscribe(subscribe) => subscribe.size(), + Request::SubAck(suback) => suback.size(), + Request::Unsubscribe(unsubscribe) => unsubscribe.size(), + Request::UnsubAck(unsuback) => unsuback.size(), + Request::Disconnect(disconn) => disconn.size(), + } + } +} + impl From for Request { fn from(publish: Publish) -> Request { Request::Publish(publish) diff --git a/rumqttc/src/mqttbytes/v4/publish.rs b/rumqttc/src/mqttbytes/v4/publish.rs index 37924fa9d..36530b957 100644 --- a/rumqttc/src/mqttbytes/v4/publish.rs +++ b/rumqttc/src/mqttbytes/v4/publish.rs @@ -2,7 +2,7 @@ use super::*; use bytes::{Buf, Bytes}; /// Publish packet -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, Eq, PartialEq)] pub struct Publish { pub dup: bool, pub qos: QoS, diff --git a/rumqttc/src/mqttbytes/v4/subscribe.rs b/rumqttc/src/mqttbytes/v4/subscribe.rs index 42ddb57b1..836311729 100644 --- a/rumqttc/src/mqttbytes/v4/subscribe.rs +++ b/rumqttc/src/mqttbytes/v4/subscribe.rs @@ -2,7 +2,7 @@ use super::*; use bytes::{Buf, Bytes}; /// Subscription packet -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, Eq, PartialEq)] pub struct Subscribe { pub pkid: u16, pub filters: Vec, @@ -93,6 +93,14 @@ impl Subscribe { Ok(1 + remaining_len_bytes + remaining_len) } + + pub(crate) fn has_valid_filters(&self) -> bool { + if self.filters.is_empty() { + return false; + } + + self.filters.iter().all(SubscribeFilter::is_valid) + } } /// Subscription filter @@ -119,6 +127,10 @@ impl SubscribeFilter { write_mqtt_string(buffer, self.path.as_str()); buffer.put_u8(options); } + + fn is_valid(&self) -> bool { + valid_filter(&self.path) + } } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/rumqttc/src/mqttbytes/v4/unsubscribe.rs b/rumqttc/src/mqttbytes/v4/unsubscribe.rs index da34fbcc6..8263752e3 100644 --- a/rumqttc/src/mqttbytes/v4/unsubscribe.rs +++ b/rumqttc/src/mqttbytes/v4/unsubscribe.rs @@ -2,7 +2,7 @@ use super::*; use bytes::{Buf, Bytes}; /// Unsubscribe packet -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Unsubscribe { pub pkid: u16, pub topics: Vec, diff --git a/rumqttc/src/notice.rs b/rumqttc/src/notice.rs new file mode 100644 index 000000000..3896e5c4d --- /dev/null +++ b/rumqttc/src/notice.rs @@ -0,0 +1,64 @@ +use tokio::sync::oneshot; + +use crate::{ + v5::mqttbytes::v5::{SubscribeReasonCode as V5SubscribeReasonCode, UnsubAckReason}, + SubscribeReasonCode, +}; + +#[derive(Debug, thiserror::Error)] +pub enum NoticeError { + #[error("Eventloop dropped Sender")] + Recv, + #[error(" v4 Subscription Failure Reason Code: {0:?}")] + V4Subscribe(SubscribeReasonCode), + #[error(" v5 Subscription Failure Reason Code: {0:?}")] + V5Subscribe(V5SubscribeReasonCode), + #[error(" v5 Unsubscription Failure Reason: {0:?}")] + V5Unsubscribe(UnsubAckReason), +} + +impl From for NoticeError { + fn from(_: oneshot::error::RecvError) -> Self { + Self::Recv + } +} + +type NoticeResult = Result<(), NoticeError>; + +/// A token through which the user is notified of the publish/subscribe/unsubscribe packet being acked by the broker. +#[derive(Debug)] +pub struct NoticeFuture(pub(crate) oneshot::Receiver); + +impl NoticeFuture { + /// Wait for broker to acknowledge by blocking the current thread + /// + /// # Panics + /// Panics if called in an async context + pub fn wait(self) -> NoticeResult { + self.0.blocking_recv()? + } + + /// Await the packet acknowledgement from broker, without blocking the current thread + pub async fn wait_async(self) -> NoticeResult { + self.0.await? + } +} + +#[derive(Debug)] +pub struct NoticeTx(pub(crate) oneshot::Sender); + +impl NoticeTx { + pub fn new() -> (Self, NoticeFuture) { + let (notice_tx, notice_rx) = tokio::sync::oneshot::channel(); + + (NoticeTx(notice_tx), NoticeFuture(notice_rx)) + } + + pub(crate) fn success(self) { + _ = self.0.send(Ok(())); + } + + pub(crate) fn error(self, e: NoticeError) { + _ = self.0.send(Err(e)); + } +} diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index f7cb34841..739b77345 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -1,9 +1,10 @@ -use crate::{Event, Incoming, Outgoing, Request}; +use crate::notice::NoticeTx; +use crate::{Event, Incoming, NoticeError, Outgoing, Request}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; use fixedbitset::FixedBitSet; -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::{io, time::Instant}; /// Errors during state handling @@ -32,6 +33,8 @@ pub enum StateError { Deserialization(#[from] mqttbytes::Error), #[error("Connection closed by peer abruptly")] ConnectionAborted, + #[error("Subscribe failed with reason '{reason:?}' ")] + SubFail { reason: SubscribeReasonCode }, } /// State of the mqtt connection. @@ -40,7 +43,7 @@ pub enum StateError { // This is done for 2 reasons // Bad acks or out of order acks aren't O(n) causing cpu spikes // Any missing acks from the broker are detected during the next recycled use of packet ids -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct MqttState { /// Status of last ping pub await_pingresp: bool, @@ -61,13 +64,17 @@ pub struct MqttState { /// Maximum number of allowed inflight pub(crate) max_inflight: u16, /// Outgoing QoS 1, 2 publishes which aren't acked yet - pub(crate) outgoing_pub: Vec>, + pub(crate) outgoing_pub: HashMap, /// Packet ids of released QoS 2 publishes - pub(crate) outgoing_rel: FixedBitSet, + pub(crate) outgoing_rel: HashMap, /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, + + outgoing_sub: HashMap, + outgoing_unsub: HashMap, + /// Last collision due to broker not acking in order - pub collision: Option, + pub(crate) collision: Option<(Publish, NoticeTx)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately @@ -88,10 +95,12 @@ impl MqttState { last_puback: 0, inflight: 0, max_inflight, + outgoing_pub: HashMap::new(), + outgoing_rel: HashMap::new(), // index 0 is wasted as 0 is not a valid packet id - outgoing_pub: vec![None; max_inflight as usize + 1], - outgoing_rel: FixedBitSet::with_capacity(max_inflight as usize + 1), incoming_pub: FixedBitSet::with_capacity(u16::MAX as usize + 1), + outgoing_sub: HashMap::new(), + outgoing_unsub: HashMap::new(), collision: None, // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), @@ -100,23 +109,18 @@ impl MqttState { } /// Returns inflight outgoing packets and clears internal queues - pub fn clean(&mut self) -> Vec { + pub fn clean(&mut self) -> Vec<(NoticeTx, Request)> { let mut pending = Vec::with_capacity(100); - let (first_half, second_half) = self - .outgoing_pub - .split_at_mut(self.last_puback as usize + 1); - for publish in second_half.iter_mut().chain(first_half) { - if let Some(publish) = publish.take() { - let request = Request::Publish(publish); - pending.push(request); - } + for (_, (publish, tx)) in self.outgoing_pub.drain() { + let request = Request::Publish(publish); + pending.push((tx, request)); } // remove and collect pending releases - for pkid in self.outgoing_rel.ones() { - let request = Request::PubRel(PubRel::new(pkid as u16)); - pending.push(request); + for (pkid, tx) in self.outgoing_rel.drain() { + let request = Request::PubRel(PubRel::new(pkid)); + pending.push((tx, request)); } self.outgoing_rel.clear(); @@ -137,13 +141,14 @@ impl MqttState { /// be put on to the network by the eventloop pub fn handle_outgoing_packet( &mut self, + tx: NoticeTx, request: Request, ) -> Result, StateError> { let packet = match request { - Request::Publish(publish) => self.outgoing_publish(publish)?, - Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, - Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, - Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, + Request::Publish(publish) => self.outgoing_publish(publish, tx)?, + Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel, tx)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe, tx)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, Request::PingReq(_) => self.outgoing_ping()?, Request::Disconnect(_) => self.outgoing_disconnect()?, Request::PubAck(puback) => self.outgoing_puback(puback)?, @@ -168,8 +173,8 @@ impl MqttState { let outgoing = match &packet { Incoming::PingResp => self.handle_incoming_pingresp()?, Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, - Incoming::SubAck(_suback) => self.handle_incoming_suback()?, - Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback()?, + Incoming::SubAck(suback) => self.handle_incoming_suback(suback)?, + Incoming::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback)?, Incoming::PubAck(puback) => self.handle_incoming_puback(puback)?, Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec)?, Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel)?, @@ -184,11 +189,46 @@ impl MqttState { Ok(outgoing) } - fn handle_incoming_suback(&mut self) -> Result, StateError> { + fn handle_incoming_suback(&mut self, suback: &SubAck) -> Result, StateError> { + if suback.pkid > self.max_inflight { + error!("Unsolicited suback packet: {:?}", suback.pkid); + return Err(StateError::Unsolicited(suback.pkid)); + } + + let tx = self + .outgoing_sub + .remove(&suback.pkid) + .ok_or(StateError::Unsolicited(suback.pkid))?; + + for reason in suback.return_codes.iter() { + match reason { + SubscribeReasonCode::Success(qos) => { + debug!("SubAck Pkid = {:?}, QoS = {:?}", suback.pkid, qos); + } + _ => { + tx.error(NoticeError::V4Subscribe(*reason)); + return Err(StateError::SubFail { reason: *reason }); + } + } + } + tx.success(); + Ok(None) } - fn handle_incoming_unsuback(&mut self) -> Result, StateError> { + fn handle_incoming_unsuback( + &mut self, + unsuback: &UnsubAck, + ) -> Result, StateError> { + if unsuback.pkid > self.max_inflight { + error!("Unsolicited unsuback packet: {:?}", unsuback.pkid); + return Err(StateError::Unsolicited(unsuback.pkid)); + } + self.outgoing_sub + .remove(&unsuback.pkid) + .ok_or(StateError::Unsolicited(unsuback.pkid))? + .success(); + Ok(None) } @@ -220,21 +260,23 @@ impl MqttState { } fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { - let publish = self + if puback.pkid > self.max_inflight { + error!("Unsolicited puback packet: {:?}", puback.pkid); + return Err(StateError::Unsolicited(puback.pkid)); + } + + let (_, tx) = self .outgoing_pub - .get_mut(puback.pkid as usize) + .remove(&puback.pkid) .ok_or(StateError::Unsolicited(puback.pkid))?; + tx.success(); self.last_puback = puback.pkid; - if publish.take().is_none() { - error!("Unsolicited puback packet: {:?}", puback.pkid); - return Err(StateError::Unsolicited(puback.pkid)); - } - self.inflight -= 1; - let packet = self.check_collision(puback.pkid).map(|publish| { - self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); + let packet = self.check_collision(puback.pkid).map(|(publish, tx)| { + self.outgoing_pub + .insert(publish.pkid, (publish.clone(), tx)); self.inflight += 1; let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); @@ -248,18 +290,18 @@ impl MqttState { } fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { - let publish = self - .outgoing_pub - .get_mut(pubrec.pkid as usize) - .ok_or(StateError::Unsolicited(pubrec.pkid))?; - - if publish.take().is_none() { + if pubrec.pkid > self.max_inflight { error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); return Err(StateError::Unsolicited(pubrec.pkid)); } + let (_, tx) = self + .outgoing_pub + .remove(&pubrec.pkid) + .ok_or(StateError::Unsolicited(pubrec.pkid))?; + // NOTE: Inflight - 1 for qos2 in comp - self.outgoing_rel.insert(pubrec.pkid as usize); + self.outgoing_rel.insert(pubrec.pkid, tx); let pubrel = PubRel { pkid: pubrec.pkid }; let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); self.events.push_back(event); @@ -282,14 +324,19 @@ impl MqttState { } fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { - if !self.outgoing_rel.contains(pubcomp.pkid as usize) { + if pubcomp.pkid > self.max_inflight { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); return Err(StateError::Unsolicited(pubcomp.pkid)); } + self.outgoing_rel + .remove(&pubcomp.pkid) + .ok_or(StateError::Unsolicited(pubcomp.pkid))? + .success(); - self.outgoing_rel.set(pubcomp.pkid as usize, false); self.inflight -= 1; - let packet = self.check_collision(pubcomp.pkid).map(|publish| { + let packet = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { + self.outgoing_pub + .insert(pubcomp.pkid, (publish.clone(), tx)); let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); self.collision_ping_count = 0; @@ -308,21 +355,21 @@ impl MqttState { /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { + fn outgoing_publish( + &mut self, + mut publish: Publish, + notice_tx: NoticeTx, + ) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); } let pkid = publish.pkid; - if self - .outgoing_pub - .get(publish.pkid as usize) - .ok_or(StateError::Unsolicited(publish.pkid))? - .is_some() - { + + if self.outgoing_pub.get(&publish.pkid).is_some() { info!("Collision on packet id = {:?}", publish.pkid); - self.collision = Some(publish); + self.collision = Some((publish, notice_tx)); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); return Ok(None); @@ -330,9 +377,11 @@ impl MqttState { // if there is an existing publish at this pkid, this implies that broker hasn't acked this // packet yet. This error is possible only when broker isn't acking sequentially - self.outgoing_pub[pkid as usize] = Some(publish.clone()); + self.outgoing_pub.insert(pkid, (publish.clone(), notice_tx)); self.inflight += 1; - }; + } else { + notice_tx.success() + } debug!( "Publish. Topic = {}, Pkid = {:?}, Payload Size = {:?}", @@ -347,8 +396,12 @@ impl MqttState { Ok(Some(Packet::Publish(publish))) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { - let pubrel = self.save_pubrel(pubrel)?; + fn outgoing_pubrel( + &mut self, + pubrel: PubRel, + notice_tx: NoticeTx, + ) -> Result, StateError> { + let pubrel = self.save_pubrel(pubrel, notice_tx)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); @@ -409,6 +462,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, + notice_tx: NoticeTx, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -422,6 +476,7 @@ impl MqttState { subscription.filters, subscription.pkid ); + self.outgoing_sub.insert(pkid, notice_tx); let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); self.events.push_back(event); @@ -431,6 +486,7 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, + notice_tx: NoticeTx, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -440,6 +496,7 @@ impl MqttState { unsub.topics, unsub.pkid ); + self.outgoing_unsub.insert(pkid, notice_tx); let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); self.events.push_back(event); @@ -455,8 +512,8 @@ impl MqttState { Ok(Some(Packet::Disconnect)) } - fn check_collision(&mut self, pkid: u16) -> Option { - if let Some(publish) = &self.collision { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, NoticeTx)> { + if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); } @@ -465,7 +522,11 @@ impl MqttState { None } - fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { + fn save_pubrel( + &mut self, + mut pubrel: PubRel, + notice_tx: NoticeTx, + ) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets 0 => { @@ -475,7 +536,7 @@ impl MqttState { _ => pubrel, }; - self.outgoing_rel.insert(pubrel.pkid as usize); + self.outgoing_rel.insert(pubrel.pkid, notice_tx); self.inflight += 1; Ok(pubrel) } @@ -502,9 +563,11 @@ impl MqttState { #[cfg(test)] mod test { + use std::collections::HashMap; + use super::{MqttState, StateError}; use crate::mqttbytes::v4::*; - use crate::mqttbytes::*; + use crate::{mqttbytes::*, NoticeTx}; use crate::{Event, Incoming, Outgoing, Request}; fn build_outgoing_publish(qos: QoS) -> Publish { @@ -555,7 +618,8 @@ mod test { let publish = build_outgoing_publish(QoS::AtMostOnce); // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 0); @@ -563,12 +627,14 @@ mod test { let publish = build_outgoing_publish(QoS::AtLeastOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); assert_eq!(mqtt.last_pkid, 2); assert_eq!(mqtt.inflight, 2); @@ -576,12 +642,14 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); assert_eq!(mqtt.last_pkid, 4); assert_eq!(mqtt.inflight, 4); } @@ -667,8 +735,10 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish1).unwrap(); - mqtt.outgoing_publish(publish2).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish1, tx).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish2, tx).unwrap(); assert_eq!(mqtt.inflight, 2); mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); @@ -677,8 +747,8 @@ mod test { mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); assert_eq!(mqtt.inflight, 0); - assert!(mqtt.outgoing_pub[1].is_none()); - assert!(mqtt.outgoing_pub[2].is_none()); + assert!(mqtt.outgoing_pub.get(&1).is_none()); + assert!(mqtt.outgoing_pub.get(&2).is_none()); } #[test] @@ -700,18 +770,20 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - let _publish_out = mqtt.outgoing_publish(publish1); - let _publish_out = mqtt.outgoing_publish(publish2); + let (tx, _) = NoticeTx::new(); + let _publish_out = mqtt.outgoing_publish(publish1, tx); + let (tx, _) = NoticeTx::new(); + let _publish_out = mqtt.outgoing_publish(publish2, tx); mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); assert_eq!(mqtt.inflight, 2); // check if the remaining element's pkid is 1 - let backup = mqtt.outgoing_pub[1].clone(); - assert_eq!(backup.unwrap().pkid, 1); + let (backup, _) = mqtt.outgoing_pub.get(&1).unwrap(); + assert_eq!(backup.pkid, 1); - // check if the qos2 element's release pkid is 2 - assert!(mqtt.outgoing_rel.contains(2)); + // check if the qos2 element's release pkik has been set + assert!(mqtt.outgoing_rel.get(&2).is_some()); } #[test] @@ -719,7 +791,8 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - let packet = mqtt.outgoing_publish(publish).unwrap().unwrap(); + let (tx, _) = NoticeTx::new(); + let packet = mqtt.outgoing_publish(publish, tx).unwrap().unwrap(); match packet { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), @@ -761,7 +834,8 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); @@ -775,7 +849,8 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(publish)) + let (tx, _) = NoticeTx::new(); + mqtt.handle_outgoing_packet(tx, Request::Publish(publish)) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) .unwrap(); @@ -804,51 +879,77 @@ mod test { fn clean_is_calculating_pending_correctly() { let mut mqtt = build_mqttstate(); - fn build_outgoing_pub() -> Vec> { - vec![ - None, - Some(Publish { - dup: false, - qos: QoS::AtMostOnce, - retain: false, - topic: "test".to_string(), - pkid: 1, - payload: "".into(), - }), - Some(Publish { - dup: false, - qos: QoS::AtMostOnce, - retain: false, - topic: "test".to_string(), - pkid: 2, - payload: "".into(), - }), - Some(Publish { - dup: false, - qos: QoS::AtMostOnce, - retain: false, - topic: "test".to_string(), - pkid: 3, - payload: "".into(), - }), - None, - None, - Some(Publish { - dup: false, - qos: QoS::AtMostOnce, - retain: false, - topic: "test".to_string(), - pkid: 6, - payload: "".into(), - }), - ] + fn build_outgoing_pub() -> HashMap { + let mut outgoing_pub = HashMap::new(); + let (tx, _) = NoticeTx::new(); + outgoing_pub.insert( + 2, + ( + Publish { + dup: false, + qos: QoS::AtMostOnce, + retain: false, + topic: "test".to_string(), + pkid: 1, + payload: "".into(), + }, + tx, + ), + ); + let (tx, _) = NoticeTx::new(); + outgoing_pub.insert( + 3, + ( + Publish { + dup: false, + qos: QoS::AtMostOnce, + retain: false, + topic: "test".to_string(), + pkid: 2, + payload: "".into(), + }, + tx, + ), + ); + let (tx, _) = NoticeTx::new(); + outgoing_pub.insert( + 4, + ( + Publish { + dup: false, + qos: QoS::AtMostOnce, + retain: false, + topic: "test".to_string(), + pkid: 3, + payload: "".into(), + }, + tx, + ), + ); + let (tx, _) = NoticeTx::new(); + outgoing_pub.insert( + 7, + ( + Publish { + dup: false, + qos: QoS::AtMostOnce, + retain: false, + topic: "test".to_string(), + pkid: 6, + payload: "".into(), + }, + tx, + ), + ); + + outgoing_pub } mqtt.outgoing_pub = build_outgoing_pub(); mqtt.last_puback = 3; let requests = mqtt.clean(); let res = vec![6, 1, 2, 3]; - for (req, idx) in requests.iter().zip(res) { + for ((_, req), idx) in requests.iter().zip(res) { if let Request::Publish(publish) = req { assert_eq!(publish.pkid, idx); } else { @@ -860,7 +961,7 @@ mod test { mqtt.last_puback = 0; let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; - for (req, idx) in requests.iter().zip(res) { + for ((_, req), idx) in requests.iter().zip(res) { if let Request::Publish(publish) = req { assert_eq!(publish.pkid, idx); } else { @@ -872,7 +973,7 @@ mod test { mqtt.last_puback = 6; let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; - for (req, idx) in requests.iter().zip(res) { + for ((_, req), idx) in requests.iter().zip(res) { if let Request::Publish(publish) = req { assert_eq!(publish.pkid, idx); } else { diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index f8629b8c5..01a412df0 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -6,9 +6,9 @@ use super::mqttbytes::v5::{ Filter, PubAck, PubRec, Publish, PublishProperties, Subscribe, SubscribeProperties, Unsubscribe, UnsubscribeProperties, }; -use super::mqttbytes::{valid_filter, QoS}; +use super::mqttbytes::QoS; use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; -use crate::valid_topic; +use crate::{valid_topic, NoticeFuture, NoticeTx}; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; @@ -25,15 +25,15 @@ pub enum ClientError { TryRequest(Request), } -impl From> for ClientError { - fn from(e: SendError) -> Self { - Self::Request(e.into_inner()) +impl From> for ClientError { + fn from(e: SendError<(NoticeTx, Request)>) -> Self { + Self::Request(e.into_inner().1) } } -impl From> for ClientError { - fn from(e: TrySendError) -> Self { - Self::TryRequest(e.into_inner()) +impl From> for ClientError { + fn from(e: TrySendError<(NoticeTx, Request)>) -> Self { + Self::TryRequest(e.into_inner().1) } } @@ -46,7 +46,7 @@ impl From> for ClientError { /// from the broker, i.e. move ahead. #[derive(Clone, Debug)] pub struct AsyncClient { - request_tx: Sender, + request_tx: Sender<(NoticeTx, Request)>, } impl AsyncClient { @@ -66,7 +66,7 @@ impl AsyncClient { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_senders(request_tx: Sender) -> AsyncClient { + pub fn from_senders(request_tx: Sender<(NoticeTx, Request)>) -> AsyncClient { AsyncClient { request_tx } } @@ -78,7 +78,7 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -86,12 +86,15 @@ impl AsyncClient { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); + let request = Request::Publish(publish); if !valid_topic(&topic) { - return Err(ClientError::Request(publish)); + return Err(ClientError::Request(request)); } - self.request_tx.send_async(publish).await?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.send_async((notice_tx, request)).await?; + + Ok(future) } pub async fn publish_with_properties( @@ -101,7 +104,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -116,7 +119,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -132,7 +135,7 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -140,12 +143,15 @@ impl AsyncClient { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); + let request = Request::Publish(publish); if !valid_topic(&topic) { - return Err(ClientError::TryRequest(publish)); + return Err(ClientError::TryRequest(request)); } - self.request_tx.try_send(publish)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } pub fn try_publish_with_properties( @@ -155,7 +161,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -169,7 +175,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -178,22 +184,27 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub async fn ack(&self, publish: &Publish) -> Result { let ack = get_ack_req(publish); + let (notice_tx, future) = NoticeTx::new(); if let Some(ack) = ack { - self.request_tx.send_async(ack).await?; + self.request_tx.send_async((notice_tx, ack)).await?; } - Ok(()) + + Ok(future) } /// Attempts to send a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn try_ack(&self, publish: &Publish) -> Result { let ack = get_ack_req(publish); + + let (notice_tx, future) = NoticeTx::new(); if let Some(ack) = ack { - self.request_tx.try_send(ack)?; + self.request_tx.try_send((notice_tx, ack))?; } - Ok(()) + + Ok(future) } /// Sends a MQTT Publish to the `EventLoop` @@ -204,19 +215,22 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where S: Into, { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); + let request = Request::Publish(publish); if !valid_topic(&topic) { - return Err(ClientError::TryRequest(publish)); + return Err(ClientError::TryRequest(request)); } - self.request_tx.send_async(publish).await?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.send_async((notice_tx, request)).await?; + + Ok(future) } pub async fn publish_bytes_with_properties( @@ -226,7 +240,7 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, { @@ -240,7 +254,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result<(), ClientError> + ) -> Result where S: Into, { @@ -254,16 +268,17 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { let filter = Filter::new(topic, qos); - let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); - let request: Request = Request::Subscribe(subscribe); - if !is_filter_valid { - return Err(ClientError::Request(request)); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(subscribe.into())); } - self.request_tx.send_async(request).await?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.send_async((notice_tx, subscribe.into())).await?; + + Ok(future) } pub async fn subscribe_with_properties>( @@ -271,11 +286,15 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_subscribe(topic, qos, Some(properties)).await } - pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub async fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { self.handle_subscribe(topic, qos, None).await } @@ -285,16 +304,17 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { let filter = Filter::new(topic, qos); - let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); - let request = Request::Subscribe(subscribe); - if !is_filter_valid { - return Err(ClientError::TryRequest(request)); + if !subscribe.has_valid_filters() { + return Err(ClientError::TryRequest(subscribe.into())); } - self.request_tx.try_send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.try_send((notice_tx, subscribe.into()))?; + + Ok(future) } pub fn try_subscribe_with_properties>( @@ -302,11 +322,15 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_try_subscribe(topic, qos, Some(properties)) } - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { self.handle_try_subscribe(topic, qos, None) } @@ -315,34 +339,33 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); - let request = Request::Subscribe(subscribe); - if !is_valid_filters { - return Err(ClientError::Request(request)); + let subscribe = Subscribe::new_many(topics, properties); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(subscribe.into())); } - self.request_tx.send_async(request).await?; - Ok(()) + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.send_async((notice_tx, subscribe.into())).await?; + + Ok(future) } pub async fn subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)).await } - pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub async fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -354,33 +377,33 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); - let request = Request::Subscribe(subscribe); - if !is_valid_filters { - return Err(ClientError::TryRequest(request)); + let subscribe = Subscribe::new_many(topics, properties); + if !subscribe.has_valid_filters() { + return Err(ClientError::TryRequest(subscribe.into())); } - self.request_tx.try_send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.try_send((notice_tx, subscribe.into()))?; + + Ok(future) } pub fn try_subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { self.handle_try_subscribe_many(topics, Some(properties)) } - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -392,22 +415,28 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { let unsubscribe = Unsubscribe::new(topic, properties); let request = Request::Unsubscribe(unsubscribe); - self.request_tx.send_async(request).await?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } pub async fn unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_unsubscribe(topic, Some(properties)).await } - pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub async fn unsubscribe>( + &self, + topic: S, + ) -> Result { self.handle_unsubscribe(topic, None).await } @@ -416,37 +445,46 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { let unsubscribe = Unsubscribe::new(topic, properties); let request = Request::Unsubscribe(unsubscribe); - self.request_tx.try_send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } pub fn try_unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_try_unsubscribe(topic, Some(properties)) } - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn try_unsubscribe>(&self, topic: S) -> Result { self.handle_try_unsubscribe(topic, None) } /// Sends a MQTT disconnect to the `EventLoop` - pub async fn disconnect(&self) -> Result<(), ClientError> { + pub async fn disconnect(&self) -> Result { let request = Request::Disconnect; - self.request_tx.send_async(request).await?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.send_async((notice_tx, request)).await?; + + Ok(future) } /// Attempts to send a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { + pub fn try_disconnect(&self) -> Result { let request = Request::Disconnect; - self.request_tx.try_send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.request_tx.try_send((notice_tx, request))?; + + Ok(future) } } @@ -495,7 +533,7 @@ impl Client { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_sender(request_tx: Sender) -> Client { + pub fn from_sender(request_tx: Sender<(NoticeTx, Request)>) -> Client { Client { client: AsyncClient::from_senders(request_tx), } @@ -509,7 +547,7 @@ impl Client { retain: bool, payload: P, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -517,12 +555,15 @@ impl Client { let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); + let request = Request::Publish(publish); if !valid_topic(&topic) { - return Err(ClientError::Request(publish)); + return Err(ClientError::Request(request)); } - self.client.request_tx.send(publish)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.client.request_tx.send((notice_tx, request))?; + + Ok(future) } pub fn publish_with_properties( @@ -532,7 +573,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -546,7 +587,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -561,7 +602,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -576,7 +617,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -585,19 +626,20 @@ impl Client { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn ack(&self, publish: &Publish) -> Result { let ack = get_ack_req(publish); + let (notice_tx, future) = NoticeTx::new(); if let Some(ack) = ack { - self.client.request_tx.send(ack)?; + self.client.request_tx.send((notice_tx, ack))?; } - Ok(()) + + Ok(future) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - self.client.try_ack(publish)?; - Ok(()) + pub fn try_ack(&self, publish: &Publish) -> Result { + self.client.try_ack(publish) } /// Sends a MQTT Subscribe to the `EventLoop` @@ -606,16 +648,17 @@ impl Client { topic: S, qos: QoS, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { let filter = Filter::new(topic, qos); - let is_filter_valid = valid_filter(&filter.path); let subscribe = Subscribe::new(filter, properties); - let request = Request::Subscribe(subscribe); - if !is_filter_valid { - return Err(ClientError::Request(request)); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(subscribe.into())); } - self.client.request_tx.send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.client.request_tx.try_send((notice_tx, subscribe.into()))?; + + Ok(future) } pub fn subscribe_with_properties>( @@ -623,11 +666,15 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_subscribe(topic, qos, Some(properties)) } - pub fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { self.handle_subscribe(topic, qos, None) } @@ -637,12 +684,16 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.client .try_subscribe_with_properties(topic, qos, properties) } - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { self.client.try_subscribe(topic, qos) } @@ -651,33 +702,33 @@ impl Client { &self, topics: T, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { - let mut topics_iter = topics.into_iter(); - let is_valid_filters = topics_iter.all(|filter| valid_filter(&filter.path)); - let subscribe = Subscribe::new_many(topics_iter, properties); - let request = Request::Subscribe(subscribe); - if !is_valid_filters { - return Err(ClientError::Request(request)); + let subscribe = Subscribe::new_many(topics, properties); + if !subscribe.has_valid_filters() { + return Err(ClientError::Request(subscribe.into())); } - self.client.request_tx.send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.client.request_tx.try_send((notice_tx, subscribe.into()))?; + + Ok(future) } pub fn subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)) } - pub fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -688,7 +739,7 @@ impl Client { &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { @@ -696,7 +747,7 @@ impl Client { .try_subscribe_many_with_properties(topics, properties) } - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -708,22 +759,25 @@ impl Client { &self, topic: S, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { let unsubscribe = Unsubscribe::new(topic, properties); let request = Request::Unsubscribe(unsubscribe); - self.client.request_tx.send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.client.request_tx.try_send((notice_tx, request))?; + + Ok(future) } pub fn unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_unsubscribe(topic, Some(properties)) } - pub fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn unsubscribe>(&self, topic: S) -> Result { self.handle_unsubscribe(topic, None) } @@ -732,26 +786,28 @@ impl Client { &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.client .try_unsubscribe_with_properties(topic, properties) } - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn try_unsubscribe>(&self, topic: S) -> Result { self.client.try_unsubscribe(topic) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn disconnect(&self) -> Result<(), ClientError> { + pub fn disconnect(&self) -> Result { let request = Request::Disconnect; - self.client.request_tx.send(request)?; - Ok(()) + + let (notice_tx, future) = NoticeTx::new(); + self.client.request_tx.send((notice_tx, request))?; + + Ok(future) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { - self.client.try_disconnect()?; - Ok(()) + pub fn try_disconnect(&self) -> Result { + self.client.try_disconnect() } } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index cd0568ada..3ad1d4874 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -3,6 +3,7 @@ use super::mqttbytes::v5::*; use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport}; use crate::eventloop::socket_connect; use crate::framed::AsyncReadWrite; +use crate::notice::NoticeTx; use flume::{bounded, Receiver, Sender}; use tokio::select; @@ -73,11 +74,11 @@ pub struct EventLoop { /// Current state of the connection pub state: MqttState, /// Request stream - requests_rx: Receiver, + requests_rx: Receiver<(NoticeTx, Request)>, /// Requests handle to send requests - pub(crate) requests_tx: Sender, + pub(crate) requests_tx: Sender<(NoticeTx, Request)>, /// Pending packets from last session - pub pending: VecDeque, + pub(crate) pending: VecDeque<(NoticeTx, Request)>, /// Network connection to the broker network: Option, /// Keep alive time @@ -128,7 +129,7 @@ impl EventLoop { // drain requests from channel which weren't yet received let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect(); - requests_in_channel.retain(|request| { + requests_in_channel.retain(|(_, request)| { match request { Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack _ => true, @@ -223,8 +224,8 @@ impl EventLoop { &self.requests_rx, self.options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { - Ok(request) => { - if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { + Ok((tx, request)) => { + if let Some(outgoing) = self.state.handle_outgoing_packet(tx, request)? { network.write(outgoing).await?; } network.flush().await?; @@ -245,7 +246,8 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.options.keep_alive); - if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? { + let (tx, _) = NoticeTx::new(); + if let Some(outgoing) = self.state.handle_outgoing_packet(tx, Request::PingReq)? { network.write(outgoing).await?; } network.flush().await?; @@ -255,10 +257,10 @@ impl EventLoop { } async fn next_request( - pending: &mut VecDeque, - rx: &Receiver, + pending: &mut VecDeque<(NoticeTx, Request)>, + rx: &Receiver<(NoticeTx, Request)>, pending_throttle: Duration, - ) -> Result { + ) -> Result<(NoticeTx, Request), ConnectionError> { if !pending.is_empty() { time::sleep(pending_throttle).await; // We must call .next() AFTER sleep() otherwise .next() would diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 2518a93f1..e393c9536 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -33,7 +33,7 @@ pub type Incoming = Packet; /// Requests by the client to mqtt event loop. Request are /// handled one by one. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum Request { Publish(Publish), PubAck(PubAck), @@ -49,6 +49,12 @@ pub enum Request { Disconnect, } +impl From for Request { + fn from(subscribe: Subscribe) -> Self { + Self::Subscribe(subscribe) + } +} + #[cfg(feature = "websocket")] type RequestModifierFn = Arc< dyn Fn(http::Request<()>) -> Pin> + Send>> diff --git a/rumqttc/src/v5/mqttbytes/v5/publish.rs b/rumqttc/src/v5/mqttbytes/v5/publish.rs index 74fbee225..6699e4b20 100644 --- a/rumqttc/src/v5/mqttbytes/v5/publish.rs +++ b/rumqttc/src/v5/mqttbytes/v5/publish.rs @@ -2,7 +2,7 @@ use super::*; use bytes::{Buf, Bytes}; /// Publish packet -#[derive(Clone, Debug, PartialEq, Eq, Default)] +#[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct Publish { pub dup: bool, pub qos: QoS, diff --git a/rumqttc/src/v5/mqttbytes/v5/subscribe.rs b/rumqttc/src/v5/mqttbytes/v5/subscribe.rs index 4167cd671..72da0a764 100644 --- a/rumqttc/src/v5/mqttbytes/v5/subscribe.rs +++ b/rumqttc/src/v5/mqttbytes/v5/subscribe.rs @@ -2,7 +2,7 @@ use super::*; use bytes::{Buf, Bytes}; /// Subscription packet -#[derive(Clone, Debug, PartialEq, Eq, Default)] +#[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct Subscribe { pub pkid: u16, pub filters: Vec, @@ -95,6 +95,14 @@ impl Subscribe { Ok(1 + remaining_len_bytes + remaining_len) } + + pub(crate) fn has_valid_filters(&self) -> bool { + if self.filters.is_empty() { + return false; + } + + self.filters.iter().all(Filter::is_valid) + } } /// Subscription filter @@ -177,6 +185,10 @@ impl Filter { write_mqtt_string(buffer, self.path.as_str()); buffer.put_u8(options); } + + fn is_valid(&self) -> bool { + valid_filter(&self.path) + } } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/rumqttc/src/v5/mqttbytes/v5/unsubscribe.rs b/rumqttc/src/v5/mqttbytes/v5/unsubscribe.rs index 2b671ce39..5e357b96e 100644 --- a/rumqttc/src/v5/mqttbytes/v5/unsubscribe.rs +++ b/rumqttc/src/v5/mqttbytes/v5/unsubscribe.rs @@ -1,8 +1,9 @@ -use super::*; use bytes::{Buf, Bytes}; +use super::*; + /// Unsubscribe packet -#[derive(Debug, Clone, PartialEq, Eq, Default)] +#[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct Unsubscribe { pub pkid: u16, pub filters: Vec, diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 6f37a4719..d0d196909 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,3 +1,5 @@ +use crate::{NoticeError, NoticeTx}; + use super::mqttbytes::v5::{ ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish, @@ -86,7 +88,7 @@ impl From for StateError { // This is done for 2 reasons // Bad acks or out of order acks aren't O(n) causing cpu spikes // Any missing acks from the broker are detected during the next recycled use of packet ids -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct MqttState { /// Status of last ping pub await_pingresp: bool, @@ -103,13 +105,17 @@ pub struct MqttState { /// Number of outgoing inflight publishes pub(crate) inflight: u16, /// Outgoing QoS 1, 2 publishes which aren't acked yet - pub(crate) outgoing_pub: Vec>, + pub(crate) outgoing_pub: HashMap, /// Packet ids of released QoS 2 publishes - pub(crate) outgoing_rel: FixedBitSet, + pub(crate) outgoing_rel: HashMap, /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, + + outgoing_sub: HashMap, + outgoing_unsub: HashMap, + /// Last collision due to broker not acking in order - pub collision: Option, + pub(crate) collision: Option<(Publish, NoticeTx)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately @@ -136,10 +142,12 @@ impl MqttState { last_outgoing: Instant::now(), last_pkid: 0, inflight: 0, + outgoing_pub: HashMap::new(), + outgoing_rel: HashMap::new(), // index 0 is wasted as 0 is not a valid packet id - outgoing_pub: vec![None; max_inflight as usize + 1], - outgoing_rel: FixedBitSet::with_capacity(max_inflight as usize + 1), incoming_pub: FixedBitSet::with_capacity(u16::MAX as usize + 1), + outgoing_sub: HashMap::new(), + outgoing_unsub: HashMap::new(), collision: None, // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), @@ -153,20 +161,18 @@ impl MqttState { } /// Returns inflight outgoing packets and clears internal queues - pub fn clean(&mut self) -> Vec { + pub fn clean(&mut self) -> Vec<(NoticeTx, Request)> { let mut pending = Vec::with_capacity(100); // remove and collect pending publishes - for publish in self.outgoing_pub.iter_mut() { - if let Some(publish) = publish.take() { - let request = Request::Publish(publish); - pending.push(request); - } + for (_, (publish, tx)) in self.outgoing_pub.drain() { + let request = Request::Publish(publish); + pending.push((tx, request)); } // remove and collect pending releases - for pkid in self.outgoing_rel.ones() { - let request = Request::PubRel(PubRel::new(pkid as u16, None)); - pending.push(request); + for (pkid, tx) in self.outgoing_rel.drain() { + let request = Request::PubRel(PubRel::new(pkid, None)); + pending.push((tx, request)); } self.outgoing_rel.clear(); @@ -187,13 +193,14 @@ impl MqttState { /// be put on to the network by the eventloop pub fn handle_outgoing_packet( &mut self, + tx: NoticeTx, request: Request, ) -> Result, StateError> { let packet = match request { - Request::Publish(publish) => self.outgoing_publish(publish)?, - Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, - Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, - Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, + Request::Publish(publish) => self.outgoing_publish(publish, tx)?, + Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel, tx)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe, tx)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, Request::PingReq => self.outgoing_ping()?, Request::Disconnect => { self.outgoing_disconnect(DisconnectReasonCode::NormalDisconnection)? @@ -247,14 +254,29 @@ impl MqttState { &mut self, suback: &mut SubAck, ) -> Result, StateError> { + if suback.pkid > self.max_outgoing_inflight { + error!("Unsolicited suback packet: {:?}", suback.pkid); + return Err(StateError::Unsolicited(suback.pkid)); + } + + let tx = self + .outgoing_sub + .remove(&suback.pkid) + .ok_or(StateError::Unsolicited(suback.pkid))?; + for reason in suback.return_codes.iter() { match reason { SubscribeReasonCode::Success(qos) => { debug!("SubAck Pkid = {:?}, QoS = {:?}", suback.pkid, qos); } - _ => return Err(StateError::SubFail { reason: *reason }), + _ => { + tx.error(NoticeError::V5Subscribe(*reason)); + return Err(StateError::SubFail { reason: *reason }); + } } } + tx.success(); + Ok(None) } @@ -262,11 +284,24 @@ impl MqttState { &mut self, unsuback: &mut UnsubAck, ) -> Result, StateError> { + if unsuback.pkid > self.max_outgoing_inflight { + error!("Unsolicited unsuback packet: {:?}", unsuback.pkid); + return Err(StateError::Unsolicited(unsuback.pkid)); + } + + let tx = self + .outgoing_unsub + .remove(&unsuback.pkid) + .ok_or(StateError::Unsolicited(unsuback.pkid))?; + for reason in unsuback.reasons.iter() { if reason != &UnsubAckReason::Success { + tx.error(NoticeError::V5Unsubscribe(*reason)); return Err(StateError::UnsubFail { reason: *reason }); } } + tx.success(); + Ok(None) } @@ -359,16 +394,17 @@ impl MqttState { } fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { - let publish = self - .outgoing_pub - .get_mut(puback.pkid as usize) - .ok_or(StateError::Unsolicited(puback.pkid))?; - - if publish.take().is_none() { + if puback.pkid > self.max_outgoing_inflight { error!("Unsolicited puback packet: {:?}", puback.pkid); return Err(StateError::Unsolicited(puback.pkid)); } + let (_, tx) = self + .outgoing_pub + .remove(&puback.pkid) + .ok_or(StateError::Unsolicited(puback.pkid))?; + tx.success(); + self.inflight -= 1; if puback.reason != PubAckReason::Success @@ -379,32 +415,32 @@ impl MqttState { }); } - if let Some(publish) = self.check_collision(puback.pkid) { - self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); + let packet = self.check_collision(puback.pkid).map(|(publish, tx)| { + self.outgoing_pub + .insert(publish.pkid, (publish.clone(), tx)); self.inflight += 1; - let pkid = publish.pkid; - let event = Event::Outgoing(Outgoing::Publish(pkid)); + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); self.collision_ping_count = 0; - return Ok(Some(Packet::Publish(publish))); - } + Packet::Publish(publish) + }); - Ok(None) + Ok(packet) } fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { - let publish = self - .outgoing_pub - .get_mut(pubrec.pkid as usize) - .ok_or(StateError::Unsolicited(pubrec.pkid))?; - - if publish.take().is_none() { + if pubrec.pkid > self.max_outgoing_inflight { error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); return Err(StateError::Unsolicited(pubrec.pkid)); } + let (_, tx) = self + .outgoing_pub + .remove(&pubrec.pkid) + .ok_or(StateError::Unsolicited(pubrec.pkid))?; + if pubrec.reason != PubRecReason::Success && pubrec.reason != PubRecReason::NoMatchingSubscribers { @@ -414,7 +450,7 @@ impl MqttState { } // NOTE: Inflight - 1 for qos2 in comp - self.outgoing_rel.insert(pubrec.pkid as usize); + self.outgoing_rel.insert(pubrec.pkid, tx); let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); self.events.push_back(event); @@ -441,29 +477,27 @@ impl MqttState { } fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { - let outgoing = self.check_collision(pubcomp.pkid).map(|publish| { - let pkid = publish.pkid; - let event = Event::Outgoing(Outgoing::Publish(pkid)); + if pubcomp.pkid > self.max_outgoing_inflight { + error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); + return Err(StateError::Unsolicited(pubcomp.pkid)); + } + self.outgoing_rel + .remove(&pubcomp.pkid) + .ok_or(StateError::Unsolicited(pubcomp.pkid))? + .success(); + + self.inflight -= 1; + let packet = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { + self.outgoing_pub + .insert(pubcomp.pkid, (publish.clone(), tx)); + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); self.collision_ping_count = 0; Packet::Publish(publish) }); - if !self.outgoing_rel.contains(pubcomp.pkid as usize) { - error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); - return Err(StateError::Unsolicited(pubcomp.pkid)); - } - self.outgoing_rel.set(pubcomp.pkid as usize, false); - - if pubcomp.reason != PubCompReason::Success { - return Err(StateError::PubCompFail { - reason: pubcomp.reason, - }); - } - - self.inflight -= 1; - Ok(outgoing) + Ok(packet) } fn handle_incoming_pingresp(&mut self) -> Result, StateError> { @@ -473,21 +507,21 @@ impl MqttState { /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { + fn outgoing_publish( + &mut self, + mut publish: Publish, + notice_tx: NoticeTx, + ) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); } let pkid = publish.pkid; - if self - .outgoing_pub - .get(publish.pkid as usize) - .ok_or(StateError::Unsolicited(publish.pkid))? - .is_some() - { + + if self.outgoing_pub.get(&publish.pkid).is_some() { info!("Collision on packet id = {:?}", publish.pkid); - self.collision = Some(publish); + self.collision = Some((publish, notice_tx)); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); return Ok(None); @@ -495,9 +529,11 @@ impl MqttState { // if there is an existing publish at this pkid, this implies that broker hasn't acked this // packet yet. This error is possible only when broker isn't acking sequentially - self.outgoing_pub[pkid as usize] = Some(publish.clone()); + self.outgoing_pub.insert(pkid, (publish.clone(), notice_tx)); self.inflight += 1; - }; + } else { + notice_tx.success() + } debug!( "Publish. Topic = {}, Pkid = {:?}, Payload Size = {:?}", @@ -507,7 +543,6 @@ impl MqttState { ); let pkid = publish.pkid; - if let Some(props) = &publish.properties { if let Some(alias) = props.topic_alias { if alias > self.broker_topic_alias_max { @@ -527,8 +562,12 @@ impl MqttState { Ok(Some(Packet::Publish(publish))) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { - let pubrel = self.save_pubrel(pubrel)?; + fn outgoing_pubrel( + &mut self, + pubrel: PubRel, + notice_tx: NoticeTx, + ) -> Result, StateError> { + let pubrel = self.save_pubrel(pubrel, notice_tx)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); @@ -589,6 +628,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, + notice_tx: NoticeTx, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -602,7 +642,7 @@ impl MqttState { subscription.filters, subscription.pkid ); - let pkid = subscription.pkid; + self.outgoing_sub.insert(pkid, notice_tx); let event = Event::Outgoing(Outgoing::Subscribe(pkid)); self.events.push_back(event); @@ -612,6 +652,7 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, + notice_tx: NoticeTx, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -621,7 +662,7 @@ impl MqttState { unsub.filters, unsub.pkid ); - let pkid = unsub.pkid; + self.outgoing_unsub.insert(pkid, notice_tx); let event = Event::Outgoing(Outgoing::Unsubscribe(pkid)); self.events.push_back(event); @@ -639,8 +680,8 @@ impl MqttState { Ok(Some(Packet::Disconnect(Disconnect::new(reason)))) } - fn check_collision(&mut self, pkid: u16) -> Option { - if let Some(publish) = &self.collision { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, NoticeTx)> { + if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); } @@ -649,7 +690,11 @@ impl MqttState { None } - fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { + fn save_pubrel( + &mut self, + mut pubrel: PubRel, + notice_tx: NoticeTx, + ) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets 0 => { @@ -659,7 +704,7 @@ impl MqttState { _ => pubrel, }; - self.outgoing_rel.insert(pubrel.pkid as usize); + self.outgoing_rel.insert(pubrel.pkid, notice_tx); self.inflight += 1; Ok(pubrel) } @@ -686,6 +731,8 @@ impl MqttState { #[cfg(test)] mod test { + use crate::notice::NoticeTx; + use super::mqttbytes::v5::*; use super::mqttbytes::*; use super::{Event, Incoming, Outgoing, Request}; @@ -739,7 +786,8 @@ mod test { let publish = build_outgoing_publish(QoS::AtMostOnce); // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 0); @@ -747,12 +795,14 @@ mod test { let publish = build_outgoing_publish(QoS::AtLeastOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); assert_eq!(mqtt.last_pkid, 2); assert_eq!(mqtt.inflight, 2); @@ -760,12 +810,14 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); assert_eq!(mqtt.last_pkid, 4); assert_eq!(mqtt.inflight, 4); } @@ -777,17 +829,20 @@ mod test { // QoS2 publish let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish.clone()).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be set back down to 0, since we hit the limit - mqtt.outgoing_publish(publish.clone()).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 2); // This should cause a collition - mqtt.outgoing_publish(publish.clone()).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 2); assert!(mqtt.collision.is_some()); @@ -797,7 +852,8 @@ mod test { assert_eq!(mqtt.inflight, 1); // Now there should be space in the outgoing queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish.clone(), tx).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 2); } @@ -881,8 +937,10 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish1).unwrap(); - mqtt.outgoing_publish(publish2).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish1, tx).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish2, tx).unwrap(); assert_eq!(mqtt.inflight, 2); mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap(); @@ -891,8 +949,8 @@ mod test { mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 0); - assert!(mqtt.outgoing_pub[1].is_none()); - assert!(mqtt.outgoing_pub[2].is_none()); + assert!(mqtt.outgoing_pub.get(&1).is_none()); + assert!(mqtt.outgoing_pub.get(&2).is_none()); } #[test] @@ -916,18 +974,20 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - let _publish_out = mqtt.outgoing_publish(publish1); - let _publish_out = mqtt.outgoing_publish(publish2); + let (tx, _) = NoticeTx::new(); + let _publish_out = mqtt.outgoing_publish(publish1, tx); + let (tx, _) = NoticeTx::new(); + let _publish_out = mqtt.outgoing_publish(publish2, tx); mqtt.handle_incoming_pubrec(&PubRec::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 2); // check if the remaining element's pkid is 1 - let backup = mqtt.outgoing_pub[1].clone(); - assert_eq!(backup.unwrap().pkid, 1); + let (backup, _) = mqtt.outgoing_pub.get(&1).unwrap(); + assert_eq!(backup.pkid, 1); - // check if the qos2 element's release pkid is 2 - assert!(mqtt.outgoing_rel.contains(2)); + // check if the qos2 element's release pkik has been set + assert!(mqtt.outgoing_rel.get(&2).is_some()); } #[test] @@ -935,7 +995,8 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - match mqtt.outgoing_publish(publish).unwrap().unwrap() { + let (tx, _) = NoticeTx::new(); + match mqtt.outgoing_publish(publish, tx).unwrap().unwrap() { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } @@ -975,7 +1036,8 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish).unwrap(); + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_publish(publish, tx).unwrap(); mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap(); mqtt.handle_incoming_pubcomp(&PubComp::new(1, None)) @@ -990,7 +1052,8 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(publish)) + let (tx, _) = NoticeTx::new(); + mqtt.handle_outgoing_packet(tx, Request::Publish(publish)) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1, None))) .unwrap();