From 3d37fa3d4c10454f01cb64ceb2580f05a0cf0ce4 Mon Sep 17 00:00:00 2001 From: Uwe Klotz Date: Wed, 28 Aug 2024 02:02:50 +0200 Subject: [PATCH] Remove obsolete FunctionCode::Disconnect Use AsyncWriteExt::shutdown() to release all resources. --- CHANGELOG.md | 10 ++++++- Cargo.toml | 2 +- examples/rtu-client.rs | 2 +- examples/tcp-client.rs | 2 +- examples/tls-client.rs | 2 +- src/client/mod.rs | 45 +++++++++++++++------------- src/codec/mod.rs | 2 -- src/codec/rtu.rs | 36 +++------------------- src/codec/tcp.rs | 33 +++----------------- src/frame/mod.rs | 23 +------------- src/frame/rtu.rs | 1 - src/frame/tcp.rs | 1 - src/service/rtu.rs | 68 ++++++++++++++++++++++++++++-------------- src/service/tcp.rs | 64 +++++++++++++++++++++++++-------------- 14 files changed, 134 insertions(+), 157 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 668e0c66..10332a96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,14 @@ # Changelog +## v0.15.0 (Unreleased) + +### Breaking Changes + +- Client: Added new `disconnect()` method to trait that returns + `std::io::Result`. +- Removed `FunctionCode::Disconnect`. + ## v0.14.0 (2024-07-21) ### Breaking Changes @@ -15,7 +23,7 @@ ### Breaking Changes -- Add `FunctionCode::Disconnect`. +- Added `FunctionCode::Disconnect`. ## v0.13.0 (2024-06-23, yanked) diff --git a/Cargo.toml b/Cargo.toml index c9cea5d0..db5ffe0d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ log = "0.4.20" smallvec = { version = "1.13.1", optional = true, default-features = false } socket2 = { version = "0.5.5", optional = true, default-features = false } thiserror = "1.0.58" -tokio = { version = "1.35.1", default-features = false } +tokio = { version = "1.35.1", default-features = false, features = ["io-util"] } # Disable default-features to exclude unused dependency on libudev tokio-serial = { version = "5.4.4", optional = true, default-features = false } tokio-util = { version = "0.7.10", optional = true, default-features = false, features = [ diff --git a/examples/rtu-client.rs b/examples/rtu-client.rs index 80ffe062..bf896542 100644 --- a/examples/rtu-client.rs +++ b/examples/rtu-client.rs @@ -21,7 +21,7 @@ async fn main() -> Result<(), Box> { println!("Sensor value is: {rsp:?}"); println!("Disconnecting"); - ctx.disconnect().await??; + ctx.disconnect().await?; Ok(()) } diff --git a/examples/tcp-client.rs b/examples/tcp-client.rs index 48783355..8f2adf76 100644 --- a/examples/tcp-client.rs +++ b/examples/tcp-client.rs @@ -23,7 +23,7 @@ async fn main() -> Result<(), Box> { println!("The coupler ID is '{id}'"); println!("Disconnecting"); - ctx.disconnect().await??; + ctx.disconnect().await?; Ok(()) } diff --git a/examples/tls-client.rs b/examples/tls-client.rs index 6508abb3..7ad8602c 100644 --- a/examples/tls-client.rs +++ b/examples/tls-client.rs @@ -117,7 +117,7 @@ async fn main() -> Result<(), Box> { println!("Reading Holding Registers"); let data = ctx.read_holding_registers(40000, 68).await?; println!("Holding Registers Data is '{:?}'", data); - ctx.disconnect().await??; + ctx.disconnect().await?; Ok(()) } diff --git a/src/client/mod.rs b/src/client/mod.rs index 40141371..651f6125 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -7,7 +7,7 @@ use std::{borrow::Cow, fmt::Debug, io}; use async_trait::async_trait; -use crate::{frame::*, slave::*, Error, Result}; +use crate::{frame::*, slave::*, Result}; #[cfg(feature = "rtu")] pub mod rtu; @@ -21,11 +21,22 @@ pub mod sync; /// Transport independent asynchronous client trait #[async_trait] pub trait Client: SlaveContext + Send + Debug { - /// Invoke a _Modbus_ function + /// Invokes a _Modbus_ function. async fn call(&mut self, request: Request<'_>) -> Result; + + /// Disconnects the client. + /// + /// Permanently disconnects the client by shutting down the + /// underlying stream in a graceful manner (`AsyncDrop`). + /// + /// Dropping the client without explicitly disconnecting it + /// beforehand should also work and free all resources. The + /// actual behavior might depend on the underlying transport + /// protocol (RTU/TCP) that is used by the client. + async fn disconnect(&mut self) -> io::Result<()>; } -/// Asynchronous Modbus reader +/// Asynchronous _Modbus_ reader #[async_trait] pub trait Reader: Client { /// Read multiple coils (0x01) @@ -83,22 +94,6 @@ pub struct Context { client: Box, } -impl Context { - /// Disconnect the client - pub async fn disconnect(&mut self) -> Result<()> { - // Disconnecting is expected to fail! - let res = self.client.call(Request::Disconnect).await; - match res { - Ok(_) => unreachable!(), - Err(Error::Transport(err)) => match err.kind() { - io::ErrorKind::NotConnected | io::ErrorKind::BrokenPipe => Ok(Ok(())), - _ => Err(Error::Transport(err)), - }, - Err(err) => Err(err), - } - } -} - impl From> for Context { fn from(client: Box) -> Self { Self { client } @@ -116,6 +111,10 @@ impl Client for Context { async fn call(&mut self, request: Request<'_>) -> Result { self.client.call(request).await } + + async fn disconnect(&mut self) -> io::Result<()> { + self.client.disconnect().await + } } impl SlaveContext for Context { @@ -319,10 +318,10 @@ impl Writer for Context { #[cfg(test)] mod tests { - use crate::Result; + use crate::{Error, Result}; use super::*; - use std::sync::Mutex; + use std::{io, sync::Mutex}; #[derive(Default, Debug)] pub(crate) struct ClientMock { @@ -358,6 +357,10 @@ mod tests { Err(err) => Err(err), } } + + async fn disconnect(&mut self) -> io::Result<()> { + Ok(()) + } } impl SlaveContext for ClientMock { diff --git a/src/codec/mod.rs b/src/codec/mod.rs index 057927ea..bf8c2806 100644 --- a/src/codec/mod.rs +++ b/src/codec/mod.rs @@ -102,7 +102,6 @@ impl<'a> TryFrom> for Bytes { data.put_u8(*d); } } - Disconnect => unreachable!(), } Ok(data.freeze()) } @@ -452,7 +451,6 @@ fn request_byte_count(req: &Request<'_>) -> usize { MaskWriteRegister(_, _, _) => 7, ReadWriteMultipleRegisters(_, _, _, ref data) => 10 + data.len() * 2, Custom(_, ref data) => 1 + data.len(), - Disconnect => unreachable!(), } } diff --git a/src/codec/rtu.rs b/src/codec/rtu.rs index adb101d5..b6d44502 100644 --- a/src/codec/rtu.rs +++ b/src/codec/rtu.rs @@ -310,13 +310,7 @@ impl Decoder for ServerCodec { // to transmission errors, because the frame's bytes // have already been verified with the CRC. RequestPdu::try_from(pdu_data) - .map(|pdu| { - Some(RequestAdu { - hdr, - pdu, - disconnect: false, - }) - }) + .map(|pdu| Some(RequestAdu { hdr, pdu })) .map_err(|err| { // Unrecoverable error log::error!("Failed to decode request PDU: {}", err); @@ -329,21 +323,7 @@ impl<'a> Encoder> for ClientCodec { type Error = Error; fn encode(&mut self, adu: RequestAdu<'a>, buf: &mut BytesMut) -> Result<()> { - let RequestAdu { - hdr, - pdu, - disconnect, - } = adu; - if disconnect { - // The disconnect happens implicitly after letting this request - // fail by returning an error. This will drop the attached - // transport, e.g. for closing a stale, exclusive connection - // to a serial port before trying to reconnect. - return Err(Error::new( - ErrorKind::NotConnected, - "Disconnecting - not an error", - )); - } + let RequestAdu { hdr, pdu } = adu; let pdu_data: Bytes = pdu.try_into()?; buf.reserve(pdu_data.len() + 3); buf.put_u8(hdr.slave_id); @@ -729,11 +709,7 @@ mod tests { let pdu = req.into(); let slave_id = 0x01; let hdr = Header { slave_id }; - let adu = RequestAdu { - hdr, - pdu, - disconnect: false, - }; + let adu = RequestAdu { hdr, pdu }; codec.encode(adu, &mut buf).unwrap(); assert_eq!( @@ -749,11 +725,7 @@ mod tests { let pdu = req.into(); let slave_id = 0x01; let hdr = Header { slave_id }; - let adu = RequestAdu { - hdr, - pdu, - disconnect: false, - }; + let adu = RequestAdu { hdr, pdu }; let mut buf = BytesMut::with_capacity(40); #[allow(unsafe_code)] unsafe { diff --git a/src/codec/tcp.rs b/src/codec/tcp.rs index 168dfb85..5b045992 100644 --- a/src/codec/tcp.rs +++ b/src/codec/tcp.rs @@ -122,11 +122,7 @@ impl Decoder for ServerCodec { fn decode(&mut self, buf: &mut BytesMut) -> Result>> { if let Some((hdr, pdu_data)) = self.decoder.decode(buf)? { let pdu = RequestPdu::try_from(pdu_data)?; - Ok(Some(RequestAdu { - hdr, - pdu, - disconnect: false, - })) + Ok(Some(RequestAdu { hdr, pdu })) } else { Ok(None) } @@ -137,20 +133,7 @@ impl<'a> Encoder> for ClientCodec { type Error = Error; fn encode(&mut self, adu: RequestAdu<'a>, buf: &mut BytesMut) -> Result<()> { - let RequestAdu { - hdr, - pdu, - disconnect, - } = adu; - if disconnect { - // The disconnect happens implicitly after letting this request - // fail by returning an error. This will drop the attached - // transport, e.g. for terminating an open connection. - return Err(Error::new( - ErrorKind::NotConnected, - "Disconnecting - not an error", - )); - } + let RequestAdu { hdr, pdu } = adu; let pdu_data: Bytes = pdu.try_into()?; buf.reserve(pdu_data.len() + 7); buf.put_u16(hdr.transaction_id); @@ -284,11 +267,7 @@ mod tests { transaction_id: TRANSACTION_ID, unit_id: UNIT_ID, }; - let adu = RequestAdu { - hdr, - pdu, - disconnect: false, - }; + let adu = RequestAdu { hdr, pdu }; codec.encode(adu, &mut buf).unwrap(); // header assert_eq!(buf[0], TRANSACTION_ID_HI); @@ -312,11 +291,7 @@ mod tests { transaction_id: TRANSACTION_ID, unit_id: UNIT_ID, }; - let adu = RequestAdu { - hdr, - pdu, - disconnect: false, - }; + let adu = RequestAdu { hdr, pdu }; let mut buf = BytesMut::with_capacity(40); #[allow(unsafe_code)] unsafe { diff --git a/src/frame/mod.rs b/src/frame/mod.rs index f3952e7d..dd2c761b 100644 --- a/src/frame/mod.rs +++ b/src/frame/mod.rs @@ -46,8 +46,6 @@ pub enum FunctionCode { /// Custom Modbus Function Code. Custom(u8), - - Disconnect, } impl FunctionCode { @@ -69,11 +67,7 @@ impl FunctionCode { } } - /// Get the [`u8`] value of the current [`FunctionCode`]. - /// - /// # Panics - /// - /// Panics on [`Disconnect`](Self::Disconnect) which has no corresponding Modbus function code. + /// Gets the [`u8`] value of the current [`FunctionCode`]. #[must_use] pub const fn value(self) -> u8 { match self { @@ -88,7 +82,6 @@ impl FunctionCode { Self::MaskWriteRegister => 0x16, Self::ReadWriteMultipleRegisters => 0x17, Self::Custom(code) => code, - Self::Disconnect => unreachable!(), } } } @@ -181,17 +174,6 @@ pub enum Request<'a> { /// The first parameter is the Modbus function code. /// The second parameter is the raw bytes of the request. Custom(u8, Cow<'a, [u8]>), - - /// A poison pill for stopping the client service and to release - /// the underlying transport, e.g. for disconnecting from an - /// exclusively used serial port. - /// - /// This is an ugly workaround, because `tokio-proto` does not - /// provide other means to gracefully shut down the `ClientService`. - /// Otherwise the bound transport is never freed as long as the - /// executor is active, even when dropping the Modbus client - /// context. - Disconnect, } impl<'a> Request<'a> { @@ -220,7 +202,6 @@ impl<'a> Request<'a> { ReadWriteMultipleRegisters(addr, qty, write_addr, Cow::Owned(words.into_owned())) } Custom(func, bytes) => Custom(func, Cow::Owned(bytes.into_owned())), - Disconnect => Disconnect, } } @@ -247,8 +228,6 @@ impl<'a> Request<'a> { ReadWriteMultipleRegisters(_, _, _, _) => FunctionCode::ReadWriteMultipleRegisters, Custom(code, _) => FunctionCode::Custom(*code), - - Disconnect => unreachable!(), } } } diff --git a/src/frame/rtu.rs b/src/frame/rtu.rs index b41b36ef..07b6989b 100644 --- a/src/frame/rtu.rs +++ b/src/frame/rtu.rs @@ -14,7 +14,6 @@ pub(crate) struct Header { pub struct RequestAdu<'a> { pub(crate) hdr: Header, pub(crate) pdu: RequestPdu<'a>, - pub(crate) disconnect: bool, } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/src/frame/tcp.rs b/src/frame/tcp.rs index 078630d3..4c7b7f85 100644 --- a/src/frame/tcp.rs +++ b/src/frame/tcp.rs @@ -16,7 +16,6 @@ pub(crate) struct Header { pub struct RequestAdu<'a> { pub(crate) hdr: Header, pub(crate) pdu: RequestPdu<'a>, - pub(crate) disconnect: bool, } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/src/service/rtu.rs b/src/service/rtu.rs index 5e91c03e..41a7016a 100644 --- a/src/service/rtu.rs +++ b/src/service/rtu.rs @@ -4,7 +4,7 @@ use std::{fmt, io}; use futures_util::{SinkExt as _, StreamExt as _}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _}; use tokio_util::codec::Framed; use crate::{ @@ -19,7 +19,7 @@ use super::verify_response_header; /// Modbus RTU client #[derive(Debug)] pub(crate) struct Client { - framed: Framed, + framed: Option>, slave_id: SlaveId, } @@ -30,37 +30,42 @@ where pub(crate) fn new(transport: T, slave: Slave) -> Self { let framed = Framed::new(transport, codec::rtu::ClientCodec::default()); let slave_id = slave.into(); - Self { framed, slave_id } + Self { + slave_id, + framed: Some(framed), + } + } + + fn framed(&mut self) -> io::Result<&mut Framed> { + let Some(framed) = &mut self.framed else { + return Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected")); + }; + Ok(framed) } - fn next_request_adu<'a, R>(&self, req: R, disconnect: bool) -> RequestAdu<'a> + fn next_request_adu<'a, R>(&self, req: R) -> RequestAdu<'a> where R: Into>, { let slave_id = self.slave_id; let hdr = Header { slave_id }; let pdu = req.into(); - RequestAdu { - hdr, - pdu, - disconnect, - } + RequestAdu { hdr, pdu } } async fn call(&mut self, req: Request<'_>) -> Result { - let req_function_code = if matches!(req, Request::Disconnect) { - None - } else { - Some(req.function_code()) - }; - let req_adu = self.next_request_adu(req, req_function_code.is_none()); + log::debug!("Call {:?}", req); + + let req_function_code = req.function_code(); + let req_adu = self.next_request_adu(req); let req_hdr = req_adu.hdr; - self.framed.read_buffer_mut().clear(); - self.framed.send(req_adu).await?; + let framed = self.framed()?; + + framed.read_buffer_mut().clear(); + framed.send(req_adu).await?; - let res_adu = self - .framed + let res_adu = framed .next() .await .unwrap_or_else(|| Err(io::Error::from(io::ErrorKind::BrokenPipe)))?; @@ -75,9 +80,6 @@ where return Err(ProtocolError::HeaderMismatch { message, result }.into()); } - debug_assert!(req_function_code.is_some()); - let req_function_code = req_function_code.unwrap(); - // Match function codes of request and response. let rsp_function_code = match &result { Ok(response) => response.function_code(), @@ -98,6 +100,24 @@ where }| exception, )) } + + async fn disconnect(&mut self) -> io::Result<()> { + let Some(framed) = self.framed.take() else { + // Already disconnected. + return Ok(()); + }; + framed + .into_inner() + .shutdown() + .await + .or_else(|err| match err.kind() { + io::ErrorKind::NotConnected | io::ErrorKind::BrokenPipe => { + // Already disconnected. + Ok(()) + } + _ => Err(err), + }) + } } impl SlaveContext for Client { @@ -114,6 +134,10 @@ where async fn call(&mut self, req: Request<'_>) -> Result { self.call(req).await } + + async fn disconnect(&mut self) -> io::Result<()> { + self.disconnect().await + } } #[cfg(test)] diff --git a/src/service/tcp.rs b/src/service/tcp.rs index a61f07b7..a4a478ac 100644 --- a/src/service/tcp.rs +++ b/src/service/tcp.rs @@ -7,7 +7,7 @@ use std::{ }; use futures_util::{SinkExt as _, StreamExt as _}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _}; use tokio_util::codec::Framed; use crate::{ @@ -23,7 +23,7 @@ const INITIAL_TRANSACTION_ID: TransactionId = 0; /// Modbus TCP client #[derive(Debug)] pub(crate) struct Client { - framed: Framed, + framed: Option>, unit_id: UnitId, transaction_id: AtomicU16, } @@ -37,7 +37,7 @@ where let unit_id: UnitId = slave.into(); let transaction_id = AtomicU16::new(INITIAL_TRANSACTION_ID); Self { - framed, + framed: Some(framed), unit_id, transaction_id, } @@ -58,35 +58,36 @@ where } } - fn next_request_adu<'a, R>(&self, req: R, disconnect: bool) -> RequestAdu<'a> + fn next_request_adu<'a, R>(&self, req: R) -> RequestAdu<'a> where R: Into>, { RequestAdu { hdr: self.next_request_hdr(self.unit_id), pdu: req.into(), - disconnect, } } + fn framed(&mut self) -> io::Result<&mut Framed> { + let Some(framed) = &mut self.framed else { + return Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected")); + }; + Ok(framed) + } + pub(crate) async fn call(&mut self, req: Request<'_>) -> Result { log::debug!("Call {:?}", req); - let req_function_code = if matches!(req, Request::Disconnect) { - None - } else { - Some(req.function_code()) - }; - let req_adu = self.next_request_adu(req, req_function_code.is_none()); + + let req_function_code = req.function_code(); + let req_adu = self.next_request_adu(req); let req_hdr = req_adu.hdr; - self.framed.read_buffer_mut().clear(); - self.framed.send(req_adu).await?; + let framed = self.framed()?; - let res_adu = self - .framed - .next() - .await - .ok_or_else(io::Error::last_os_error)??; + framed.read_buffer_mut().clear(); + framed.send(req_adu).await?; + + let res_adu = framed.next().await.ok_or_else(io::Error::last_os_error)??; let ResponseAdu { hdr: res_hdr, pdu: res_pdu, @@ -98,9 +99,6 @@ where return Err(ProtocolError::HeaderMismatch { message, result }.into()); } - debug_assert!(req_function_code.is_some()); - let req_function_code = req_function_code.unwrap(); - // Match function codes of request and response. let rsp_function_code = match &result { Ok(response) => response.function_code(), @@ -121,6 +119,24 @@ where }| exception, )) } + + async fn disconnect(&mut self) -> io::Result<()> { + let Some(framed) = self.framed.take() else { + // Already disconnected. + return Ok(()); + }; + framed + .into_inner() + .shutdown() + .await + .or_else(|err| match err.kind() { + io::ErrorKind::NotConnected | io::ErrorKind::BrokenPipe => { + // Already disconnected. + Ok(()) + } + _ => Err(err), + }) + } } impl SlaveContext for Client { @@ -135,7 +151,11 @@ where T: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin, { async fn call(&mut self, req: Request<'_>) -> Result { - Client::call(self, req).await + self.call(req).await + } + + async fn disconnect(&mut self) -> io::Result<()> { + self.disconnect().await } }