diff --git a/src/client/mod.rs b/src/client/mod.rs index 651f612..92b9f2a 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -316,6 +316,28 @@ impl Writer for Context { } } +#[cfg(any(feature = "rtu", feature = "tcp"))] +pub(crate) async fn disconnect_framed( + framed: tokio_util::codec::Framed, +) -> std::io::Result<()> +where + T: tokio::io::AsyncWrite + Unpin, +{ + use tokio::io::AsyncWriteExt as _; + + framed + .into_inner() + .shutdown() + .await + .or_else(|err| match err.kind() { + std::io::ErrorKind::NotConnected | std::io::ErrorKind::BrokenPipe => { + // Already disconnected. + Ok(()) + } + _ => Err(err), + }) +} + #[cfg(test)] mod tests { use crate::{Error, Result}; diff --git a/src/client/rtu.rs b/src/client/rtu.rs index 606ffad..1d2eb0d 100644 --- a/src/client/rtu.rs +++ b/src/client/rtu.rs @@ -3,25 +3,37 @@ //! RTU client connections +use std::{fmt, io}; + +use futures_util::{SinkExt as _, StreamExt as _}; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Framed; -use crate::rtu::{Client, ClientContext}; +use crate::{ + codec::rtu::ClientCodec, + frame::{ + rtu::{Header, RequestAdu}, + RequestPdu, + }, + slave::SlaveContext, + FunctionCode, Request, Response, Result, Slave, +}; -use super::*; +use super::{disconnect_framed, Context}; -/// Connect to no particular Modbus slave device for sending +/// Connect to no particular _Modbus_ slave device for sending /// broadcast messages. pub fn attach(transport: T) -> Context where - T: AsyncRead + AsyncWrite + Debug + Unpin + Send + 'static, + T: AsyncRead + AsyncWrite + fmt::Debug + Unpin + Send + 'static, { attach_slave(transport, Slave::broadcast()) } -/// Connect to any kind of Modbus slave device. +/// Connect to any kind of _Modbus_ slave device. pub fn attach_slave(transport: T, slave: Slave) -> Context where - T: AsyncRead + AsyncWrite + Debug + Unpin + Send + 'static, + T: AsyncRead + AsyncWrite + fmt::Debug + Unpin + Send + 'static, { let client = Client::new(transport); let context = ClientContext::new(client, slave); @@ -29,3 +41,240 @@ where client: Box::new(context), } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct RequestContext { + pub(crate) function_code: FunctionCode, + pub(crate) header: Header, +} + +impl RequestContext { + #[must_use] + pub const fn function_code(&self) -> FunctionCode { + self.function_code + } +} + +/// _Modbus_ RTU client. +#[derive(Debug)] +pub struct Client { + framed: Framed, +} + +impl Client +where + T: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(transport: T) -> Self { + let framed = Framed::new(transport, ClientCodec::default()); + Self { framed } + } + + pub async fn disconnect(self) -> io::Result<()> { + let Self { framed } = self; + disconnect_framed(framed).await + } + + pub async fn call<'a>(&mut self, server: Slave, request: Request<'a>) -> Result { + let request_context = self.send_request(server, request).await?; + self.recv_response(request_context).await + } + + pub async fn send_request<'a>( + &mut self, + server: Slave, + request: Request<'a>, + ) -> io::Result { + self.send_request_pdu(server, request).await + } + + async fn send_request_pdu<'a, R>( + &mut self, + server: Slave, + request_pdu: R, + ) -> io::Result + where + R: Into>, + { + let request_adu = request_adu(server, request_pdu); + self.send_request_adu(request_adu).await + } + + async fn send_request_adu<'a>( + &mut self, + request_adu: RequestAdu<'a>, + ) -> io::Result { + let request_context = request_adu.context(); + + self.framed.read_buffer_mut().clear(); + self.framed.send(request_adu).await?; + + Ok(request_context) + } + + pub async fn recv_response(&mut self, request_context: RequestContext) -> Result { + let response_adu = self + .framed + .next() + .await + .unwrap_or_else(|| Err(io::Error::from(io::ErrorKind::BrokenPipe)))?; + + response_adu.try_into_response(request_context) + } +} + +/// _Modbus_ RTU client with (server) context and connection state. +/// +/// Client that invokes methods (request/response) on a single or many (broadcast) server(s). +/// +/// The server can be switched between method calls. +#[derive(Debug)] +pub struct ClientContext { + client: Option>, + server: Slave, +} + +impl ClientContext { + pub fn new(client: Client, server: Slave) -> Self { + Self { + client: Some(client), + server, + } + } + + #[must_use] + pub const fn is_connected(&self) -> bool { + self.client.is_some() + } + + #[must_use] + pub const fn server(&self) -> Slave { + self.server + } + + pub fn set_server(&mut self, server: Slave) { + self.server = server; + } +} + +impl ClientContext +where + T: AsyncWrite + Unpin, +{ + pub async fn disconnect(&mut self) -> io::Result<()> { + let Some(client) = self.client.take() else { + // Already disconnected. + return Ok(()); + }; + disconnect_framed(client.framed).await + } +} + +impl ClientContext +where + T: AsyncRead + AsyncWrite + Unpin, +{ + pub async fn call(&mut self, request: Request<'_>) -> Result { + log::debug!("Call {:?}", request); + + let Some(client) = &mut self.client else { + return Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected").into()); + }; + + client.call(self.server, request).await + } +} + +impl ClientContext +where + T: AsyncRead + AsyncWrite + Unpin + fmt::Debug + Send + 'static, +{ + #[must_use] + pub fn boxed(self) -> Box { + Box::new(self) + } +} + +impl SlaveContext for ClientContext { + fn set_slave(&mut self, slave: Slave) { + self.set_server(slave); + } +} + +#[async_trait::async_trait] +impl crate::client::Client for ClientContext +where + T: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin, +{ + async fn call(&mut self, req: Request<'_>) -> Result { + self.call(req).await + } + + async fn disconnect(&mut self) -> io::Result<()> { + self.disconnect().await + } +} + +fn request_adu<'a, R>(server: Slave, request_pdu: R) -> RequestAdu<'a> +where + R: Into>, +{ + let hdr = Header { slave: server }; + let pdu = request_pdu.into(); + RequestAdu { hdr, pdu } +} + +#[cfg(test)] +mod tests { + use core::{ + pin::Pin, + task::{Context, Poll}, + }; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, Result}; + + use crate::Error; + + use super::*; + + #[derive(Debug)] + struct MockTransport; + + impl Unpin for MockTransport {} + + impl AsyncRead for MockTransport { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl AsyncWrite for MockTransport { + fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll> { + Poll::Ready(Ok(2)) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + unimplemented!() + } + } + + #[tokio::test] + async fn handle_broken_pipe() { + let transport = MockTransport; + let client = Client::new(transport); + let mut context = ClientContext::new(client, Slave::broadcast()); + let res = context.call(Request::ReadCoils(0x00, 5)).await; + assert!(res.is_err()); + let err = res.err().unwrap(); + assert!( + matches!(err, Error::Transport(err) if err.kind() == std::io::ErrorKind::BrokenPipe) + ); + } +} diff --git a/src/client/tcp.rs b/src/client/tcp.rs index da72ce4..6c0381e 100644 --- a/src/client/tcp.rs +++ b/src/client/tcp.rs @@ -5,12 +5,24 @@ use std::{fmt, io, net::SocketAddr}; +use futures_util::{SinkExt as _, StreamExt as _}; use tokio::{ io::{AsyncRead, AsyncWrite}, net::TcpStream, }; +use tokio_util::codec::Framed; -use super::*; +use crate::{ + codec::tcp::ClientCodec, + frame::{ + tcp::{Header, RequestAdu, ResponseAdu, TransactionId, UnitId}, + verify_response_header, RequestPdu, ResponsePdu, + }, + slave::SlaveContext, + ExceptionResponse, ProtocolError, Request, Response, Result, Slave, +}; + +use super::{disconnect_framed, Context}; /// Establish a direct connection to a Modbus TCP coupler. pub async fn connect(socket_addr: SocketAddr) -> io::Result { @@ -43,8 +55,213 @@ pub fn attach_slave(transport: T, slave: Slave) -> Context where T: AsyncRead + AsyncWrite + Send + Unpin + fmt::Debug + 'static, { - let client = crate::tcp::Client::new(transport, slave); + let client = Client::new(transport, slave); Context { client: Box::new(client), } } + +const INITIAL_TRANSACTION_ID: TransactionId = 0; + +#[derive(Debug)] +struct TransactionIdGenerator { + next_transaction_id: TransactionId, +} + +impl TransactionIdGenerator { + const fn new() -> Self { + Self { + next_transaction_id: INITIAL_TRANSACTION_ID, + } + } + + fn next(&mut self) -> TransactionId { + let next_transaction_id = self.next_transaction_id; + self.next_transaction_id = next_transaction_id.wrapping_add(1); + next_transaction_id + } +} + +/// Modbus TCP client +#[derive(Debug)] +pub(crate) struct Client { + framed: Option>, + transaction_id_generator: TransactionIdGenerator, + unit_id: UnitId, +} + +impl Client +where + T: AsyncRead + AsyncWrite + Unpin, +{ + pub(crate) fn new(transport: T, slave: Slave) -> Self { + let framed = Framed::new(transport, ClientCodec::new()); + let transaction_id_generator = TransactionIdGenerator::new(); + let unit_id: UnitId = slave.into(); + Self { + framed: Some(framed), + transaction_id_generator, + unit_id, + } + } + + fn next_request_hdr(&mut self, unit_id: UnitId) -> Header { + let transaction_id = self.transaction_id_generator.next(); + Header { + transaction_id, + unit_id, + } + } + + fn next_request_adu<'a, R>(&mut self, req: R) -> RequestAdu<'a> + where + R: Into>, + { + RequestAdu { + hdr: self.next_request_hdr(self.unit_id), + pdu: req.into(), + } + } + + fn framed(&mut self) -> io::Result<&mut Framed> { + let Some(framed) = &mut self.framed else { + return Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected")); + }; + Ok(framed) + } + + pub(crate) async fn call(&mut self, req: Request<'_>) -> Result { + log::debug!("Call {:?}", req); + + let req_function_code = req.function_code(); + let req_adu = self.next_request_adu(req); + let req_hdr = req_adu.hdr; + + let framed = self.framed()?; + + framed.read_buffer_mut().clear(); + framed.send(req_adu).await?; + + let res_adu = framed.next().await.ok_or_else(io::Error::last_os_error)??; + let ResponseAdu { + hdr: res_hdr, + pdu: res_pdu, + } = res_adu; + let ResponsePdu(result) = res_pdu; + + // Match headers of request and response. + if let Err(message) = verify_response_header(&req_hdr, &res_hdr) { + return Err(ProtocolError::HeaderMismatch { message, result }.into()); + } + + // Match function codes of request and response. + let rsp_function_code = match &result { + Ok(response) => response.function_code(), + Err(ExceptionResponse { function, .. }) => *function, + }; + if req_function_code != rsp_function_code { + return Err(ProtocolError::FunctionCodeMismatch { + request: req_function_code, + result, + } + .into()); + } + + Ok(result.map_err( + |ExceptionResponse { + function: _, + exception, + }| exception, + )) + } + + async fn disconnect(&mut self) -> io::Result<()> { + let Some(framed) = self.framed.take() else { + // Already disconnected. + return Ok(()); + }; + disconnect_framed(framed).await + } +} + +impl SlaveContext for Client { + fn set_slave(&mut self, slave: Slave) { + self.unit_id = slave.into(); + } +} + +#[async_trait::async_trait] +impl crate::client::Client for Client +where + T: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin, +{ + async fn call(&mut self, req: Request<'_>) -> Result { + self.call(req).await + } + + async fn disconnect(&mut self) -> io::Result<()> { + self.disconnect().await + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_same_headers() { + // Given + let req_hdr = Header { + unit_id: 0, + transaction_id: 42, + }; + let rsp_hdr = Header { + unit_id: 0, + transaction_id: 42, + }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(result.is_ok()); + } + + #[test] + fn invalid_validate_not_same_unit_id() { + // Given + let req_hdr = Header { + unit_id: 0, + transaction_id: 42, + }; + let rsp_hdr = Header { + unit_id: 5, + transaction_id: 42, + }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(result.is_err()); + } + + #[test] + fn invalid_validate_not_same_transaction_id() { + // Given + let req_hdr = Header { + unit_id: 0, + transaction_id: 42, + }; + let rsp_hdr = Header { + unit_id: 0, + transaction_id: 86, + }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(result.is_err()); + } +} diff --git a/src/codec/mod.rs b/src/codec/mod.rs index 0745fe2..41f3a72 100644 --- a/src/codec/mod.rs +++ b/src/codec/mod.rs @@ -19,26 +19,6 @@ pub(crate) mod rtu; #[cfg(feature = "tcp")] pub(crate) mod tcp; -#[cfg(any(feature = "rtu", feature = "tcp"))] -pub(crate) async fn disconnect(framed: tokio_util::codec::Framed) -> std::io::Result<()> -where - T: tokio::io::AsyncWrite + Unpin, -{ - use tokio::io::AsyncWriteExt as _; - - framed - .into_inner() - .shutdown() - .await - .or_else(|err| match err.kind() { - std::io::ErrorKind::NotConnected | std::io::ErrorKind::BrokenPipe => { - // Already disconnected. - Ok(()) - } - _ => Err(err), - }) -} - #[allow(clippy::cast_possible_truncation)] fn u16_len(len: usize) -> u16 { // This type conversion should always be safe, because either diff --git a/src/frame/rtu.rs b/src/frame/rtu.rs index d7da018..0199298 100644 --- a/src/frame/rtu.rs +++ b/src/frame/rtu.rs @@ -3,7 +3,7 @@ use super::*; -use crate::{rtu::RequestContext, ProtocolError, Result, Slave}; +use crate::{client::rtu::RequestContext, ProtocolError, Result, Slave}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub(crate) struct Header { diff --git a/src/lib.rs b/src/lib.rs index 6451263..7ea11da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,12 +39,6 @@ pub mod client; pub mod slave; pub use self::slave::{Slave, SlaveId}; -#[cfg(feature = "rtu")] -pub mod rtu; - -#[cfg(feature = "tcp")] -pub mod tcp; - mod codec; mod error; diff --git a/src/prelude.rs b/src/prelude.rs index 35dcd3f..33c694e 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -12,7 +12,6 @@ pub use crate::client; #[cfg(feature = "rtu")] pub mod rtu { pub use crate::client::rtu::*; - pub use crate::rtu::{Client, ClientContext, RequestContext}; } #[allow(missing_docs)] diff --git a/src/rtu.rs b/src/rtu.rs deleted file mode 100644 index 666019e..0000000 --- a/src/rtu.rs +++ /dev/null @@ -1,252 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2017-2024 slowtec GmbH -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use std::{fmt, io}; - -use futures_util::{SinkExt as _, StreamExt as _}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::codec::Framed; - -use crate::{ - codec::{self, disconnect}, - frame::{rtu::*, *}, - slave::*, - Result, -}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct RequestContext { - pub(crate) function_code: FunctionCode, - pub(crate) header: Header, -} - -impl RequestContext { - #[must_use] - pub const fn function_code(&self) -> FunctionCode { - self.function_code - } -} - -/// _Modbus_ RTU client. -#[derive(Debug)] -pub struct Client { - framed: Framed, -} - -impl Client -where - T: AsyncRead + AsyncWrite + Unpin, -{ - pub fn new(transport: T) -> Self { - let framed = Framed::new(transport, codec::rtu::ClientCodec::default()); - Self { framed } - } - - pub async fn disconnect(self) -> io::Result<()> { - let Self { framed } = self; - disconnect(framed).await - } - - pub async fn call<'a>(&mut self, server: Slave, request: Request<'a>) -> Result { - let request_context = self.send_request(server, request).await?; - self.recv_response(request_context).await - } - - pub async fn send_request<'a>( - &mut self, - server: Slave, - request: Request<'a>, - ) -> io::Result { - self.send_request_pdu(server, request).await - } - - async fn send_request_pdu<'a, R>( - &mut self, - server: Slave, - request_pdu: R, - ) -> io::Result - where - R: Into>, - { - let request_adu = request_adu(server, request_pdu); - self.send_request_adu(request_adu).await - } - - async fn send_request_adu<'a>( - &mut self, - request_adu: RequestAdu<'a>, - ) -> io::Result { - let request_context = request_adu.context(); - - self.framed.read_buffer_mut().clear(); - self.framed.send(request_adu).await?; - - Ok(request_context) - } - - pub async fn recv_response(&mut self, request_context: RequestContext) -> Result { - let response_adu = self - .framed - .next() - .await - .unwrap_or_else(|| Err(io::Error::from(io::ErrorKind::BrokenPipe)))?; - - response_adu.try_into_response(request_context) - } -} - -/// _Modbus_ RTU client with (server) context and connection state. -/// -/// Client that invokes methods (request/response) on a single or many (broadcast) server(s). -/// -/// The server can be switched between method calls. -#[derive(Debug)] -pub struct ClientContext { - client: Option>, - server: Slave, -} - -impl ClientContext { - pub fn new(client: Client, server: Slave) -> Self { - Self { - client: Some(client), - server, - } - } - - #[must_use] - pub const fn is_connected(&self) -> bool { - self.client.is_some() - } - - #[must_use] - pub const fn server(&self) -> Slave { - self.server - } - - pub fn set_server(&mut self, server: Slave) { - self.server = server; - } -} - -impl ClientContext -where - T: AsyncWrite + Unpin, -{ - pub async fn disconnect(&mut self) -> io::Result<()> { - let Some(client) = self.client.take() else { - // Already disconnected. - return Ok(()); - }; - disconnect(client.framed).await - } -} - -impl ClientContext -where - T: AsyncRead + AsyncWrite + Unpin, -{ - pub async fn call(&mut self, request: Request<'_>) -> Result { - log::debug!("Call {:?}", request); - - let Some(client) = &mut self.client else { - return Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected").into()); - }; - - client.call(self.server, request).await - } -} - -impl ClientContext -where - T: AsyncRead + AsyncWrite + Unpin + fmt::Debug + Send + 'static, -{ - #[must_use] - pub fn boxed(self) -> Box { - Box::new(self) - } -} - -impl SlaveContext for ClientContext { - fn set_slave(&mut self, slave: Slave) { - self.set_server(slave); - } -} - -#[async_trait::async_trait] -impl crate::client::Client for ClientContext -where - T: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin, -{ - async fn call(&mut self, req: Request<'_>) -> Result { - self.call(req).await - } - - async fn disconnect(&mut self) -> io::Result<()> { - self.disconnect().await - } -} - -fn request_adu<'a, R>(server: Slave, request_pdu: R) -> RequestAdu<'a> -where - R: Into>, -{ - let hdr = Header { slave: server }; - let pdu = request_pdu.into(); - RequestAdu { hdr, pdu } -} - -#[cfg(test)] -mod tests { - use core::{ - pin::Pin, - task::{Context, Poll}, - }; - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, Result}; - - use crate::Error; - - use super::*; - - #[derive(Debug)] - struct MockTransport; - - impl Unpin for MockTransport {} - - impl AsyncRead for MockTransport { - fn poll_read( - self: Pin<&mut Self>, - _: &mut Context<'_>, - _: &mut ReadBuf<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - } - - impl AsyncWrite for MockTransport { - fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll> { - Poll::Ready(Ok(2)) - } - - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - unimplemented!() - } - } - - #[tokio::test] - async fn handle_broken_pipe() { - let transport = MockTransport; - let client = Client::new(transport); - let mut context = ClientContext::new(client, Slave::broadcast()); - let res = context.call(Request::ReadCoils(0x00, 5)).await; - assert!(res.is_err()); - let err = res.err().unwrap(); - assert!( - matches!(err, Error::Transport(err) if err.kind() == std::io::ErrorKind::BrokenPipe) - ); - } -} diff --git a/src/tcp.rs b/src/tcp.rs deleted file mode 100644 index f6914da..0000000 --- a/src/tcp.rs +++ /dev/null @@ -1,223 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2017-2024 slowtec GmbH -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use std::{fmt, io}; - -use futures_util::{SinkExt as _, StreamExt as _}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::codec::Framed; - -use crate::{ - codec::{self, disconnect}, - frame::{ - tcp::{Header, RequestAdu, ResponseAdu, TransactionId, UnitId}, - verify_response_header, RequestPdu, ResponsePdu, - }, - slave::*, - ExceptionResponse, ProtocolError, Request, Response, Result, -}; - -const INITIAL_TRANSACTION_ID: TransactionId = 0; - -#[derive(Debug)] -struct TransactionIdGenerator { - next_transaction_id: TransactionId, -} - -impl TransactionIdGenerator { - const fn new() -> Self { - Self { - next_transaction_id: INITIAL_TRANSACTION_ID, - } - } - - fn next(&mut self) -> TransactionId { - let next_transaction_id = self.next_transaction_id; - self.next_transaction_id = next_transaction_id.wrapping_add(1); - next_transaction_id - } -} - -/// Modbus TCP client -#[derive(Debug)] -pub(crate) struct Client { - framed: Option>, - transaction_id_generator: TransactionIdGenerator, - unit_id: UnitId, -} - -impl Client -where - T: AsyncRead + AsyncWrite + Unpin, -{ - pub(crate) fn new(transport: T, slave: Slave) -> Self { - let framed = Framed::new(transport, codec::tcp::ClientCodec::new()); - let transaction_id_generator = TransactionIdGenerator::new(); - let unit_id: UnitId = slave.into(); - Self { - framed: Some(framed), - transaction_id_generator, - unit_id, - } - } - - fn next_request_hdr(&mut self, unit_id: UnitId) -> Header { - let transaction_id = self.transaction_id_generator.next(); - Header { - transaction_id, - unit_id, - } - } - - fn next_request_adu<'a, R>(&mut self, req: R) -> RequestAdu<'a> - where - R: Into>, - { - RequestAdu { - hdr: self.next_request_hdr(self.unit_id), - pdu: req.into(), - } - } - - fn framed(&mut self) -> io::Result<&mut Framed> { - let Some(framed) = &mut self.framed else { - return Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected")); - }; - Ok(framed) - } - - pub(crate) async fn call(&mut self, req: Request<'_>) -> Result { - log::debug!("Call {:?}", req); - - let req_function_code = req.function_code(); - let req_adu = self.next_request_adu(req); - let req_hdr = req_adu.hdr; - - let framed = self.framed()?; - - framed.read_buffer_mut().clear(); - framed.send(req_adu).await?; - - let res_adu = framed.next().await.ok_or_else(io::Error::last_os_error)??; - let ResponseAdu { - hdr: res_hdr, - pdu: res_pdu, - } = res_adu; - let ResponsePdu(result) = res_pdu; - - // Match headers of request and response. - if let Err(message) = verify_response_header(&req_hdr, &res_hdr) { - return Err(ProtocolError::HeaderMismatch { message, result }.into()); - } - - // Match function codes of request and response. - let rsp_function_code = match &result { - Ok(response) => response.function_code(), - Err(ExceptionResponse { function, .. }) => *function, - }; - if req_function_code != rsp_function_code { - return Err(ProtocolError::FunctionCodeMismatch { - request: req_function_code, - result, - } - .into()); - } - - Ok(result.map_err( - |ExceptionResponse { - function: _, - exception, - }| exception, - )) - } - - async fn disconnect(&mut self) -> io::Result<()> { - let Some(framed) = self.framed.take() else { - // Already disconnected. - return Ok(()); - }; - disconnect(framed).await - } -} - -impl SlaveContext for Client { - fn set_slave(&mut self, slave: Slave) { - self.unit_id = slave.into(); - } -} - -#[async_trait::async_trait] -impl crate::client::Client for Client -where - T: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin, -{ - async fn call(&mut self, req: Request<'_>) -> Result { - self.call(req).await - } - - async fn disconnect(&mut self) -> io::Result<()> { - self.disconnect().await - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn validate_same_headers() { - // Given - let req_hdr = Header { - unit_id: 0, - transaction_id: 42, - }; - let rsp_hdr = Header { - unit_id: 0, - transaction_id: 42, - }; - - // When - let result = verify_response_header(&req_hdr, &rsp_hdr); - - // Then - assert!(result.is_ok()); - } - - #[test] - fn invalid_validate_not_same_unit_id() { - // Given - let req_hdr = Header { - unit_id: 0, - transaction_id: 42, - }; - let rsp_hdr = Header { - unit_id: 5, - transaction_id: 42, - }; - - // When - let result = verify_response_header(&req_hdr, &rsp_hdr); - - // Then - assert!(result.is_err()); - } - - #[test] - fn invalid_validate_not_same_transaction_id() { - // Given - let req_hdr = Header { - unit_id: 0, - transaction_id: 42, - }; - let rsp_hdr = Header { - unit_id: 0, - transaction_id: 86, - }; - - // When - let result = verify_response_header(&req_hdr, &rsp_hdr); - - // Then - assert!(result.is_err()); - } -}