diff --git a/CHANGELOG.md b/CHANGELOG.md index ac84c6b0..462472f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ `FunctionCode`. - Added `ExceptionCode::Custom`. - Removed `TryFrom` and `#[repr(u8)]` for `Exception`. +- Removed `FunctionCode::Disconnect`. +- Client: Added new `disconnect()` method to trait that returns + `std::io::Result`. ## v0.14.1 (2024-09-10) @@ -32,7 +35,7 @@ ### Breaking Changes -- Add `FunctionCode::Disconnect`. +- Added `FunctionCode::Disconnect`. ## v0.13.0 (2024-06-23, yanked) diff --git a/Cargo.toml b/Cargo.toml index 854dbf25..402f9e73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ log = "0.4.20" smallvec = { version = "1.13.1", optional = true, default-features = false } socket2 = { version = "0.5.5", optional = true, default-features = false } thiserror = "1.0.58" -tokio = { version = "1.35.1", default-features = false } +tokio = { version = "1.35.1", default-features = false, features = ["io-util"] } # Disable default-features to exclude unused dependency on libudev tokio-serial = { version = "5.4.4", optional = true, default-features = false } tokio-util = { version = "0.7.10", optional = true, default-features = false, features = [ diff --git a/examples/rtu-client.rs b/examples/rtu-client.rs index 80ffe062..bf896542 100644 --- a/examples/rtu-client.rs +++ b/examples/rtu-client.rs @@ -21,7 +21,7 @@ async fn main() -> Result<(), Box> { println!("Sensor value is: {rsp:?}"); println!("Disconnecting"); - ctx.disconnect().await??; + ctx.disconnect().await?; Ok(()) } diff --git a/examples/tcp-client.rs b/examples/tcp-client.rs index 48783355..8f2adf76 100644 --- a/examples/tcp-client.rs +++ b/examples/tcp-client.rs @@ -23,7 +23,7 @@ async fn main() -> Result<(), Box> { println!("The coupler ID is '{id}'"); println!("Disconnecting"); - ctx.disconnect().await??; + ctx.disconnect().await?; Ok(()) } diff --git a/examples/tls-client.rs b/examples/tls-client.rs index 6508abb3..7ad8602c 100644 --- a/examples/tls-client.rs +++ b/examples/tls-client.rs @@ -117,7 +117,7 @@ async fn main() -> Result<(), Box> { println!("Reading Holding Registers"); let data = ctx.read_holding_registers(40000, 68).await?; println!("Holding Registers Data is '{:?}'", data); - ctx.disconnect().await??; + ctx.disconnect().await?; Ok(()) } diff --git a/src/client/mod.rs b/src/client/mod.rs index 40141371..651f6125 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -7,7 +7,7 @@ use std::{borrow::Cow, fmt::Debug, io}; use async_trait::async_trait; -use crate::{frame::*, slave::*, Error, Result}; +use crate::{frame::*, slave::*, Result}; #[cfg(feature = "rtu")] pub mod rtu; @@ -21,11 +21,22 @@ pub mod sync; /// Transport independent asynchronous client trait #[async_trait] pub trait Client: SlaveContext + Send + Debug { - /// Invoke a _Modbus_ function + /// Invokes a _Modbus_ function. async fn call(&mut self, request: Request<'_>) -> Result; + + /// Disconnects the client. + /// + /// Permanently disconnects the client by shutting down the + /// underlying stream in a graceful manner (`AsyncDrop`). + /// + /// Dropping the client without explicitly disconnecting it + /// beforehand should also work and free all resources. The + /// actual behavior might depend on the underlying transport + /// protocol (RTU/TCP) that is used by the client. + async fn disconnect(&mut self) -> io::Result<()>; } -/// Asynchronous Modbus reader +/// Asynchronous _Modbus_ reader #[async_trait] pub trait Reader: Client { /// Read multiple coils (0x01) @@ -83,22 +94,6 @@ pub struct Context { client: Box, } -impl Context { - /// Disconnect the client - pub async fn disconnect(&mut self) -> Result<()> { - // Disconnecting is expected to fail! - let res = self.client.call(Request::Disconnect).await; - match res { - Ok(_) => unreachable!(), - Err(Error::Transport(err)) => match err.kind() { - io::ErrorKind::NotConnected | io::ErrorKind::BrokenPipe => Ok(Ok(())), - _ => Err(Error::Transport(err)), - }, - Err(err) => Err(err), - } - } -} - impl From> for Context { fn from(client: Box) -> Self { Self { client } @@ -116,6 +111,10 @@ impl Client for Context { async fn call(&mut self, request: Request<'_>) -> Result { self.client.call(request).await } + + async fn disconnect(&mut self) -> io::Result<()> { + self.client.disconnect().await + } } impl SlaveContext for Context { @@ -319,10 +318,10 @@ impl Writer for Context { #[cfg(test)] mod tests { - use crate::Result; + use crate::{Error, Result}; use super::*; - use std::sync::Mutex; + use std::{io, sync::Mutex}; #[derive(Default, Debug)] pub(crate) struct ClientMock { @@ -358,6 +357,10 @@ mod tests { Err(err) => Err(err), } } + + async fn disconnect(&mut self) -> io::Result<()> { + Ok(()) + } } impl SlaveContext for ClientMock { diff --git a/src/codec/mod.rs b/src/codec/mod.rs index 5dd67b23..41f3a721 100644 --- a/src/codec/mod.rs +++ b/src/codec/mod.rs @@ -103,7 +103,6 @@ impl<'a> TryFrom> for Bytes { data.put_u8(*d); } } - Disconnect => unreachable!(), } Ok(data.freeze()) } @@ -460,7 +459,6 @@ fn request_byte_count(req: &Request<'_>) -> usize { MaskWriteRegister(_, _, _) => 7, ReadWriteMultipleRegisters(_, _, _, ref data) => 10 + data.len() * 2, Custom(_, ref data) => 1 + data.len(), - Disconnect => unreachable!(), } } diff --git a/src/codec/rtu.rs b/src/codec/rtu.rs index 7585aa26..95ff7616 100644 --- a/src/codec/rtu.rs +++ b/src/codec/rtu.rs @@ -315,13 +315,7 @@ impl Decoder for ServerCodec { // to transmission errors, because the frame's bytes // have already been verified with the CRC. RequestPdu::try_from(pdu_data) - .map(|pdu| { - Some(RequestAdu { - hdr, - pdu, - disconnect: false, - }) - }) + .map(|pdu| Some(RequestAdu { hdr, pdu })) .map_err(|err| { // Unrecoverable error log::error!("Failed to decode request PDU: {}", err); @@ -334,21 +328,7 @@ impl<'a> Encoder> for ClientCodec { type Error = Error; fn encode(&mut self, adu: RequestAdu<'a>, buf: &mut BytesMut) -> Result<()> { - let RequestAdu { - hdr, - pdu, - disconnect, - } = adu; - if disconnect { - // The disconnect happens implicitly after letting this request - // fail by returning an error. This will drop the attached - // transport, e.g. for closing a stale, exclusive connection - // to a serial port before trying to reconnect. - return Err(Error::new( - ErrorKind::NotConnected, - "Disconnecting - not an error", - )); - } + let RequestAdu { hdr, pdu } = adu; let pdu_data: Bytes = pdu.try_into()?; buf.reserve(pdu_data.len() + 3); buf.put_u8(hdr.slave_id); @@ -741,11 +721,7 @@ mod tests { let pdu = req.into(); let slave_id = 0x01; let hdr = Header { slave_id }; - let adu = RequestAdu { - hdr, - pdu, - disconnect: false, - }; + let adu = RequestAdu { hdr, pdu }; codec.encode(adu, &mut buf).unwrap(); assert_eq!( @@ -761,11 +737,7 @@ mod tests { let pdu = req.into(); let slave_id = 0x01; let hdr = Header { slave_id }; - let adu = RequestAdu { - hdr, - pdu, - disconnect: false, - }; + let adu = RequestAdu { hdr, pdu }; let mut buf = BytesMut::with_capacity(40); #[allow(unsafe_code)] unsafe { diff --git a/src/codec/tcp.rs b/src/codec/tcp.rs index 6defd8c4..0f86c978 100644 --- a/src/codec/tcp.rs +++ b/src/codec/tcp.rs @@ -125,11 +125,7 @@ impl Decoder for ServerCodec { fn decode(&mut self, buf: &mut BytesMut) -> Result>> { if let Some((hdr, pdu_data)) = self.decoder.decode(buf)? { let pdu = RequestPdu::try_from(pdu_data)?; - Ok(Some(RequestAdu { - hdr, - pdu, - disconnect: false, - })) + Ok(Some(RequestAdu { hdr, pdu })) } else { Ok(None) } @@ -140,20 +136,7 @@ impl<'a> Encoder> for ClientCodec { type Error = Error; fn encode(&mut self, adu: RequestAdu<'a>, buf: &mut BytesMut) -> Result<()> { - let RequestAdu { - hdr, - pdu, - disconnect, - } = adu; - if disconnect { - // The disconnect happens implicitly after letting this request - // fail by returning an error. This will drop the attached - // transport, e.g. for terminating an open connection. - return Err(Error::new( - ErrorKind::NotConnected, - "Disconnecting - not an error", - )); - } + let RequestAdu { hdr, pdu } = adu; let pdu_data: Bytes = pdu.try_into()?; buf.reserve(pdu_data.len() + 7); buf.put_u16(hdr.transaction_id); @@ -288,11 +271,7 @@ mod tests { transaction_id: TRANSACTION_ID, unit_id: UNIT_ID, }; - let adu = RequestAdu { - hdr, - pdu, - disconnect: false, - }; + let adu = RequestAdu { hdr, pdu }; codec.encode(adu, &mut buf).unwrap(); // header assert_eq!(buf[0], TRANSACTION_ID_HI); @@ -316,11 +295,7 @@ mod tests { transaction_id: TRANSACTION_ID, unit_id: UNIT_ID, }; - let adu = RequestAdu { - hdr, - pdu, - disconnect: false, - }; + let adu = RequestAdu { hdr, pdu }; let mut buf = BytesMut::with_capacity(40); #[allow(unsafe_code)] unsafe { diff --git a/src/frame/mod.rs b/src/frame/mod.rs index c4599f2f..8a8e113d 100644 --- a/src/frame/mod.rs +++ b/src/frame/mod.rs @@ -79,8 +79,6 @@ pub enum FunctionCode { /// Custom Modbus Function Code. Custom(u8), - - Disconnect, } impl FunctionCode { @@ -111,11 +109,7 @@ impl FunctionCode { } } - /// Get the [`u8`] value of the current [`FunctionCode`]. - /// - /// # Panics - /// - /// Panics on [`Disconnect`](Self::Disconnect) which has no corresponding Modbus function code. + /// Gets the [`u8`] value of the current [`FunctionCode`]. #[must_use] pub const fn value(self) -> u8 { match self { @@ -139,7 +133,6 @@ impl FunctionCode { Self::ReadFifoQueue => 0x18, Self::EncapsulatedInterfaceTransport => 0x2B, Self::Custom(code) => code, - Self::Disconnect => unreachable!(), } } } @@ -235,17 +228,6 @@ pub enum Request<'a> { /// The first parameter is the Modbus function code. /// The second parameter is the raw bytes of the request. Custom(u8, Cow<'a, [u8]>), - - /// A poison pill for stopping the client service and to release - /// the underlying transport, e.g. for disconnecting from an - /// exclusively used serial port. - /// - /// This is an ugly workaround, because `tokio-proto` does not - /// provide other means to gracefully shut down the `ClientService`. - /// Otherwise the bound transport is never freed as long as the - /// executor is active, even when dropping the Modbus client - /// context. - Disconnect, } impl<'a> Request<'a> { @@ -275,7 +257,6 @@ impl<'a> Request<'a> { ReadWriteMultipleRegisters(addr, qty, write_addr, Cow::Owned(words.into_owned())) } Custom(func, bytes) => Custom(func, Cow::Owned(bytes.into_owned())), - Disconnect => Disconnect, } } @@ -304,8 +285,6 @@ impl<'a> Request<'a> { ReadWriteMultipleRegisters(_, _, _, _) => FunctionCode::ReadWriteMultipleRegisters, Custom(code, _) => FunctionCode::Custom(*code), - - Disconnect => unreachable!(), } } } diff --git a/src/frame/rtu.rs b/src/frame/rtu.rs index b41b36ef..07b6989b 100644 --- a/src/frame/rtu.rs +++ b/src/frame/rtu.rs @@ -14,7 +14,6 @@ pub(crate) struct Header { pub struct RequestAdu<'a> { pub(crate) hdr: Header, pub(crate) pdu: RequestPdu<'a>, - pub(crate) disconnect: bool, } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/src/frame/tcp.rs b/src/frame/tcp.rs index 078630d3..4c7b7f85 100644 --- a/src/frame/tcp.rs +++ b/src/frame/tcp.rs @@ -16,7 +16,6 @@ pub(crate) struct Header { pub struct RequestAdu<'a> { pub(crate) hdr: Header, pub(crate) pdu: RequestPdu<'a>, - pub(crate) disconnect: bool, } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/src/service/rtu.rs b/src/service/rtu.rs index 5e91c03e..41a7016a 100644 --- a/src/service/rtu.rs +++ b/src/service/rtu.rs @@ -4,7 +4,7 @@ use std::{fmt, io}; use futures_util::{SinkExt as _, StreamExt as _}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _}; use tokio_util::codec::Framed; use crate::{ @@ -19,7 +19,7 @@ use super::verify_response_header; /// Modbus RTU client #[derive(Debug)] pub(crate) struct Client { - framed: Framed, + framed: Option>, slave_id: SlaveId, } @@ -30,37 +30,42 @@ where pub(crate) fn new(transport: T, slave: Slave) -> Self { let framed = Framed::new(transport, codec::rtu::ClientCodec::default()); let slave_id = slave.into(); - Self { framed, slave_id } + Self { + slave_id, + framed: Some(framed), + } + } + + fn framed(&mut self) -> io::Result<&mut Framed> { + let Some(framed) = &mut self.framed else { + return Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected")); + }; + Ok(framed) } - fn next_request_adu<'a, R>(&self, req: R, disconnect: bool) -> RequestAdu<'a> + fn next_request_adu<'a, R>(&self, req: R) -> RequestAdu<'a> where R: Into>, { let slave_id = self.slave_id; let hdr = Header { slave_id }; let pdu = req.into(); - RequestAdu { - hdr, - pdu, - disconnect, - } + RequestAdu { hdr, pdu } } async fn call(&mut self, req: Request<'_>) -> Result { - let req_function_code = if matches!(req, Request::Disconnect) { - None - } else { - Some(req.function_code()) - }; - let req_adu = self.next_request_adu(req, req_function_code.is_none()); + log::debug!("Call {:?}", req); + + let req_function_code = req.function_code(); + let req_adu = self.next_request_adu(req); let req_hdr = req_adu.hdr; - self.framed.read_buffer_mut().clear(); - self.framed.send(req_adu).await?; + let framed = self.framed()?; + + framed.read_buffer_mut().clear(); + framed.send(req_adu).await?; - let res_adu = self - .framed + let res_adu = framed .next() .await .unwrap_or_else(|| Err(io::Error::from(io::ErrorKind::BrokenPipe)))?; @@ -75,9 +80,6 @@ where return Err(ProtocolError::HeaderMismatch { message, result }.into()); } - debug_assert!(req_function_code.is_some()); - let req_function_code = req_function_code.unwrap(); - // Match function codes of request and response. let rsp_function_code = match &result { Ok(response) => response.function_code(), @@ -98,6 +100,24 @@ where }| exception, )) } + + async fn disconnect(&mut self) -> io::Result<()> { + let Some(framed) = self.framed.take() else { + // Already disconnected. + return Ok(()); + }; + framed + .into_inner() + .shutdown() + .await + .or_else(|err| match err.kind() { + io::ErrorKind::NotConnected | io::ErrorKind::BrokenPipe => { + // Already disconnected. + Ok(()) + } + _ => Err(err), + }) + } } impl SlaveContext for Client { @@ -114,6 +134,10 @@ where async fn call(&mut self, req: Request<'_>) -> Result { self.call(req).await } + + async fn disconnect(&mut self) -> io::Result<()> { + self.disconnect().await + } } #[cfg(test)] diff --git a/src/service/tcp.rs b/src/service/tcp.rs index a61f07b7..a4a478ac 100644 --- a/src/service/tcp.rs +++ b/src/service/tcp.rs @@ -7,7 +7,7 @@ use std::{ }; use futures_util::{SinkExt as _, StreamExt as _}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _}; use tokio_util::codec::Framed; use crate::{ @@ -23,7 +23,7 @@ const INITIAL_TRANSACTION_ID: TransactionId = 0; /// Modbus TCP client #[derive(Debug)] pub(crate) struct Client { - framed: Framed, + framed: Option>, unit_id: UnitId, transaction_id: AtomicU16, } @@ -37,7 +37,7 @@ where let unit_id: UnitId = slave.into(); let transaction_id = AtomicU16::new(INITIAL_TRANSACTION_ID); Self { - framed, + framed: Some(framed), unit_id, transaction_id, } @@ -58,35 +58,36 @@ where } } - fn next_request_adu<'a, R>(&self, req: R, disconnect: bool) -> RequestAdu<'a> + fn next_request_adu<'a, R>(&self, req: R) -> RequestAdu<'a> where R: Into>, { RequestAdu { hdr: self.next_request_hdr(self.unit_id), pdu: req.into(), - disconnect, } } + fn framed(&mut self) -> io::Result<&mut Framed> { + let Some(framed) = &mut self.framed else { + return Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected")); + }; + Ok(framed) + } + pub(crate) async fn call(&mut self, req: Request<'_>) -> Result { log::debug!("Call {:?}", req); - let req_function_code = if matches!(req, Request::Disconnect) { - None - } else { - Some(req.function_code()) - }; - let req_adu = self.next_request_adu(req, req_function_code.is_none()); + + let req_function_code = req.function_code(); + let req_adu = self.next_request_adu(req); let req_hdr = req_adu.hdr; - self.framed.read_buffer_mut().clear(); - self.framed.send(req_adu).await?; + let framed = self.framed()?; - let res_adu = self - .framed - .next() - .await - .ok_or_else(io::Error::last_os_error)??; + framed.read_buffer_mut().clear(); + framed.send(req_adu).await?; + + let res_adu = framed.next().await.ok_or_else(io::Error::last_os_error)??; let ResponseAdu { hdr: res_hdr, pdu: res_pdu, @@ -98,9 +99,6 @@ where return Err(ProtocolError::HeaderMismatch { message, result }.into()); } - debug_assert!(req_function_code.is_some()); - let req_function_code = req_function_code.unwrap(); - // Match function codes of request and response. let rsp_function_code = match &result { Ok(response) => response.function_code(), @@ -121,6 +119,24 @@ where }| exception, )) } + + async fn disconnect(&mut self) -> io::Result<()> { + let Some(framed) = self.framed.take() else { + // Already disconnected. + return Ok(()); + }; + framed + .into_inner() + .shutdown() + .await + .or_else(|err| match err.kind() { + io::ErrorKind::NotConnected | io::ErrorKind::BrokenPipe => { + // Already disconnected. + Ok(()) + } + _ => Err(err), + }) + } } impl SlaveContext for Client { @@ -135,7 +151,11 @@ where T: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin, { async fn call(&mut self, req: Request<'_>) -> Result { - Client::call(self, req).await + self.call(req).await + } + + async fn disconnect(&mut self) -> io::Result<()> { + self.disconnect().await } }