diff --git a/Cargo.toml b/Cargo.toml index 501d8c23..3bc03abb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,6 +70,13 @@ tcp-server = [ "tokio/macros", "tokio/rt-multi-thread", ] +rtu-over-tcp-server = [ + "rtu", + "server", + "socket2/all", + "tokio/macros", + "tokio/rt-multi-thread", +] # The following features are internal and must not be used in dependencies. sync = ["dep:futures", "tokio/time", "tokio/rt"] server = ["dep:futures"] @@ -117,6 +124,11 @@ name = "tcp-server" path = "examples/tcp-server.rs" required-features = ["tcp-server"] +[[example]] +name = "rtu-over-tcp-server" +path = "examples/rtu-over-tcp-server.rs" +required-features = ["rtu-over-tcp-server"] + [[example]] name = "tls-client" path = "examples/tls-client.rs" diff --git a/examples/rtu-over-tcp-server.rs b/examples/rtu-over-tcp-server.rs new file mode 100644 index 00000000..57f3aa02 --- /dev/null +++ b/examples/rtu-over-tcp-server.rs @@ -0,0 +1,214 @@ +// SPDX-FileCopyrightText: Copyright (c) 2017-2023 slowtec GmbH +// SPDX-License-Identifier: MIT OR Apache-2.0 + +//! # RTU over TCP server example +//! +//! This example shows how to start a server and implement basic register +//! read/write operations. + +use std::{ + collections::HashMap, + net::SocketAddr, + sync::{Arc, Mutex}, + time::Duration, +}; + +use futures::future; +use tokio::net::TcpListener; + +use tokio_modbus::{ + prelude::*, + server::rtu_over_tcp::{accept_tcp_connection, Server}, +}; + +struct ExampleService { + input_registers: Arc>>, + holding_registers: Arc>>, +} + +impl tokio_modbus::server::Service for ExampleService { + type Request = SlaveRequest<'static>; + type Response = Response; + type Error = std::io::Error; + type Future = future::Ready>; + + fn call(&self, req: Self::Request) -> Self::Future { + println!("{}", req.slave); + match req.request { + Request::ReadInputRegisters(addr, cnt) => { + match register_read(&self.input_registers.lock().unwrap(), addr, cnt) { + Ok(values) => future::ready(Ok(Response::ReadInputRegisters(values))), + Err(err) => future::ready(Err(err)), + } + } + Request::ReadHoldingRegisters(addr, cnt) => { + match register_read(&self.holding_registers.lock().unwrap(), addr, cnt) { + Ok(values) => future::ready(Ok(Response::ReadHoldingRegisters(values))), + Err(err) => future::ready(Err(err)), + } + } + Request::WriteMultipleRegisters(addr, values) => { + match register_write(&mut self.holding_registers.lock().unwrap(), addr, &values) { + Ok(_) => future::ready(Ok(Response::WriteMultipleRegisters( + addr, + values.len() as u16, + ))), + Err(err) => future::ready(Err(err)), + } + } + Request::WriteSingleRegister(addr, value) => { + match register_write( + &mut self.holding_registers.lock().unwrap(), + addr, + std::slice::from_ref(&value), + ) { + Ok(_) => future::ready(Ok(Response::WriteSingleRegister(addr, value))), + Err(err) => future::ready(Err(err)), + } + } + _ => { + println!("SERVER: Exception::IllegalFunction - Unimplemented function code in request: {req:?}"); + // TODO: We want to return a Modbus Exception response `IllegalFunction`. https://github.com/slowtec/tokio-modbus/issues/165 + future::ready(Err(std::io::Error::new( + std::io::ErrorKind::AddrNotAvailable, + "Unimplemented function code in request".to_string(), + ))) + } + } + } +} + +impl ExampleService { + fn new() -> Self { + // Insert some test data as register values. + let mut input_registers = HashMap::new(); + input_registers.insert(0, 1234); + input_registers.insert(1, 5678); + let mut holding_registers = HashMap::new(); + holding_registers.insert(0, 10); + holding_registers.insert(1, 20); + holding_registers.insert(2, 30); + holding_registers.insert(3, 40); + Self { + input_registers: Arc::new(Mutex::new(input_registers)), + holding_registers: Arc::new(Mutex::new(holding_registers)), + } + } +} + +/// Helper function implementing reading registers from a HashMap. +fn register_read( + registers: &HashMap, + addr: u16, + cnt: u16, +) -> Result, std::io::Error> { + let mut response_values = vec![0; cnt.into()]; + for i in 0..cnt { + let reg_addr = addr + i; + if let Some(r) = registers.get(®_addr) { + response_values[i as usize] = *r; + } else { + // TODO: Return a Modbus Exception response `IllegalDataAddress` https://github.com/slowtec/tokio-modbus/issues/165 + println!("SERVER: Exception::IllegalDataAddress"); + return Err(std::io::Error::new( + std::io::ErrorKind::AddrNotAvailable, + format!("no register at address {reg_addr}"), + )); + } + } + + Ok(response_values) +} + +/// Write a holding register. Used by both the write single register +/// and write multiple registers requests. +fn register_write( + registers: &mut HashMap, + addr: u16, + values: &[u16], +) -> Result<(), std::io::Error> { + for (i, value) in values.iter().enumerate() { + let reg_addr = addr + i as u16; + if let Some(r) = registers.get_mut(®_addr) { + *r = *value; + } else { + // TODO: Return a Modbus Exception response `IllegalDataAddress` https://github.com/slowtec/tokio-modbus/issues/165 + println!("SERVER: Exception::IllegalDataAddress"); + return Err(std::io::Error::new( + std::io::ErrorKind::AddrNotAvailable, + format!("no register at address {reg_addr}"), + )); + } + } + + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let socket_addr = "127.0.0.1:5502".parse().unwrap(); + + tokio::select! { + _ = server_context(socket_addr) => unreachable!(), + _ = client_context(socket_addr) => println!("Exiting"), + } + + Ok(()) +} + +async fn server_context(socket_addr: SocketAddr) -> anyhow::Result<()> { + println!("Starting up server on {socket_addr}"); + let listener = TcpListener::bind(socket_addr).await?; + let server = Server::new(listener); + let new_service = |_socket_addr| Ok(Some(ExampleService::new())); + let on_connected = |stream, socket_addr| async move { + accept_tcp_connection(stream, socket_addr, new_service) + }; + let on_process_error = |err| { + eprintln!("{err}"); + }; + server.serve(&on_connected, on_process_error).await?; + Ok(()) +} + +async fn client_context(socket_addr: SocketAddr) { + tokio::join!( + async { + // Give the server some time for starting up + tokio::time::sleep(Duration::from_secs(1)).await; + + println!("CLIENT: Connecting client..."); + let transport = tokio::net::TcpStream::connect(socket_addr).await.unwrap(); + let mut ctx = tokio_modbus::prelude::rtu::attach_slave(transport, Slave(1)); + + println!("CLIENT: Reading 2 input registers..."); + let response = ctx.read_input_registers(0x00, 2).await.unwrap(); + println!("CLIENT: The result is '{response:?}'"); + assert_eq!(response, [1234, 5678]); + + println!("CLIENT: Writing 2 holding registers..."); + ctx.write_multiple_registers(0x01, &[7777, 8888]) + .await + .unwrap(); + + // Read back a block including the two registers we wrote. + println!("CLIENT: Reading 4 holding registers..."); + let response = ctx.read_holding_registers(0x00, 4).await.unwrap(); + println!("CLIENT: The result is '{response:?}'"); + assert_eq!(response, [10, 7777, 8888, 40]); + + // Now we try to read with an invalid register address. + // This should return a Modbus exception response with the code + // IllegalDataAddress. + println!("CLIENT: Reading nonexisting holding register address... (should return IllegalDataAddress)"); + let response = ctx.read_holding_registers(0x100, 1).await; + println!("CLIENT: The result is '{response:?}'"); + assert!(response.is_err()); + // TODO: How can Modbus client identify Modbus exception responses? E.g. here we expect IllegalDataAddress + // Question here: https://github.com/slowtec/tokio-modbus/issues/169 + + println!("CLIENT: Done.") + }, + tokio::time::sleep(Duration::from_secs(5)) + ); +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 2f23cb8a..84b7aee4 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -12,6 +12,9 @@ pub mod rtu; #[cfg(feature = "tcp-server")] pub mod tcp; +#[cfg(feature = "rtu-over-tcp-server")] +pub mod rtu_over_tcp; + mod service; pub use self::service::Service; diff --git a/src/server/rtu_over_tcp.rs b/src/server/rtu_over_tcp.rs new file mode 100644 index 00000000..088c3a58 --- /dev/null +++ b/src/server/rtu_over_tcp.rs @@ -0,0 +1,281 @@ +// SPDX-FileCopyrightText: Copyright (c) 2017-2023 slowtec GmbH +// SPDX-License-Identifier: MIT OR Apache-2.0 + +//! Modbus RTU over TCP server skeleton + +use std::{io, net::SocketAddr}; + +use async_trait::async_trait; +use futures::{self, Future}; +use futures_util::{future::FutureExt as _, sink::SinkExt as _, stream::StreamExt as _}; +use socket2::{Domain, Socket, Type}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::{TcpListener, TcpStream}, +}; +use tokio_util::codec::Framed; + +use crate::{ + codec::rtu::ServerCodec, + frame::{ + rtu::{RequestAdu, ResponseAdu}, + OptionalResponsePdu, + }, + server::service::Service, +}; + +use super::Terminated; + +#[async_trait] +pub trait BindSocket { + type Error; + + async fn bind_socket(addr: SocketAddr) -> Result; +} + +/// Accept unencrypted TCP connections. +pub fn accept_tcp_connection( + stream: TcpStream, + socket_addr: SocketAddr, + new_service: NewService, +) -> io::Result> +where + S: Service + Send + Sync + 'static, + S::Request: From> + Send, + S::Response: Into + Send, + S::Error: Into, + NewService: Fn(SocketAddr) -> io::Result>, +{ + let service = new_service(socket_addr)?; + Ok(service.map(|service| (service, stream))) +} + +#[derive(Debug)] +pub struct Server { + listener: TcpListener, +} + +impl Server { + /// Attach the Modbus server to a TCP socket server. + #[must_use] + pub fn new(listener: TcpListener) -> Self { + Self { listener } + } + + /// Listens for incoming connections and starts a Modbus RTU over TCP server task for + /// each connection. + /// + /// `OnConnected` is responsible for creating both the service and the + /// transport layer for the underlying TCP stream. If `OnConnected` returns + /// with `Err` then listening stops and [`Self::serve()`] returns with an error. + /// If `OnConnected` returns `Ok(None)` then the connection is rejected + /// but [`Self::serve()`] continues listening for new connections. + pub async fn serve( + &self, + on_connected: &OnConnected, + on_process_error: OnProcessError, + ) -> io::Result<()> + where + S: Service + Send + Sync + 'static, + S::Request: From> + Send, + S::Response: Into + Send, + S::Error: Into, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + OnConnected: Fn(TcpStream, SocketAddr) -> F, + F: Future>>, + OnProcessError: FnOnce(io::Error) + Clone + Send + 'static, + { + loop { + let (stream, socket_addr) = self.listener.accept().await?; + log::debug!("Accepted connection from {socket_addr}"); + + let Some((service, transport)) = on_connected(stream, socket_addr).await? else { + log::debug!("No service for connection from {socket_addr}"); + continue; + }; + let on_process_error = on_process_error.clone(); + + // use RTU codec + let framed = Framed::new(transport, ServerCodec::default()); + + tokio::spawn(async move { + log::debug!("Processing requests from {socket_addr}"); + if let Err(err) = process(framed, service).await { + on_process_error(err); + } + }); + } + } + + /// Start an abortable Modbus RTU over TCP server task. + /// + /// Warning: Request processing is not scoped and could be aborted at any internal await point! + /// See also: + pub async fn serve_until( + self, + on_connected: &OnConnected, + on_process_error: OnProcessError, + abort_signal: X, + ) -> io::Result + where + S: Service + Send + Sync + 'static, + S::Request: From> + Send, + S::Response: Into + Send, + S::Error: Into, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + X: Future + Sync + Send + Unpin + 'static, + OnConnected: Fn(TcpStream, SocketAddr) -> F, + F: Future>>, + OnProcessError: FnOnce(io::Error) + Clone + Send + 'static, + { + let abort_signal = abort_signal.fuse(); + tokio::select! { + res = self.serve(on_connected, on_process_error) => { + res.map(|()| Terminated::Finished) + }, + () = abort_signal => { + Ok(Terminated::Aborted) + } + } + } +} + +/// The request-response loop spawned by [`serve_until`] for each client +async fn process(mut framed: Framed, service: S) -> io::Result<()> +where + S: Service + Send + Sync + 'static, + S::Request: From> + Send, + S::Response: Into + Send, + S::Error: Into, + T: AsyncRead + AsyncWrite + Unpin, +{ + loop { + let Some(request) = framed.next().await.transpose()? else { + log::debug!("TCP socket has been closed"); + break; + }; + + let hdr = request.hdr; + let OptionalResponsePdu(Some(response_pdu)) = service + .call(request.into()) + .await + .map_err(Into::into)? + .into() + else { + log::trace!("Sending no response for request {hdr:?}"); + continue; + }; + + framed + .send(ResponseAdu { + hdr, + pdu: response_pdu, + }) + .await?; + } + + Ok(()) +} + +/// Start TCP listener - configure and open TCP socket +#[allow(unused)] +fn listener(addr: SocketAddr, workers: usize) -> io::Result { + let listener = match addr { + SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?, + SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?, + }; + configure_tcp(workers, &listener)?; + listener.reuse_address()?; + listener.bind(&addr.into())?; + listener.listen(1024)?; + TcpListener::from_std(listener.into()) +} + +#[cfg(unix)] +#[allow(unused)] +fn configure_tcp(workers: usize, tcp: &Socket) -> io::Result<()> { + if workers > 1 { + tcp.reuse_port()?; + } + Ok(()) +} + +#[cfg(windows)] +#[allow(unused)] +fn configure_tcp(_workers: usize, _tcp: &Socket) -> io::Result<()> { + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::{prelude::*, server::Service}; + + use std::sync::Arc; + + use futures::future; + + #[tokio::test] + async fn delegate_service_through_deref_for_server() { + #[derive(Clone)] + struct DummyService { + response: Response, + } + + impl Service for DummyService { + type Request = Request<'static>; + type Response = Response; + type Error = io::Error; + type Future = future::Ready>; + + fn call(&self, _: Self::Request) -> Self::Future { + future::ready(Ok(self.response.clone())) + } + } + + let service = Arc::new(DummyService { + response: Response::ReadInputRegisters(vec![0x33]), + }); + let svc = |_socket_addr| Ok(Some(Arc::clone(&service))); + let on_connected = + |stream, socket_addr| async move { accept_tcp_connection(stream, socket_addr, svc) }; + + // bind 0 to let the OS pick a random port + let addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); + let listener = TcpListener::bind(addr).await.unwrap(); + let server = Server::new(listener); + + // passes type-check is the goal here + // added `mem::drop` to satisfy `must_use` compiler warnings + std::mem::drop(server.serve(&on_connected, |_err| {})); + } + + #[tokio::test] + async fn service_wrapper() { + #[derive(Clone)] + struct DummyService { + response: Response, + } + + impl Service for DummyService { + type Request = Request<'static>; + type Response = Response; + type Error = io::Error; + type Future = future::Ready>; + + fn call(&self, _: Self::Request) -> Self::Future { + future::ready(Ok(self.response.clone())) + } + } + + let service = DummyService { + response: Response::ReadInputRegisters(vec![0x33]), + }; + + let pdu = Request::ReadInputRegisters(0, 1); + let rsp_adu = service.call(pdu).await.unwrap(); + + assert_eq!(rsp_adu, service.response); + } +}