diff --git a/examples/rtu-client.rs b/examples/rtu-client.rs index bf89654..0e6c5c2 100644 --- a/examples/rtu-client.rs +++ b/examples/rtu-client.rs @@ -15,13 +15,18 @@ async fn main() -> Result<(), Box> { let builder = tokio_serial::new(tty_path, 19200); let port = SerialStream::open(&builder).unwrap(); - let mut ctx = rtu::attach_slave(port, slave); + let mut conn = rtu::ClientConnection::new(port); println!("Reading a sensor value"); - let rsp = ctx.read_holding_registers(0x082B, 2).await??; - println!("Sensor value is: {rsp:?}"); + let request = Request::ReadHoldingRegisters(0x082B, 2); + let request_context = conn.send_request(request, slave).await?; + let Response::ReadHoldingRegisters(value) = conn.recv_response(request_context).await?? else { + // The response variant will always match its corresponding request variant if successful. + unreachable!(); + }; + println!("Sensor value is: {value:?}"); println!("Disconnecting"); - ctx.disconnect().await?; + conn.disconnect().await?; Ok(()) } diff --git a/src/frame/mod.rs b/src/frame/mod.rs index 336c436..26ce98e 100644 --- a/src/frame/mod.rs +++ b/src/frame/mod.rs @@ -18,7 +18,7 @@ use crate::bytes::Bytes; /// A Modbus function code. /// /// All function codes as defined by the protocol specification V1.1b3. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum FunctionCode { /// 01 (0x01) Read Coils. ReadCoils, @@ -586,6 +586,24 @@ impl error::Error for ExceptionResponse { } } +/// Check that `req_hdr` is the same `Header` as `rsp_hdr`. +/// +/// # Errors +/// +/// If the 2 headers are different, an error message with the details will be returned. +#[cfg(any(feature = "rtu", feature = "tcp"))] +pub(crate) fn verify_response_header( + req_hdr: &H, + rsp_hdr: &H, +) -> Result<(), String> { + if req_hdr != rsp_hdr { + return Err(format!( + "expected/request = {req_hdr:?}, actual/response = {rsp_hdr:?}" + )); + } + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/frame/rtu.rs b/src/frame/rtu.rs index b483fd9..0576928 100644 --- a/src/frame/rtu.rs +++ b/src/frame/rtu.rs @@ -3,9 +3,22 @@ use super::*; -use crate::slave::SlaveId; +use crate::{ProtocolError, Result, SlaveId}; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct RequestContext { + function_code: FunctionCode, + header: Header, +} + +impl RequestContext { + #[must_use] + pub const fn function_code(&self) -> FunctionCode { + self.function_code + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub(crate) struct Header { pub(crate) slave_id: SlaveId, } @@ -16,12 +29,60 @@ pub struct RequestAdu<'a> { pub(crate) pdu: RequestPdu<'a>, } +impl RequestAdu<'_> { + pub(crate) fn context(&self) -> RequestContext { + RequestContext { + function_code: self.pdu.0.function_code(), + header: self.hdr, + } + } +} + #[derive(Debug, Clone)] pub(crate) struct ResponseAdu { pub(crate) hdr: Header, pub(crate) pdu: ResponsePdu, } +impl ResponseAdu { + pub(crate) fn try_into_response(self, request_context: RequestContext) -> Result { + let RequestContext { + function_code: req_function_code, + header: req_hdr, + } = request_context; + + let ResponseAdu { + hdr: rsp_hdr, + pdu: rsp_pdu, + } = self; + let ResponsePdu(result) = rsp_pdu; + + if let Err(message) = verify_response_header(&req_hdr, &rsp_hdr) { + return Err(ProtocolError::HeaderMismatch { message, result }.into()); + } + + // Match function codes of request and response. + let rsp_function_code = match &result { + Ok(response) => response.function_code(), + Err(ExceptionResponse { function, .. }) => *function, + }; + if req_function_code != rsp_function_code { + return Err(ProtocolError::FunctionCodeMismatch { + request: req_function_code, + result, + } + .into()); + } + + Ok(result.map_err( + |ExceptionResponse { + function: _, + exception, + }| exception, + )) + } +} + impl<'a> From> for Request<'a> { fn from(from: RequestAdu<'a>) -> Self { from.pdu.into() @@ -37,3 +98,34 @@ impl<'a> From> for SlaveRequest<'a> { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_same_headers() { + // Given + let req_hdr = Header { slave_id: 0 }; + let rsp_hdr = Header { slave_id: 0 }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(result.is_ok()); + } + + #[test] + fn invalid_validate_not_same_slave_id() { + // Given + let req_hdr = Header { slave_id: 0 }; + let rsp_hdr = Header { slave_id: 5 }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(result.is_err()); + } +} diff --git a/src/prelude.rs b/src/prelude.rs index 33c694e..f769509 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -12,6 +12,8 @@ pub use crate::client; #[cfg(feature = "rtu")] pub mod rtu { pub use crate::client::rtu::*; + pub use crate::frame::rtu::RequestContext; + pub use crate::service::rtu::ClientConnection; } #[allow(missing_docs)] diff --git a/src/service/mod.rs b/src/service/mod.rs index fb36240..48d7bd0 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -6,18 +6,3 @@ pub(crate) mod rtu; #[cfg(feature = "tcp")] pub(crate) mod tcp; - -/// Check that `req_hdr` is the same `Header` as `rsp_hdr`. -/// -/// # Errors -/// -/// If the 2 headers are different, an error message with the details will be returned. -#[cfg(any(feature = "rtu", feature = "tcp"))] -fn verify_response_header(req_hdr: &H, rsp_hdr: &H) -> Result<(), String> { - if req_hdr != rsp_hdr { - return Err(format!( - "expected/request = {req_hdr:?}, actual/response = {rsp_hdr:?}" - )); - } - Ok(()) -} diff --git a/src/service/rtu.rs b/src/service/rtu.rs index 41a7016..a296ea6 100644 --- a/src/service/rtu.rs +++ b/src/service/rtu.rs @@ -11,102 +11,114 @@ use crate::{ codec, frame::{rtu::*, *}, slave::*, - ProtocolError, Result, + Result, }; -use super::verify_response_header; - -/// Modbus RTU client #[derive(Debug)] -pub(crate) struct Client { - framed: Option>, - slave_id: SlaveId, +pub struct ClientConnection { + framed: Framed, } -impl Client +impl ClientConnection where T: AsyncRead + AsyncWrite + Unpin, { - pub(crate) fn new(transport: T, slave: Slave) -> Self { + pub fn new(transport: T) -> Self { let framed = Framed::new(transport, codec::rtu::ClientCodec::default()); - let slave_id = slave.into(); - Self { - slave_id, - framed: Some(framed), - } + Self { framed } } - fn framed(&mut self) -> io::Result<&mut Framed> { - let Some(framed) = &mut self.framed else { - return Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected")); - }; - Ok(framed) + pub async fn disconnect(self) -> io::Result<()> { + let Self { framed } = self; + framed + .into_inner() + .shutdown() + .await + .or_else(|err| match err.kind() { + io::ErrorKind::NotConnected | io::ErrorKind::BrokenPipe => { + // Already disconnected. + Ok(()) + } + _ => Err(err), + }) } - fn next_request_adu<'a, R>(&self, req: R) -> RequestAdu<'a> + pub async fn send_request<'a>( + &mut self, + request: Request<'a>, + server: Slave, + ) -> io::Result { + self.send_request_pdu(request, server).await + } + + async fn send_request_pdu<'a, R>( + &mut self, + request: R, + server: Slave, + ) -> io::Result where R: Into>, { - let slave_id = self.slave_id; - let hdr = Header { slave_id }; - let pdu = req.into(); - RequestAdu { hdr, pdu } - } - - async fn call(&mut self, req: Request<'_>) -> Result { - log::debug!("Call {:?}", req); + let request_adu = request_adu(request, server); + let context = request_adu.context(); - let req_function_code = req.function_code(); - let req_adu = self.next_request_adu(req); - let req_hdr = req_adu.hdr; - - let framed = self.framed()?; + let Self { framed } = self; framed.read_buffer_mut().clear(); - framed.send(req_adu).await?; + framed.send(request_adu).await?; + + Ok(context) + } - let res_adu = framed + pub async fn recv_response(&mut self, request_context: RequestContext) -> Result { + let res_adu = self + .framed .next() .await .unwrap_or_else(|| Err(io::Error::from(io::ErrorKind::BrokenPipe)))?; - let ResponseAdu { - hdr: res_hdr, - pdu: res_pdu, - } = res_adu; - let ResponsePdu(result) = res_pdu; - - // Match headers of request and response. - if let Err(message) = verify_response_header(&req_hdr, &res_hdr) { - return Err(ProtocolError::HeaderMismatch { message, result }.into()); - } - // Match function codes of request and response. - let rsp_function_code = match &result { - Ok(response) => response.function_code(), - Err(ExceptionResponse { function, .. }) => *function, - }; - if req_function_code != rsp_function_code { - return Err(ProtocolError::FunctionCodeMismatch { - request: req_function_code, - result, - } - .into()); - } + res_adu.try_into_response(request_context) + } +} - Ok(result.map_err( - |ExceptionResponse { - function: _, - exception, - }| exception, - )) +fn request_adu<'a, R>(req: R, server: Slave) -> RequestAdu<'a> +where + R: Into>, +{ + let hdr = Header { + slave_id: server.into(), + }; + let pdu = req.into(); + RequestAdu { hdr, pdu } +} + +/// Modbus RTU client +#[derive(Debug)] +pub(crate) struct Client { + connection: Option>, + slave_id: SlaveId, +} + +impl Client +where + T: AsyncRead + AsyncWrite + Unpin, +{ + pub(crate) fn new(transport: T, slave: Slave) -> Self { + let connection = ClientConnection::new(transport); + let slave_id = slave.into(); + Self { + connection: Some(connection), + slave_id, + } } async fn disconnect(&mut self) -> io::Result<()> { - let Some(framed) = self.framed.take() else { + let Some(connection) = self.connection.take() else { // Already disconnected. return Ok(()); }; - framed + connection + .framed .into_inner() .shutdown() .await @@ -118,6 +130,19 @@ where _ => Err(err), }) } + + async fn call(&mut self, request: Request<'_>) -> Result { + log::debug!("Call {:?}", request); + + let Some(connection) = &mut self.connection else { + return Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected").into()); + }; + + let request_context = connection + .send_request(request, Slave(self.slave_id)) + .await?; + connection.recv_response(request_context).await + } } impl SlaveContext for Client { @@ -149,36 +174,7 @@ mod tests { }; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, Result}; - use crate::{ - service::{rtu::Header, verify_response_header}, - Error, - }; - - #[test] - fn validate_same_headers() { - // Given - let req_hdr = Header { slave_id: 0 }; - let rsp_hdr = Header { slave_id: 0 }; - - // When - let result = verify_response_header(&req_hdr, &rsp_hdr); - - // Then - assert!(result.is_ok()); - } - - #[test] - fn invalid_validate_not_same_slave_id() { - // Given - let req_hdr = Header { slave_id: 0 }; - let rsp_hdr = Header { slave_id: 5 }; - - // When - let result = verify_response_header(&req_hdr, &rsp_hdr); - - // Then - assert!(result.is_err()); - } + use crate::Error; #[derive(Debug)] struct MockTransport; diff --git a/src/service/tcp.rs b/src/service/tcp.rs index 0e2d08b..03309f4 100644 --- a/src/service/tcp.rs +++ b/src/service/tcp.rs @@ -11,9 +11,8 @@ use crate::{ codec, frame::{ tcp::{Header, RequestAdu, ResponseAdu, TransactionId, UnitId}, - RequestPdu, ResponsePdu, + verify_response_header, RequestPdu, ResponsePdu, }, - service::verify_response_header, slave::*, ExceptionResponse, ProtocolError, Request, Response, Result, }; diff --git a/src/slave.rs b/src/slave.rs index 0fa5b94..9dc7765 100644 --- a/src/slave.rs +++ b/src/slave.rs @@ -10,6 +10,7 @@ pub type SlaveId = u8; /// A single byte for addressing Modbus slave devices. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] pub struct Slave(pub SlaveId); impl Slave {