From f8a883406a7e2ad6b229a7eb30cef9ec83b0c5bc Mon Sep 17 00:00:00 2001 From: Benjamin Peter <152640169+creberust@users.noreply.github.com> Date: Fri, 15 Mar 2024 10:49:10 +0100 Subject: [PATCH] Client API: Handle Modbus exceptions (#248) --- CHANGELOG.md | 22 ++ examples/rtu-over-tcp-server.rs | 56 ++-- examples/rtu-server-address.rs | 14 +- examples/rtu-server.rs | 19 +- examples/tcp-client-custom-fn.rs | 2 +- examples/tcp-client.rs | 2 +- examples/tcp-server.rs | 58 ++-- examples/tls-client.rs | 2 +- examples/tls-server.rs | 56 ++-- src/client/mod.rs | 436 +++++++++++++++++++------------ src/client/sync/mod.rs | 24 +- src/client/sync/rtu.rs | 10 +- src/client/sync/tcp.rs | 13 +- src/client/tcp.rs | 6 +- src/error.rs | 20 ++ src/lib.rs | 9 + src/service/mod.rs | 21 ++ src/service/rtu.rs | 60 +++-- src/service/tcp.rs | 93 +++++-- tests/exception/mod.rs | 181 +++++-------- 20 files changed, 634 insertions(+), 470 deletions(-) create mode 100644 src/error.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 01aef383..44749b92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,28 @@ # Changelog +## v0.12.0 (unreleased) + +- Client: Support handling of _Modbus_ exceptions by using nested `Result`s. + +### Breaking Changes + +- Client: All methods in the `Client`, `Reader` and `Writer` traits now return + nested `Result` values that both need to be handled explicitly. + + ```diff + async fn read_coils(&mut self, _: Address, _: Quantity) + - -> Result, std::io::Error>; + + -> Result, Exception>, std::io::Error>; + ``` + +The type alias `tokio_modbus::Result` facilitates referencing the new return +types. + +```rust +pub type Result = Result, std::io::Error> +``` + ## v0.11.0 (2024-01-28) - Server: Remove `Sync` and `Unpin` trait bounds from `Service::call()` future diff --git a/examples/rtu-over-tcp-server.rs b/examples/rtu-over-tcp-server.rs index 857d0429..898929e7 100644 --- a/examples/rtu-over-tcp-server.rs +++ b/examples/rtu-over-tcp-server.rs @@ -33,37 +33,26 @@ impl tokio_modbus::server::Service for ExampleService { fn call(&self, req: Self::Request) -> Self::Future { println!("{}", req.slave); match req.request { - Request::ReadInputRegisters(addr, cnt) => { - match register_read(&self.input_registers.lock().unwrap(), addr, cnt) { - Ok(values) => future::ready(Ok(Response::ReadInputRegisters(values))), - Err(err) => future::ready(Err(err)), - } - } - Request::ReadHoldingRegisters(addr, cnt) => { - match register_read(&self.holding_registers.lock().unwrap(), addr, cnt) { - Ok(values) => future::ready(Ok(Response::ReadHoldingRegisters(values))), - Err(err) => future::ready(Err(err)), - } - } - Request::WriteMultipleRegisters(addr, values) => { - match register_write(&mut self.holding_registers.lock().unwrap(), addr, &values) { - Ok(_) => future::ready(Ok(Response::WriteMultipleRegisters( - addr, - values.len() as u16, - ))), - Err(err) => future::ready(Err(err)), - } - } - Request::WriteSingleRegister(addr, value) => { - match register_write( + Request::ReadInputRegisters(addr, cnt) => future::ready( + register_read(&self.input_registers.lock().unwrap(), addr, cnt) + .map(Response::ReadInputRegisters), + ), + Request::ReadHoldingRegisters(addr, cnt) => future::ready( + register_read(&self.holding_registers.lock().unwrap(), addr, cnt) + .map(Response::ReadHoldingRegisters), + ), + Request::WriteMultipleRegisters(addr, values) => future::ready( + register_write(&mut self.holding_registers.lock().unwrap(), addr, &values) + .map(|_| Response::WriteMultipleRegisters(addr, values.len() as u16)), + ), + Request::WriteSingleRegister(addr, value) => future::ready( + register_write( &mut self.holding_registers.lock().unwrap(), addr, std::slice::from_ref(&value), - ) { - Ok(_) => future::ready(Ok(Response::WriteSingleRegister(addr, value))), - Err(err) => future::ready(Err(err)), - } - } + ) + .map(|_| Response::WriteSingleRegister(addr, value)), + ), _ => { println!("SERVER: Exception::IllegalFunction - Unimplemented function code in request: {req:?}"); future::ready(Err(Exception::IllegalFunction)) @@ -170,28 +159,27 @@ async fn client_context(socket_addr: SocketAddr) { println!("CLIENT: Reading 2 input registers..."); let response = ctx.read_input_registers(0x00, 2).await.unwrap(); println!("CLIENT: The result is '{response:?}'"); - assert_eq!(response, [1234, 5678]); + assert_eq!(response, Ok(vec![1234, 5678])); println!("CLIENT: Writing 2 holding registers..."); ctx.write_multiple_registers(0x01, &[7777, 8888]) .await + .unwrap() .unwrap(); // Read back a block including the two registers we wrote. println!("CLIENT: Reading 4 holding registers..."); let response = ctx.read_holding_registers(0x00, 4).await.unwrap(); println!("CLIENT: The result is '{response:?}'"); - assert_eq!(response, [10, 7777, 8888, 40]); + assert_eq!(response, Ok(vec![10, 7777, 8888, 40])); // Now we try to read with an invalid register address. // This should return a Modbus exception response with the code // IllegalDataAddress. println!("CLIENT: Reading nonexisting holding register address... (should return IllegalDataAddress)"); - let response = ctx.read_holding_registers(0x100, 1).await; + let response = ctx.read_holding_registers(0x100, 1).await.unwrap(); println!("CLIENT: The result is '{response:?}'"); - assert!(response.is_err()); - // TODO: How can Modbus client identify Modbus exception responses? E.g. here we expect IllegalDataAddress - // Question here: https://github.com/slowtec/tokio-modbus/issues/169 + assert_eq!(response, Err(Exception::IllegalDataAddress)); println!("CLIENT: Done.") }, diff --git a/examples/rtu-server-address.rs b/examples/rtu-server-address.rs index 2779f2f2..e558b633 100644 --- a/examples/rtu-server-address.rs +++ b/examples/rtu-server-address.rs @@ -51,12 +51,20 @@ async fn main() -> Result<(), Box> { // Give the server some time for stating up thread::sleep(Duration::from_secs(1)); - println!("Connecting client..."); + println!("CLIENT: Connecting client..."); let client_serial = tokio_serial::SerialStream::open(&builder).unwrap(); let mut ctx = rtu::attach_slave(client_serial, slave); - println!("Reading input registers..."); + println!("CLIENT: Reading input registers..."); let rsp = ctx.read_input_registers(0x00, 7).await?; - println!("The result is '{rsp:#x?}'"); // The result is '[0x0,0x0,0x77,0x0,0x0,0x0,0x0,]' + println!("CLIENT: The result is '{rsp:#x?}'"); + assert_eq!(rsp, Ok(vec![0x0, 0x0, 0x77, 0x0, 0x0, 0x0, 0x0])); + + println!("CLIENT: Reading with illegal function... (should return IllegalFunction)"); + let response = ctx.read_holding_registers(0x100, 1).await.unwrap(); + println!("CLIENT: The result is '{response:?}'"); + assert_eq!(response, Err(Exception::IllegalFunction)); + + println!("CLIENT: Done."); Ok(()) } diff --git a/examples/rtu-server.rs b/examples/rtu-server.rs index 4ea97fd6..d1001f94 100644 --- a/examples/rtu-server.rs +++ b/examples/rtu-server.rs @@ -20,6 +20,9 @@ impl tokio_modbus::server::Service for Service { registers[2] = 0x77; future::ready(Ok(Response::ReadInputRegisters(registers))) } + Request::ReadHoldingRegisters(_, _) => { + future::ready(Err(Exception::IllegalDataAddress)) + } _ => unimplemented!(), } } @@ -45,12 +48,22 @@ async fn main() -> Result<(), Box> { // Give the server some time for stating up thread::sleep(Duration::from_secs(1)); - println!("Connecting client..."); + println!("CLIENT: Connecting client..."); let client_serial = tokio_serial::SerialStream::open(&builder).unwrap(); let mut ctx = rtu::attach(client_serial); - println!("Reading input registers..."); + println!("CLIENT: Reading input registers..."); let rsp = ctx.read_input_registers(0x00, 7).await?; - println!("The result is '{rsp:#x?}'"); // The result is '[0x0,0x0,0x77,0x0,0x0,0x0,0x0,]' + println!("CLIENT: The result is '{rsp:#x?}'"); + assert_eq!(rsp, Ok(vec![0x0, 0x0, 0x77, 0x0, 0x0, 0x0, 0x0])); + + // Now we try to read with an invalid register address. + // This should return a Modbus exception response with the code + // IllegalDataAddress. + println!("CLIENT: Reading nonexisting holding register address... (should return IllegalDataAddress)"); + let response = ctx.read_holding_registers(0x100, 1).await.unwrap(); + println!("CLIENT: The result is '{response:?}'"); + assert_eq!(response, Err(Exception::IllegalDataAddress)); + println!("CLIENT: Done."); Ok(()) } diff --git a/examples/tcp-client-custom-fn.rs b/examples/tcp-client-custom-fn.rs index c0f690d7..610d6712 100644 --- a/examples/tcp-client-custom-fn.rs +++ b/examples/tcp-client-custom-fn.rs @@ -16,7 +16,7 @@ async fn main() -> Result<(), Box> { println!("Fetching the coupler ID"); let rsp = ctx .call(Request::Custom(0x66, Cow::Borrowed(&[0x11, 0x42]))) - .await?; + .await??; match rsp { Response::Custom(f, rsp) => { diff --git a/examples/tcp-client.rs b/examples/tcp-client.rs index f35948c7..36e06bbb 100644 --- a/examples/tcp-client.rs +++ b/examples/tcp-client.rs @@ -12,7 +12,7 @@ async fn main() -> Result<(), Box> { let mut ctx = tcp::connect(socket_addr).await?; println!("Fetching the coupler ID"); - let data = ctx.read_input_registers(0x1000, 7).await?; + let data = ctx.read_input_registers(0x1000, 7).await??; let bytes: Vec = data.iter().fold(vec![], |mut x, elem| { x.push((elem & 0xff) as u8); diff --git a/examples/tcp-server.rs b/examples/tcp-server.rs index 2f3565cb..2466e3c3 100644 --- a/examples/tcp-server.rs +++ b/examples/tcp-server.rs @@ -32,37 +32,26 @@ impl tokio_modbus::server::Service for ExampleService { fn call(&self, req: Self::Request) -> Self::Future { match req { - Request::ReadInputRegisters(addr, cnt) => { - match register_read(&self.input_registers.lock().unwrap(), addr, cnt) { - Ok(values) => future::ready(Ok(Response::ReadInputRegisters(values))), - Err(err) => future::ready(Err(err)), - } - } - Request::ReadHoldingRegisters(addr, cnt) => { - match register_read(&self.holding_registers.lock().unwrap(), addr, cnt) { - Ok(values) => future::ready(Ok(Response::ReadHoldingRegisters(values))), - Err(err) => future::ready(Err(err)), - } - } - Request::WriteMultipleRegisters(addr, values) => { - match register_write(&mut self.holding_registers.lock().unwrap(), addr, &values) { - Ok(_) => future::ready(Ok(Response::WriteMultipleRegisters( - addr, - values.len() as u16, - ))), - Err(err) => future::ready(Err(err)), - } - } - Request::WriteSingleRegister(addr, value) => { - match register_write( + Request::ReadInputRegisters(addr, cnt) => future::ready( + register_read(&self.input_registers.lock().unwrap(), addr, cnt) + .map(Response::ReadInputRegisters), + ), + Request::ReadHoldingRegisters(addr, cnt) => future::ready( + register_read(&self.holding_registers.lock().unwrap(), addr, cnt) + .map(Response::ReadHoldingRegisters), + ), + Request::WriteMultipleRegisters(addr, values) => future::ready( + register_write(&mut self.holding_registers.lock().unwrap(), addr, &values) + .map(|_| Response::WriteMultipleRegisters(addr, values.len() as u16)), + ), + Request::WriteSingleRegister(addr, value) => future::ready( + register_write( &mut self.holding_registers.lock().unwrap(), addr, std::slice::from_ref(&value), - ) { - Ok(_) => future::ready(Ok(Response::WriteSingleRegister(addr, value))), - Err(err) => future::ready(Err(err)), - } - } + ) + .map(|_| Response::WriteSingleRegister(addr, value)), + ), _ => { println!("SERVER: Exception::IllegalFunction - Unimplemented function code in request: {req:?}"); future::ready(Err(Exception::IllegalFunction)) @@ -101,7 +90,6 @@ fn register_read( if let Some(r) = registers.get(®_addr) { response_values[i as usize] = *r; } else { - // TODO: Return a Modbus Exception response `IllegalDataAddress` https://github.com/slowtec/tokio-modbus/issues/165 println!("SERVER: Exception::IllegalDataAddress"); return Err(Exception::IllegalDataAddress); } @@ -122,7 +110,6 @@ fn register_write( if let Some(r) = registers.get_mut(®_addr) { *r = *value; } else { - // TODO: Return a Modbus Exception response `IllegalDataAddress` https://github.com/slowtec/tokio-modbus/issues/165 println!("SERVER: Exception::IllegalDataAddress"); return Err(Exception::IllegalDataAddress); } @@ -170,28 +157,27 @@ async fn client_context(socket_addr: SocketAddr) { println!("CLIENT: Reading 2 input registers..."); let response = ctx.read_input_registers(0x00, 2).await.unwrap(); println!("CLIENT: The result is '{response:?}'"); - assert_eq!(response, [1234, 5678]); + assert_eq!(response, Ok(vec![1234, 5678])); println!("CLIENT: Writing 2 holding registers..."); ctx.write_multiple_registers(0x01, &[7777, 8888]) .await + .unwrap() .unwrap(); // Read back a block including the two registers we wrote. println!("CLIENT: Reading 4 holding registers..."); let response = ctx.read_holding_registers(0x00, 4).await.unwrap(); println!("CLIENT: The result is '{response:?}'"); - assert_eq!(response, [10, 7777, 8888, 40]); + assert_eq!(response, Ok(vec![10, 7777, 8888, 40])); // Now we try to read with an invalid register address. // This should return a Modbus exception response with the code // IllegalDataAddress. println!("CLIENT: Reading nonexisting holding register address... (should return IllegalDataAddress)"); - let response = ctx.read_holding_registers(0x100, 1).await; + let response = ctx.read_holding_registers(0x100, 1).await.unwrap(); println!("CLIENT: The result is '{response:?}'"); - assert!(response.is_err()); - // TODO: How can Modbus client identify Modbus exception responses? E.g. here we expect IllegalDataAddress - // Question here: https://github.com/slowtec/tokio-modbus/issues/169 + assert_eq!(response, Err(Exception::IllegalDataAddress)); println!("CLIENT: Done.") }, diff --git a/examples/tls-client.rs b/examples/tls-client.rs index 918ca901..f37fe8d0 100644 --- a/examples/tls-client.rs +++ b/examples/tls-client.rs @@ -114,7 +114,7 @@ async fn main() -> Result<(), Box> { println!("Reading Holding Registers"); let data = ctx.read_holding_registers(40000, 68).await?; println!("Holding Registers Data is '{:?}'", data); - ctx.disconnect().await?; + ctx.disconnect().await??; Ok(()) } diff --git a/examples/tls-server.rs b/examples/tls-server.rs index a6a809a8..6f4d83e2 100644 --- a/examples/tls-server.rs +++ b/examples/tls-server.rs @@ -82,37 +82,26 @@ impl tokio_modbus::server::Service for ExampleService { fn call(&self, req: Self::Request) -> Self::Future { match req { - Request::ReadInputRegisters(addr, cnt) => { - match register_read(&self.input_registers.lock().unwrap(), addr, cnt) { - Ok(values) => future::ready(Ok(Response::ReadInputRegisters(values))), - Err(err) => future::ready(Err(err)), - } - } - Request::ReadHoldingRegisters(addr, cnt) => { - match register_read(&self.holding_registers.lock().unwrap(), addr, cnt) { - Ok(values) => future::ready(Ok(Response::ReadHoldingRegisters(values))), - Err(err) => future::ready(Err(err)), - } - } - Request::WriteMultipleRegisters(addr, values) => { - match register_write(&mut self.holding_registers.lock().unwrap(), addr, &values) { - Ok(_) => future::ready(Ok(Response::WriteMultipleRegisters( - addr, - values.len() as u16, - ))), - Err(err) => future::ready(Err(err)), - } - } - Request::WriteSingleRegister(addr, value) => { - match register_write( + Request::ReadInputRegisters(addr, cnt) => future::ready( + register_read(&self.input_registers.lock().unwrap(), addr, cnt) + .map(Response::ReadInputRegisters), + ), + Request::ReadHoldingRegisters(addr, cnt) => future::ready( + register_read(&self.holding_registers.lock().unwrap(), addr, cnt) + .map(Response::ReadHoldingRegisters), + ), + Request::WriteMultipleRegisters(addr, values) => future::ready( + register_write(&mut self.holding_registers.lock().unwrap(), addr, &values) + .map(|_| Response::WriteMultipleRegisters(addr, values.len() as u16)), + ), + Request::WriteSingleRegister(addr, value) => future::ready( + register_write( &mut self.holding_registers.lock().unwrap(), addr, std::slice::from_ref(&value), - ) { - Ok(_) => future::ready(Ok(Response::WriteSingleRegister(addr, value))), - Err(err) => future::ready(Err(err)), - } - } + ) + .map(|_| Response::WriteSingleRegister(addr, value)), + ), _ => { println!("SERVER: Exception::IllegalFunction - Unimplemented function code in request: {req:?}"); future::ready(Err(Exception::IllegalFunction)) @@ -273,28 +262,27 @@ async fn client_context(socket_addr: SocketAddr) { println!("CLIENT: Reading 2 input registers..."); let response = ctx.read_input_registers(0x00, 2).await.unwrap(); println!("CLIENT: The result is '{response:?}'"); - assert_eq!(response, [1234, 5678]); + assert_eq!(response, Ok(vec![1234, 5678])); println!("CLIENT: Writing 2 holding registers..."); ctx.write_multiple_registers(0x01, &[7777, 8888]) .await + .unwrap() .unwrap(); // Read back a block including the two registers we wrote. println!("CLIENT: Reading 4 holding registers..."); let response = ctx.read_holding_registers(0x00, 4).await.unwrap(); println!("CLIENT: The result is '{response:?}'"); - assert_eq!(response, [10, 7777, 8888, 40]); + assert_eq!(response, Ok(vec![10, 7777, 8888, 40])); // Now we try to read with an invalid register address. // This should return a Modbus exception response with the code // IllegalDataAddress. println!("CLIENT: Reading nonexisting holding register address... (should return IllegalDataAddress)"); - let response = ctx.read_holding_registers(0x100, 1).await; + let response = ctx.read_holding_registers(0x100, 1).await.unwrap(); println!("CLIENT: The result is '{response:?}'"); - assert!(response.is_err()); - // TODO: How can Modbus client identify Modbus exception responses? E.g. here we expect IllegalDataAddress - // Question here: https://github.com/slowtec/tokio-modbus/issues/169 + assert_eq!(response, Err(Exception::IllegalDataAddress)); println!("CLIENT: Done.") }, diff --git a/src/client/mod.rs b/src/client/mod.rs index e89149e9..dc7cf72f 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -3,15 +3,11 @@ //! Modbus clients -use std::{ - borrow::Cow, - fmt::Debug, - io::{Error, ErrorKind}, -}; +use std::{borrow::Cow, fmt::Debug, io}; use async_trait::async_trait; -use crate::{frame::*, slave::*}; +use crate::{error::unexpected_rsp_code_panic_msg, frame::*, slave::*, Result}; #[cfg(feature = "rtu")] pub mod rtu; @@ -26,24 +22,23 @@ pub mod sync; #[async_trait] pub trait Client: SlaveContext + Send + Debug { /// Invoke a Modbus function - async fn call(&mut self, request: Request<'_>) -> Result; + async fn call(&mut self, request: Request<'_>) -> Result; } /// Asynchronous Modbus reader #[async_trait] pub trait Reader: Client { /// Read multiple coils (0x01) - async fn read_coils(&mut self, _: Address, _: Quantity) -> Result, Error>; + async fn read_coils(&mut self, addr: Address, cnt: Quantity) -> Result>; /// Read multiple discrete inputs (0x02) - async fn read_discrete_inputs(&mut self, _: Address, _: Quantity) -> Result, Error>; + async fn read_discrete_inputs(&mut self, addr: Address, cnt: Quantity) -> Result>; /// Read multiple holding registers (0x03) - async fn read_holding_registers(&mut self, _: Address, _: Quantity) - -> Result, Error>; + async fn read_holding_registers(&mut self, addr: Address, cnt: Quantity) -> Result>; /// Read multiple input registers (0x04) - async fn read_input_registers(&mut self, _: Address, _: Quantity) -> Result, Error>; + async fn read_input_registers(&mut self, addr: Address, cnt: Quantity) -> Result>; /// Read and write multiple holding registers (0x17) /// @@ -55,27 +50,31 @@ pub trait Reader: Client { read_count: Quantity, write_addr: Address, write_data: &[Word], - ) -> Result, Error>; + ) -> Result>; } /// Asynchronous Modbus writer #[async_trait] pub trait Writer: Client { /// Write a single coil (0x05) - async fn write_single_coil(&mut self, _: Address, _: Coil) -> Result<(), Error>; + async fn write_single_coil(&mut self, addr: Address, coil: Coil) -> Result<()>; /// Write a single holding register (0x06) - async fn write_single_register(&mut self, _: Address, _: Word) -> Result<(), Error>; + async fn write_single_register(&mut self, addr: Address, word: Word) -> Result<()>; /// Write multiple coils (0x0F) - async fn write_multiple_coils(&mut self, addr: Address, data: &'_ [Coil]) -> Result<(), Error>; + async fn write_multiple_coils(&mut self, addr: Address, coils: &'_ [Coil]) -> Result<()>; /// Write multiple holding registers (0x10) - async fn write_multiple_registers(&mut self, addr: Address, data: &[Word]) - -> Result<(), Error>; + async fn write_multiple_registers(&mut self, addr: Address, words: &[Word]) -> Result<()>; /// Set or clear individual bits of a holding register (0x16) - async fn masked_write_register(&mut self, _: Address, _: Word, _: Word) -> Result<(), Error>; + async fn masked_write_register( + &mut self, + addr: Address, + and_mask: Word, + or_mask: Word, + ) -> Result<()>; } /// Asynchronous Modbus client context @@ -86,13 +85,13 @@ pub struct Context { impl Context { /// Disconnect the client - pub async fn disconnect(&mut self) -> Result<(), Error> { + pub async fn disconnect(&mut self) -> Result<()> { // Disconnecting is expected to fail! let res = self.client.call(Request::Disconnect).await; match res { Ok(_) => unreachable!(), Err(err) => match err.kind() { - ErrorKind::NotConnected | ErrorKind::BrokenPipe => Ok(()), + io::ErrorKind::NotConnected | io::ErrorKind::BrokenPipe => Ok(Ok(())), _ => Err(err), }, } @@ -113,7 +112,7 @@ impl From for Box { #[async_trait] impl Client for Context { - async fn call(&mut self, request: Request<'_>) -> Result { + async fn call(&mut self, request: Request<'_>) -> Result { self.client.call(request).await } } @@ -126,79 +125,118 @@ impl SlaveContext for Context { #[async_trait] impl Reader for Context { - async fn read_coils<'a>( - &'a mut self, - addr: Address, - cnt: Quantity, - ) -> Result, Error> { - let rsp = self.client.call(Request::ReadCoils(addr, cnt)).await?; - - if let Response::ReadCoils(mut coils) = rsp { - debug_assert!(coils.len() >= cnt.into()); - coils.truncate(cnt.into()); - Ok(coils) - } else { - Err(Error::new(ErrorKind::InvalidData, "unexpected response")) - } + async fn read_coils<'a>(&'a mut self, addr: Address, cnt: Quantity) -> Result> { + self.client + .call(Request::ReadCoils(addr, cnt)) + .await + .map(|modbus_rsp| { + modbus_rsp.map(|rsp| match rsp { + Response::ReadCoils(mut coils) => { + debug_assert!(coils.len() >= cnt.into()); + coils.truncate(cnt.into()); + coils + } + others => { + // NOTE: A call to `Client::call` implementation *MUST* always return the `Response` variant matching the `Request` one. + // TIPS: This can be ensured via a call to `verify_response_header`( in 'src/service/mod.rs') before returning from `Client::call`. + unreachable!( + "{}", + unexpected_rsp_code_panic_msg( + FunctionCode::ReadCoils, + others.function_code() + ), + ) + } + }) + }) } async fn read_discrete_inputs<'a>( &'a mut self, addr: Address, cnt: Quantity, - ) -> Result, Error> { - let rsp = self - .client + ) -> Result> { + self.client .call(Request::ReadDiscreteInputs(addr, cnt)) - .await?; - - if let Response::ReadDiscreteInputs(mut coils) = rsp { - debug_assert!(coils.len() >= cnt.into()); - coils.truncate(cnt.into()); - Ok(coils) - } else { - Err(Error::new(ErrorKind::InvalidData, "unexpected response")) - } + .await + .map(|modbus_rsp| { + modbus_rsp.map(|rsp| match rsp { + Response::ReadDiscreteInputs(mut coils) => { + debug_assert!(coils.len() >= cnt.into()); + coils.truncate(cnt.into()); + coils + } + others => { + // NOTE: A call to `Client::call` implementation *MUST* always return the `Response` variant matching the `Request` one. + // TIPS: This can be ensured via a call to `verify_response_header`( in 'src/service/mod.rs') before returning from `Client::call`. + unreachable!( + "{}", + unexpected_rsp_code_panic_msg( + FunctionCode::ReadDiscreteInputs, + others.function_code() + ), + ) + } + }) + }) } async fn read_input_registers<'a>( &'a mut self, addr: Address, cnt: Quantity, - ) -> Result, Error> { - let rsp = self - .client + ) -> Result> { + self.client .call(Request::ReadInputRegisters(addr, cnt)) - .await?; - - if let Response::ReadInputRegisters(rsp) = rsp { - if rsp.len() != cnt.into() { - return Err(Error::new(ErrorKind::InvalidData, "invalid response")); - } - Ok(rsp) - } else { - Err(Error::new(ErrorKind::InvalidData, "unexpected response")) - } + .await + .map(|modbus_rsp| { + modbus_rsp.map(|rsp| match rsp { + Response::ReadInputRegisters(words) => { + debug_assert_eq!(words.len(), cnt.into()); + words + } + others => { + // NOTE: A call to `Client::call` implementation *MUST* always return the `Response` variant matching the `Request` one. + // TIPS: This can be ensured via a call to `verify_response_header`( in 'src/service/mod.rs') before returning from `Client::call`. + unreachable!( + "{}", + unexpected_rsp_code_panic_msg( + FunctionCode::ReadInputRegisters, + others.function_code() + ), + ) + } + }) + }) } async fn read_holding_registers<'a>( &'a mut self, addr: Address, cnt: Quantity, - ) -> Result, Error> { - let rsp = self - .client + ) -> Result> { + self.client .call(Request::ReadHoldingRegisters(addr, cnt)) - .await?; - - if let Response::ReadHoldingRegisters(rsp) = rsp { - if rsp.len() != cnt.into() { - return Err(Error::new(ErrorKind::InvalidData, "invalid response")); - } - Ok(rsp) - } else { - Err(Error::new(ErrorKind::InvalidData, "unexpected response")) - } + .await + .map(|modbus_rsp| { + modbus_rsp.map(|rsp| match rsp { + Response::ReadHoldingRegisters(words) => { + debug_assert_eq!(words.len(), cnt.into()); + words + } + others => { + // NOTE: A call to `Client::call` implementation *MUST* always return the `Response` variant matching the `Request` one. + // TIPS: This can be ensured via a call to `verify_response_header`( in 'src/service/mod.rs') before returning from `Client::call`. + unreachable!( + "{}", + unexpected_rsp_code_panic_msg( + FunctionCode::ReadHoldingRegisters, + others.function_code() + ), + ) + } + }) + }) } async fn read_write_multiple_registers<'a>( @@ -207,127 +245,176 @@ impl Reader for Context { read_count: Quantity, write_addr: Address, write_data: &[Word], - ) -> Result, Error> { - let rsp = self - .client + ) -> Result> { + self.client .call(Request::ReadWriteMultipleRegisters( read_addr, read_count, write_addr, Cow::Borrowed(write_data), )) - .await?; - - if let Response::ReadWriteMultipleRegisters(rsp) = rsp { - if rsp.len() != read_count.into() { - return Err(Error::new(ErrorKind::InvalidData, "invalid response")); - } - Ok(rsp) - } else { - Err(Error::new(ErrorKind::InvalidData, "unexpected response")) - } + .await + .map(|modbus_rsp| { + modbus_rsp.map(|rsp| match rsp { + Response::ReadWriteMultipleRegisters(words) => { + debug_assert_eq!(words.len(), read_count.into()); + words + } + others => { + // NOTE: A call to `Client::call` implementation *MUST* always return the `Response` variant matching the `Request` one. + // TIPS: This can be ensured via a call to `verify_response_header`( in 'src/service/mod.rs') before returning from `Client::call`. + unreachable!( + "{}", + unexpected_rsp_code_panic_msg( + FunctionCode::ReadWriteMultipleRegisters, + others.function_code() + ), + ) + } + }) + }) } } #[async_trait] impl Writer for Context { - async fn write_single_coil<'a>(&'a mut self, addr: Address, coil: Coil) -> Result<(), Error> { - let rsp = self - .client + async fn write_single_coil<'a>(&'a mut self, addr: Address, coil: Coil) -> Result<()> { + self.client .call(Request::WriteSingleCoil(addr, coil)) - .await?; - - if let Response::WriteSingleCoil(rsp_addr, rsp_coil) = rsp { - if rsp_addr != addr || rsp_coil != coil { - return Err(Error::new(ErrorKind::InvalidData, "invalid response")); - } - Ok(()) - } else { - Err(Error::new(ErrorKind::InvalidData, "unexpected response")) - } + .await + .map(|modbus_rsp| { + modbus_rsp.map(|rsp| match rsp { + Response::WriteSingleCoil(rsp_addr, rsp_coil) => { + debug_assert_eq!(addr, rsp_addr); + debug_assert_eq!(coil, rsp_coil); + } + others => { + // NOTE: A call to `Client::call` implementation *MUST* always return the `Response` variant matching the `Request` one. + // TIPS: This can be ensured via a call to `verify_response_header`( in 'src/service/mod.rs') before returning from `Client::call`. + unreachable!( + "{}", + unexpected_rsp_code_panic_msg( + FunctionCode::WriteSingleCoil, + others.function_code() + ), + ) + } + }) + }) } - async fn write_multiple_coils<'a>( - &'a mut self, - addr: Address, - coils: &[Coil], - ) -> Result<(), Error> { + async fn write_multiple_coils<'a>(&'a mut self, addr: Address, coils: &[Coil]) -> Result<()> { let cnt = coils.len(); - let rsp = self - .client - .call(Request::WriteMultipleCoils(addr, Cow::Borrowed(coils))) - .await?; - if let Response::WriteMultipleCoils(rsp_addr, rsp_cnt) = rsp { - if rsp_addr != addr || usize::from(rsp_cnt) != cnt { - return Err(Error::new(ErrorKind::InvalidData, "invalid response")); - } - Ok(()) - } else { - Err(Error::new(ErrorKind::InvalidData, "unexpected response")) - } + self.client + .call(Request::WriteMultipleCoils(addr, Cow::Borrowed(coils))) + .await + .map(|modbus_rsp| { + modbus_rsp.map(|rsp| match rsp { + Response::WriteMultipleCoils(rsp_addr, rsp_cnt) => { + debug_assert_eq!(addr, rsp_addr); + debug_assert_eq!(cnt, rsp_cnt.into()); + } + others => { + // NOTE: A call to `Client::call` implementation *MUST* always return the `Response` variant matching the `Request` one. + // TIPS: This can be ensured via a call to `verify_response_header`( in 'src/service/mod.rs') before returning from `Client::call`. + unreachable!( + "{}", + unexpected_rsp_code_panic_msg( + FunctionCode::WriteMultipleCoils, + others.function_code() + ), + ) + } + }) + }) } - async fn write_single_register<'a>( - &'a mut self, - addr: Address, - data: Word, - ) -> Result<(), Error> { - let rsp = self - .client - .call(Request::WriteSingleRegister(addr, data)) - .await?; - - if let Response::WriteSingleRegister(rsp_addr, rsp_word) = rsp { - if rsp_addr != addr || rsp_word != data { - return Err(Error::new(ErrorKind::InvalidData, "invalid response")); - } - Ok(()) - } else { - Err(Error::new(ErrorKind::InvalidData, "unexpected response")) - } + async fn write_single_register<'a>(&'a mut self, addr: Address, word: Word) -> Result<()> { + self.client + .call(Request::WriteSingleRegister(addr, word)) + .await + .map(|modbus_rsp| { + modbus_rsp.map(|rsp| match rsp { + Response::WriteSingleRegister(rsp_addr, rsp_word) => { + debug_assert_eq!(addr, rsp_addr); + debug_assert_eq!(word, rsp_word); + } + others => { + // NOTE: A call to `Client::call` implementation *MUST* always return the `Response` variant matching the `Request` one. + // TIPS: This can be ensured via a call to `verify_response_header`( in 'src/service/mod.rs') before returning from `Client::call`. + unreachable!( + "{}", + unexpected_rsp_code_panic_msg( + FunctionCode::WriteSingleRegister, + others.function_code() + ), + ) + } + }) + }) } async fn write_multiple_registers<'a>( &'a mut self, addr: Address, data: &[Word], - ) -> Result<(), Error> { + ) -> Result<()> { let cnt = data.len(); - let rsp = self - .client - .call(Request::WriteMultipleRegisters(addr, Cow::Borrowed(data))) - .await?; - if let Response::WriteMultipleRegisters(rsp_addr, rsp_cnt) = rsp { - if rsp_addr != addr || usize::from(rsp_cnt) != cnt { - return Err(Error::new(ErrorKind::InvalidData, "invalid response")); - } - Ok(()) - } else { - Err(Error::new(ErrorKind::InvalidData, "unexpected response")) - } + self.client + .call(Request::WriteMultipleRegisters(addr, Cow::Borrowed(data))) + .await + .map(|modbus_rsp| { + modbus_rsp.map(|rsp| match rsp { + Response::WriteMultipleRegisters(rsp_addr, rsp_cnt) => { + debug_assert_eq!(addr, rsp_addr); + debug_assert_eq!(cnt, rsp_cnt.into()); + } + others => { + // NOTE: A call to `Client::call` implementation *MUST* always return the `Response` variant matching the `Request` one. + // TIPS: This can be ensured via a call to `verify_response_header`( in 'src/service/mod.rs') before returning from `Client::call`. + unreachable!( + "{}", + unexpected_rsp_code_panic_msg( + FunctionCode::WriteMultipleRegisters, + others.function_code() + ), + ) + } + }) + }) } async fn masked_write_register<'a>( &'a mut self, - address: Address, + addr: Address, and_mask: Word, or_mask: Word, - ) -> Result<(), Error> { - let rsp = self - .client - .call(Request::MaskWriteRegister(address, and_mask, or_mask)) - .await?; - - if let Response::MaskWriteRegister(addr, and, or) = rsp { - if addr != address || and != and_mask || or != or_mask { - return Err(Error::new(ErrorKind::InvalidData, "invalid response")); - } - Ok(()) - } else { - Err(Error::new(ErrorKind::InvalidData, "unexpected response")) - } + ) -> Result<()> { + self.client + .call(Request::MaskWriteRegister(addr, and_mask, or_mask)) + .await + .map(|modbus_rsp| { + modbus_rsp.map(|rsp| match rsp { + Response::MaskWriteRegister(rsp_addr, rsp_and_mask, rsp_or_mask) => { + debug_assert_eq!(addr, rsp_addr); + debug_assert_eq!(and_mask, rsp_and_mask); + debug_assert_eq!(or_mask, rsp_or_mask); + } + others => { + // NOTE: A call to `Client::call` implementation *MUST* always return the `Response` variant matching the `Request` one. + // TIPS: This can be ensured via a call to `verify_response_header`( in 'src/service/mod.rs') before returning from `Client::call`. + unreachable!( + "{}", + unexpected_rsp_code_panic_msg( + FunctionCode::MaskWriteRegister, + others.function_code() + ), + ) + } + }) + }) } } @@ -340,7 +427,7 @@ mod tests { pub(crate) struct ClientMock { slave: Option, last_request: Mutex>>, - next_response: Option>, + next_response: Option>, } #[allow(dead_code)] @@ -353,18 +440,18 @@ mod tests { &self.last_request } - pub(crate) fn set_next_response(&mut self, next_response: Result) { + pub(crate) fn set_next_response(&mut self, next_response: Result) { self.next_response = Some(next_response); } } #[async_trait] impl Client for ClientMock { - async fn call(&mut self, request: Request<'_>) -> Result { + async fn call(&mut self, request: Request<'_>) -> Result { *self.last_request.lock().unwrap() = Some(request.into_owned()); match self.next_response.as_ref().unwrap() { Ok(response) => Ok(response.clone()), - Err(err) => Err(Error::new(err.kind(), format!("{err}"))), + Err(err) => Err(io::Error::new(err.kind(), format!("{err}"))), } } } @@ -382,10 +469,12 @@ mod tests { let response_coils = [true, false, false, true, false, true, false, true]; for num_coils in 1..8 { let mut client = Box::::default(); - client.set_next_response(Ok(Response::ReadCoils(response_coils.to_vec()))); + client.set_next_response(Ok(Ok(Response::ReadCoils(response_coils.to_vec())))); let mut context = Context { client }; context.set_slave(Slave(1)); - let coils = futures::executor::block_on(context.read_coils(1, num_coils)).unwrap(); + let coils = futures::executor::block_on(context.read_coils(1, num_coils)) + .unwrap() + .unwrap(); assert_eq!(&response_coils[0..num_coils as usize], &coils[..]); } } @@ -397,11 +486,14 @@ mod tests { let response_inputs = [true, false, false, true, false, true, false, true]; for num_inputs in 1..8 { let mut client = Box::::default(); - client.set_next_response(Ok(Response::ReadDiscreteInputs(response_inputs.to_vec()))); + client.set_next_response(Ok(Ok(Response::ReadDiscreteInputs( + response_inputs.to_vec(), + )))); let mut context = Context { client }; context.set_slave(Slave(1)); - let inputs = - futures::executor::block_on(context.read_discrete_inputs(1, num_inputs)).unwrap(); + let inputs = futures::executor::block_on(context.read_discrete_inputs(1, num_inputs)) + .unwrap() + .unwrap(); assert_eq!(&response_inputs[0..num_inputs as usize], &inputs[..]); } } diff --git a/src/client/sync/mod.rs b/src/client/sync/mod.rs index d2898a80..2686561b 100644 --- a/src/client/sync/mod.rs +++ b/src/client/sync/mod.rs @@ -9,11 +9,11 @@ pub mod rtu; #[cfg(feature = "tcp-sync")] pub mod tcp; -use std::{future::Future, io::Result, time::Duration}; +use std::{future::Future, io, time::Duration}; use futures_util::future::Either; -use crate::{frame::*, slave::*}; +use crate::{frame::*, slave::*, Result}; use super::{ Client as AsyncClient, Context as AsyncContext, Reader as AsyncReader, SlaveContext, @@ -23,8 +23,8 @@ use super::{ fn block_on_with_timeout( runtime: &tokio::runtime::Runtime, timeout: Option, - task: impl Future>, -) -> Result { + task: impl Future>, +) -> io::Result { let task = if let Some(duration) = timeout { Either::Left(async move { tokio::time::timeout(duration, task) @@ -48,10 +48,10 @@ pub trait Client: SlaveContext { /// /// The synchronous counterpart of the asynchronous [`Reader`](`crate::client::Reader`) trait. pub trait Reader: Client { - fn read_coils(&mut self, _: Address, _: Quantity) -> Result>; - fn read_discrete_inputs(&mut self, _: Address, _: Quantity) -> Result>; - fn read_input_registers(&mut self, _: Address, _: Quantity) -> Result>; - fn read_holding_registers(&mut self, _: Address, _: Quantity) -> Result>; + fn read_coils(&mut self, addr: Address, cnt: Quantity) -> Result>; + fn read_discrete_inputs(&mut self, addr: Address, cnt: Quantity) -> Result>; + fn read_input_registers(&mut self, addr: Address, cnt: Quantity) -> Result>; + fn read_holding_registers(&mut self, addr: Address, cnt: Quantity) -> Result>; fn read_write_multiple_registers( &mut self, read_addr: Address, @@ -65,10 +65,10 @@ pub trait Reader: Client { /// /// The synchronous counterpart of the asynchronous [`Writer`](`crate::client::Writer`) trait. pub trait Writer: Client { - fn write_single_coil(&mut self, _: Address, _: Coil) -> Result<()>; - fn write_multiple_coils(&mut self, addr: Address, data: &[Coil]) -> Result<()>; - fn write_single_register(&mut self, _: Address, _: Word) -> Result<()>; - fn write_multiple_registers(&mut self, addr: Address, data: &[Word]) -> Result<()>; + fn write_single_coil(&mut self, addr: Address, coil: Coil) -> Result<()>; + fn write_multiple_coils(&mut self, addr: Address, coils: &[Coil]) -> Result<()>; + fn write_single_register(&mut self, addr: Address, word: Word) -> Result<()>; + fn write_multiple_registers(&mut self, addr: Address, words: &[Word]) -> Result<()>; } /// A synchronous Modbus client context. diff --git a/src/client/sync/rtu.rs b/src/client/sync/rtu.rs index 2603709d..c311137e 100644 --- a/src/client/sync/rtu.rs +++ b/src/client/sync/rtu.rs @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2017-2024 slowtec GmbH // SPDX-License-Identifier: MIT OR Apache-2.0 -use std::{io::Result, time::Duration}; +use std::{io, time::Duration}; use super::{block_on_with_timeout, Context}; @@ -11,7 +11,7 @@ use crate::slave::Slave; /// Connect to no particular Modbus slave device for sending /// broadcast messages. -pub fn connect(builder: &SerialPortBuilder) -> Result { +pub fn connect(builder: &SerialPortBuilder) -> io::Result { connect_slave(builder, Slave::broadcast()) } @@ -20,12 +20,12 @@ pub fn connect(builder: &SerialPortBuilder) -> Result { pub fn connect_with_timeout( builder: &SerialPortBuilder, timeout: Option, -) -> Result { +) -> io::Result { connect_slave_with_timeout(builder, Slave::broadcast(), timeout) } /// Connect to any kind of Modbus slave device. -pub fn connect_slave(builder: &SerialPortBuilder, slave: Slave) -> Result { +pub fn connect_slave(builder: &SerialPortBuilder, slave: Slave) -> io::Result { connect_slave_with_timeout(builder, slave, None) } @@ -34,7 +34,7 @@ pub fn connect_slave_with_timeout( builder: &SerialPortBuilder, slave: Slave, timeout: Option, -) -> Result { +) -> io::Result { let runtime = tokio::runtime::Builder::new_current_thread() .enable_io() .enable_time() diff --git a/src/client/sync/tcp.rs b/src/client/sync/tcp.rs index 51a8fd28..4a7d6e11 100644 --- a/src/client/sync/tcp.rs +++ b/src/client/sync/tcp.rs @@ -3,26 +3,29 @@ //! TCP client connections -use std::{io::Result, net::SocketAddr, time::Duration}; +use std::{io, net::SocketAddr, time::Duration}; use crate::{client::tcp::connect_slave as async_connect_slave, slave::Slave}; use super::{block_on_with_timeout, Context}; /// Establish a direct connection to a Modbus TCP coupler. -pub fn connect(socket_addr: SocketAddr) -> Result { +pub fn connect(socket_addr: SocketAddr) -> io::Result { connect_slave(socket_addr, Slave::tcp_device()) } /// Establish a direct connection to a Modbus TCP coupler with a timeout. -pub fn connect_with_timeout(socket_addr: SocketAddr, timeout: Option) -> Result { +pub fn connect_with_timeout( + socket_addr: SocketAddr, + timeout: Option, +) -> io::Result { connect_slave_with_timeout(socket_addr, Slave::tcp_device(), timeout) } /// Connect to any kind of Modbus slave device, probably through a Modbus TCP/RTU /// gateway that is forwarding messages to/from the corresponding unit identified /// by the slave parameter. -pub fn connect_slave(socket_addr: SocketAddr, slave: Slave) -> Result { +pub fn connect_slave(socket_addr: SocketAddr, slave: Slave) -> io::Result { connect_slave_with_timeout(socket_addr, slave, None) } @@ -33,7 +36,7 @@ pub fn connect_slave_with_timeout( socket_addr: SocketAddr, slave: Slave, timeout: Option, -) -> Result { +) -> io::Result { let runtime = tokio::runtime::Builder::new_current_thread() .enable_io() .enable_time() diff --git a/src/client/tcp.rs b/src/client/tcp.rs index 90b6257b..acd0df56 100644 --- a/src/client/tcp.rs +++ b/src/client/tcp.rs @@ -3,7 +3,7 @@ //! TCP client connections -use std::{fmt, io::Error, net::SocketAddr}; +use std::{fmt, io, net::SocketAddr}; use tokio::{ io::{AsyncRead, AsyncWrite}, @@ -13,14 +13,14 @@ use tokio::{ use super::*; /// Establish a direct connection to a Modbus TCP coupler. -pub async fn connect(socket_addr: SocketAddr) -> Result { +pub async fn connect(socket_addr: SocketAddr) -> io::Result { connect_slave(socket_addr, Slave::tcp_device()).await } /// Connect to a physical, broadcast, or custom Modbus device, /// probably through a Modbus TCP gateway that is forwarding /// messages to/from the corresponding slave device. -pub async fn connect_slave(socket_addr: SocketAddr, slave: Slave) -> Result { +pub async fn connect_slave(socket_addr: SocketAddr, slave: Slave) -> io::Result { let transport = TcpStream::connect(socket_addr).await?; let context = attach_slave(transport, slave); Ok(context) diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 00000000..22bd892b --- /dev/null +++ b/src/error.rs @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: Copyright (c) 2017-2024 slowtec GmbH +// SPDX-License-Identifier: MIT OR Apache-2.0 + +//! Modbus Error helpers. + +use crate::FunctionCode; + +/// Message to show when a bug has been found during runtime execution. +const REPORT_ISSUE_MSG: &str = + "Please report the issue at `https://github.com/slowtec/tokio-modbus/issues` with a minimal example reproducing this bug."; + +/// Create a panic message for `unexpected response code` with `req_code` and `rsp_code`. +pub(crate) fn unexpected_rsp_code_panic_msg( + req_code: FunctionCode, + rsp_code: FunctionCode, +) -> String { + format!( + "unexpected response code: {rsp_code} (request code: {req_code})\nnote: {REPORT_ISSUE_MSG}" + ) +} diff --git a/src/lib.rs b/src/lib.rs index 0b0462de..9127c0a6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,10 +40,19 @@ pub mod slave; pub use self::slave::{Slave, SlaveId}; mod codec; +mod error; mod frame; pub use self::frame::{Address, Exception, FunctionCode, Quantity, Request, Response}; +/// Specialized [`std::result::Result`] type for `Modbus` client API. +/// +/// This [`Result`] type contains 2 layers of errors. +/// +/// 1. [`std::io::Error`]: An error occurred while performing I/O operations. +/// 2. [`Exception`]: An error occurred on the `Modbus` server. +pub type Result = std::io::Result>; + mod service; #[cfg(feature = "server")] diff --git a/src/service/mod.rs b/src/service/mod.rs index 48d7bd06..0dd738e9 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -6,3 +6,24 @@ pub(crate) mod rtu; #[cfg(feature = "tcp")] pub(crate) mod tcp; + +/// Check that `req_hdr` is the same `Header` as `rsp_hdr`. +/// +/// # Errors +/// +/// If the 2 headers are different, an [`io::Error`] will be returned with [`io::ErrorKind::InvalidData`]. +#[cfg(any(feature = "rtu", feature = "tcp"))] +fn verify_response_header( + req_hdr: &H, + rsp_hdr: &H, +) -> std::io::Result<()> { + if req_hdr != rsp_hdr { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!( + "Invalid response header: expected/request = {req_hdr:?}, actual/response = {rsp_hdr:?}" + ), + )); + } + Ok(()) +} diff --git a/src/service/rtu.rs b/src/service/rtu.rs index 9c12cdde..2b0e15e4 100644 --- a/src/service/rtu.rs +++ b/src/service/rtu.rs @@ -1,10 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2017-2024 slowtec GmbH // SPDX-License-Identifier: MIT OR Apache-2.0 -use std::{ - fmt, - io::{Error, ErrorKind}, -}; +use std::{fmt, io}; use futures_util::{SinkExt as _, StreamExt as _}; use tokio::io::{AsyncRead, AsyncWrite}; @@ -14,8 +11,11 @@ use crate::{ codec, frame::{rtu::*, *}, slave::*, + Result, }; +use super::verify_response_header; + /// Modbus RTU client #[derive(Debug)] pub(crate) struct Client { @@ -47,7 +47,7 @@ where } } - async fn call(&mut self, req: Request<'_>) -> Result { + async fn call(&mut self, req: Request<'_>) -> Result { let disconnect = req == Request::Disconnect; let req_adu = self.next_request_adu(req, disconnect); let req_hdr = req_adu.hdr; @@ -59,27 +59,15 @@ where .framed .next() .await - .unwrap_or_else(|| Err(Error::from(ErrorKind::BrokenPipe)))?; + .unwrap_or_else(|| Err(io::Error::from(io::ErrorKind::BrokenPipe)))?; match res_adu.pdu { - ResponsePdu(Ok(res)) => verify_response_header(req_hdr, res_adu.hdr).and(Ok(res)), - ResponsePdu(Err(err)) => Err(Error::new(ErrorKind::Other, err)), + ResponsePdu(Ok(res)) => verify_response_header(&req_hdr, &res_adu.hdr).and(Ok(Ok(res))), + ResponsePdu(Err(err)) => Ok(Err(err.exception)), } } } -fn verify_response_header(req_hdr: Header, rsp_hdr: Header) -> Result<(), Error> { - if req_hdr != rsp_hdr { - return Err(Error::new( - ErrorKind::InvalidData, - format!( - "Invalid response header: expected/request = {req_hdr:?}, actual/response = {rsp_hdr:?}" - ), - )); - } - Ok(()) -} - impl SlaveContext for Client { fn set_slave(&mut self, slave: Slave) { self.slave_id = slave.into(); @@ -91,7 +79,7 @@ impl crate::client::Client for Client where T: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin, { - async fn call(&mut self, req: Request<'_>) -> Result { + async fn call(&mut self, req: Request<'_>) -> Result { self.call(req).await } } @@ -105,6 +93,36 @@ mod tests { }; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, Result}; + use crate::service::{rtu::Header, verify_response_header}; + + #[test] + fn validate_same_headers() { + // Given + let req_hdr = Header { slave_id: 0 }; + let rsp_hdr = Header { slave_id: 0 }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(result.is_ok()); + } + + #[test] + fn invalid_validate_not_same_slave_id() { + // Given + let req_hdr = Header { slave_id: 0 }; + let rsp_hdr = Header { slave_id: 5 }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(matches!( + result, + Err(err) if err.kind() == std::io::ErrorKind::InvalidData)); + } + #[derive(Debug)] struct MockTransport; diff --git a/src/service/tcp.rs b/src/service/tcp.rs index cb992f41..c52498ed 100644 --- a/src/service/tcp.rs +++ b/src/service/tcp.rs @@ -2,8 +2,7 @@ // SPDX-License-Identifier: MIT OR Apache-2.0 use std::{ - fmt, - io::{Error, ErrorKind}, + fmt, io, sync::atomic::{AtomicU16, Ordering}, }; @@ -14,7 +13,9 @@ use tokio_util::codec::Framed; use crate::{ codec, frame::{tcp::*, *}, + service::verify_response_header, slave::*, + Result, }; const INITIAL_TRANSACTION_ID: TransactionId = 0; @@ -68,7 +69,7 @@ where } } - pub(crate) async fn call(&mut self, req: Request<'_>) -> Result { + pub(crate) async fn call(&mut self, req: Request<'_>) -> Result { log::debug!("Call {:?}", req); let disconnect = req == Request::Disconnect; let req_adu = self.next_request_adu(req, disconnect); @@ -81,27 +82,15 @@ where .framed .next() .await - .ok_or_else(Error::last_os_error)??; + .ok_or_else(io::Error::last_os_error)??; match res_adu.pdu { - ResponsePdu(Ok(res)) => verify_response_header(req_hdr, res_adu.hdr).and(Ok(res)), - ResponsePdu(Err(err)) => Err(Error::new(ErrorKind::Other, err)), + ResponsePdu(Ok(res)) => verify_response_header(&req_hdr, &res_adu.hdr).and(Ok(Ok(res))), + ResponsePdu(Err(err)) => Ok(Err(err.exception)), } } } -fn verify_response_header(req_hdr: Header, rsp_hdr: Header) -> Result<(), Error> { - if req_hdr != rsp_hdr { - return Err(Error::new( - ErrorKind::InvalidData, - format!( - "Invalid response header: expected/request = {req_hdr:?}, actual/response = {rsp_hdr:?}" - ), - )); - } - Ok(()) -} - impl SlaveContext for Client { fn set_slave(&mut self, slave: Slave) { self.unit_id = slave.into(); @@ -113,7 +102,73 @@ impl crate::client::Client for Client where T: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin, { - async fn call(&mut self, req: Request<'_>) -> Result { + async fn call(&mut self, req: Request<'_>) -> Result { Client::call(self, req).await } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_same_headers() { + // Given + let req_hdr = Header { + unit_id: 0, + transaction_id: 42, + }; + let rsp_hdr = Header { + unit_id: 0, + transaction_id: 42, + }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(result.is_ok()); + } + + #[test] + fn invalid_validate_not_same_unit_id() { + // Given + let req_hdr = Header { + unit_id: 0, + transaction_id: 42, + }; + let rsp_hdr = Header { + unit_id: 5, + transaction_id: 42, + }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(matches!( + result, + Err(err) if err.kind() == std::io::ErrorKind::InvalidData)); + } + + #[test] + fn invalid_validate_not_same_transaction_id() { + // Given + let req_hdr = Header { + unit_id: 0, + transaction_id: 42, + }; + let rsp_hdr = Header { + unit_id: 0, + transaction_id: 86, + }; + + // When + let result = verify_response_header(&req_hdr, &rsp_hdr); + + // Then + assert!(matches!( + result, + Err(err) if err.kind() == std::io::ErrorKind::InvalidData)); + } +} diff --git a/tests/exception/mod.rs b/tests/exception/mod.rs index 54c107ae..2bc3e8b9 100644 --- a/tests/exception/mod.rs +++ b/tests/exception/mod.rs @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2017-2024 slowtec GmbH // SPDX-License-Identifier: MIT OR Apache-2.0 -use std::{borrow::Cow, future}; +use std::future; use tokio_modbus::{ client::{Context, Reader, Writer}, @@ -42,125 +42,66 @@ impl Service for TestService { // TODO: Update the `assert_eq` with a check on Exception once Client trait can return Exception pub async fn check_client_context(mut ctx: Context) { - let response = ctx.read_coils(0x00, 2).await; - assert!(response.is_err()); - assert_eq!( - response.unwrap_err().to_string(), - format!( - "Modbus function {}: {}", - Request::ReadCoils(0, 0).function_code(), - Exception::Acknowledge - ), - ); - - let response = ctx.read_discrete_inputs(0x00, 2).await; - assert!(response.is_err()); - assert_eq!( - response.unwrap_err().to_string(), - format!( - "Modbus function {}: {}", - Request::ReadDiscreteInputs(0, 0).function_code(), - Exception::GatewayPathUnavailable - ), - ); - - let response = ctx.write_single_coil(0x00, true).await; - assert!(response.is_err()); - assert_eq!( - response.unwrap_err().to_string(), - format!( - "Modbus function {}: {}", - Request::WriteSingleCoil(0, true).function_code(), - Exception::GatewayTargetDevice - ), - ); - - let response = ctx.write_multiple_coils(0x00, &[true]).await; - assert!(response.is_err()); - assert_eq!( - response.unwrap_err().to_string(), - format!( - "Modbus function {}: {}", - Request::WriteMultipleCoils(0, Cow::Owned(vec![true])).function_code(), - Exception::IllegalDataAddress - ), - ); - - let response = ctx.read_input_registers(0x00, 2).await; - assert!(response.is_err()); - assert_eq!( - response.unwrap_err().to_string(), - format!( - "Modbus function {}: {}", - Request::ReadInputRegisters(0, 2).function_code(), - Exception::IllegalDataValue - ), - ); - - let response = ctx.read_holding_registers(0x00, 2).await; - assert!(response.is_err()); - assert_eq!( - response.unwrap_err().to_string(), - format!( - "Modbus function {}: {}", - Request::ReadHoldingRegisters(0, 2).function_code(), - Exception::IllegalFunction - ), - ); - - let response = ctx.write_single_register(0x00, 42).await; - assert!(response.is_err()); - assert_eq!( - response.unwrap_err().to_string(), - format!( - "Modbus function {}: {}", - Request::WriteSingleRegister(0, 42).function_code(), - Exception::MemoryParityError - ), - ); - - let response = ctx.write_multiple_registers(0x00, &[42]).await; - assert!(response.is_err()); - assert_eq!( - response.unwrap_err().to_string(), - format!( - "Modbus function {}: {}", - Request::WriteMultipleRegisters(0, Cow::Owned(vec![42])).function_code(), - Exception::ServerDeviceBusy - ), - ); - - let response = ctx.masked_write_register(0x00, 0, 0).await; - assert!(response.is_err()); - assert_eq!( - response.unwrap_err().to_string(), - format!( - "Modbus function {}: {}", - Request::MaskWriteRegister(0, 0, 0).function_code(), - Exception::ServerDeviceFailure - ), - ); - - let response = ctx.read_write_multiple_registers(0x00, 0, 0, &[42]).await; - assert!(response.is_err()); - assert_eq!( - response.unwrap_err().to_string(), - format!( - "Modbus function {}: {}", - Request::ReadWriteMultipleRegisters(0, 0, 0, Cow::Owned(vec![42])).function_code(), - Exception::IllegalFunction - ), - ); + let response = ctx.read_coils(0x00, 2).await.expect("communication failed"); + assert_eq!(response, Err(Exception::Acknowledge)); + + let response = ctx + .read_discrete_inputs(0x00, 2) + .await + .expect("communication failed"); + assert_eq!(response, Err(Exception::GatewayPathUnavailable)); + + let response = ctx + .write_single_coil(0x00, true) + .await + .expect("communication failed"); + assert_eq!(response, Err(Exception::GatewayTargetDevice)); + + let response = ctx + .write_multiple_coils(0x00, &[true]) + .await + .expect("communication failed"); + assert_eq!(response, Err(Exception::IllegalDataAddress)); + + let response = ctx + .read_input_registers(0x00, 2) + .await + .expect("communication failed"); + assert_eq!(response, Err(Exception::IllegalDataValue)); + + let response = ctx + .read_holding_registers(0x00, 2) + .await + .expect("communication failed"); + assert_eq!(response, Err(Exception::IllegalFunction)); + + let response = ctx + .write_single_register(0x00, 42) + .await + .expect("communication failed"); + assert_eq!(response, Err(Exception::MemoryParityError)); + + let response = ctx + .write_multiple_registers(0x00, &[42]) + .await + .expect("communication failed"); + assert_eq!(response, Err(Exception::ServerDeviceBusy)); + + let response = ctx + .masked_write_register(0x00, 0, 0) + .await + .expect("communication failed"); + assert_eq!(response, Err(Exception::ServerDeviceFailure)); + + let response = ctx + .read_write_multiple_registers(0x00, 0, 0, &[42]) + .await + .expect("communication failed"); + assert_eq!(response, Err(Exception::IllegalFunction)); // TODO: This codes hangs if used with `rtu-over-tcp-server`, need to check why - /*let response = ctx.call(Request::Custom(70, Cow::Owned(vec![42]))).await; - assert!(response.is_err()); - assert_eq!( - response.unwrap_err().to_string(), - format!( - "Modbus function {}: {}", - Request::Custom(70, Cow::Owned(vec![42])).function_code(), - Exception::IllegalFunction - ), - );*/ + /* + let response = ctx.call(Request::Custom(70, Cow::Owned(vec![42]))).await.expect("communication failed"); + assert_eq!(response, Err(Exception::IllegalFunction)); + */ }