From c0d40436c1f556a0e1defacdd8295ac61f6609cc Mon Sep 17 00:00:00 2001 From: Uwe Klotz Date: Mon, 31 Jan 2022 13:59:05 +0100 Subject: [PATCH] Replace casts between primitive types with safe type conversions Casts would silently drop the least significant bits when applied incorrectly whereas explicit type conversions are checked by the compiler. Type conversions are only permitted when no information is lost. --- examples/rtu-server.rs | 2 +- examples/tcp-server.rs | 2 +- src/client/mod.rs | 18 ++++++------ src/codec/mod.rs | 66 ++++++++++++++++++++++++++---------------- src/codec/rtu.rs | 14 ++++++--- src/codec/tcp.rs | 6 ++-- src/frame/mod.rs | 7 +++++ 7 files changed, 72 insertions(+), 43 deletions(-) diff --git a/examples/rtu-server.rs b/examples/rtu-server.rs index 4be83e51..23d3bf62 100644 --- a/examples/rtu-server.rs +++ b/examples/rtu-server.rs @@ -17,7 +17,7 @@ pub async fn main() -> Result<(), Box> { fn call(&self, req: Self::Request) -> Self::Future { match req { Request::ReadInputRegisters(_addr, cnt) => { - let mut registers = vec![0; cnt as usize]; + let mut registers = vec![0; cnt.into()]; registers[2] = 0x77; future::ready(Ok(Response::ReadInputRegisters(registers))) } diff --git a/examples/tcp-server.rs b/examples/tcp-server.rs index 27dda712..fa747f9e 100644 --- a/examples/tcp-server.rs +++ b/examples/tcp-server.rs @@ -15,7 +15,7 @@ impl Service for MbServer { fn call(&self, req: Self::Request) -> Self::Future { match req { Request::ReadInputRegisters(_addr, cnt) => { - let mut registers = vec![0; cnt as usize]; + let mut registers = vec![0; cnt.into()]; registers[2] = 77; future::ready(Ok(Response::ReadInputRegisters(registers))) } diff --git a/src/client/mod.rs b/src/client/mod.rs index 0debf38e..025293be 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -110,8 +110,8 @@ impl Reader for Context { let rsp = self.client.call(Request::ReadCoils(addr, cnt)).await?; if let Response::ReadCoils(mut coils) = rsp { - debug_assert!(coils.len() >= cnt as usize); - coils.truncate(cnt as usize); + debug_assert!(coils.len() >= cnt.into()); + coils.truncate(cnt.into()); Ok(coils) } else { Err(Error::new(ErrorKind::InvalidData, "unexpected response")) @@ -129,8 +129,8 @@ impl Reader for Context { .await?; if let Response::ReadDiscreteInputs(mut coils) = rsp { - debug_assert!(coils.len() >= cnt as usize); - coils.truncate(cnt as usize); + debug_assert!(coils.len() >= cnt.into()); + coils.truncate(cnt.into()); Ok(coils) } else { Err(Error::new(ErrorKind::InvalidData, "unexpected response")) @@ -148,7 +148,7 @@ impl Reader for Context { .await?; if let Response::ReadInputRegisters(rsp) = rsp { - if rsp.len() != cnt as usize { + if rsp.len() != cnt.into() { return Err(Error::new(ErrorKind::InvalidData, "invalid response")); } Ok(rsp) @@ -168,7 +168,7 @@ impl Reader for Context { .await?; if let Response::ReadHoldingRegisters(rsp) = rsp { - if rsp.len() != cnt as usize { + if rsp.len() != cnt.into() { return Err(Error::new(ErrorKind::InvalidData, "invalid response")); } Ok(rsp) @@ -195,7 +195,7 @@ impl Reader for Context { .await?; if let Response::ReadWriteMultipleRegisters(rsp) = rsp { - if rsp.len() != read_cnt as usize { + if rsp.len() != read_cnt.into() { return Err(Error::new(ErrorKind::InvalidData, "invalid response")); } Ok(rsp) @@ -235,7 +235,7 @@ impl Writer for Context { .await?; if let Response::WriteMultipleCoils(rsp_addr, rsp_cnt) = rsp { - if rsp_addr != addr || rsp_cnt as usize != cnt { + if rsp_addr != addr || usize::from(rsp_cnt) != cnt { return Err(Error::new(ErrorKind::InvalidData, "invalid response")); } Ok(()) @@ -276,7 +276,7 @@ impl Writer for Context { .await?; if let Response::WriteMultipleRegisters(rsp_addr, rsp_cnt) = rsp { - if rsp_addr != addr || rsp_cnt as usize != cnt { + if rsp_addr != addr || usize::from(rsp_cnt) != cnt { return Err(Error::new(ErrorKind::InvalidData, "invalid response")); } Ok(()) diff --git a/src/codec/mod.rs b/src/codec/mod.rs index 23315622..0837fba1 100644 --- a/src/codec/mod.rs +++ b/src/codec/mod.rs @@ -12,6 +12,22 @@ use bytes::{BufMut, Bytes, BytesMut}; use std::convert::TryFrom; use std::io::{self, Cursor, Error, ErrorKind}; +fn u16_len(len: usize) -> u16 { + // This type conversion should always be safe, because either + // the caller is responsible to pass a valid usize or the + // possible values are limited by the protocol. + debug_assert!(len <= u16::MAX.into()); + len as u16 +} + +fn u8_len(len: usize) -> u8 { + // This type conversion should always be safe, because either + // the caller is responsible to pass a valid usize or the + // possible values are limited by the protocol. + debug_assert!(len <= u8::MAX.into()); + len as u8 +} + impl From for Bytes { fn from(req: Request) -> Bytes { let cnt = request_byte_count(&req); @@ -33,9 +49,9 @@ impl From for Bytes { WriteMultipleCoils(address, coils) => { data.put_u16(address); let len = coils.len(); - data.put_u16(len as u16); + data.put_u16(u16_len(len)); let packed_coils = pack_coils(&coils); - data.put_u8(packed_coils.len() as u8); + data.put_u8(u8_len(packed_coils.len())); for b in packed_coils { data.put_u8(b); } @@ -47,8 +63,8 @@ impl From for Bytes { WriteMultipleRegisters(address, words) => { data.put_u16(address); let len = words.len(); - data.put_u16(len as u16); - data.put_u8((len as u8) * 2); + data.put_u16(u16_len(len)); + data.put_u8(u8_len(len * 2)); for w in words { data.put_u16(w); } @@ -58,8 +74,8 @@ impl From for Bytes { data.put_u16(quantity); data.put_u16(write_address); let n = words.len(); - data.put_u16(n as u16); - data.put_u8(n as u8 * 2); + data.put_u16(u16_len(n)); + data.put_u8(u8_len(n * 2)); for w in words { data.put_u16(w); } @@ -90,7 +106,7 @@ impl From for Bytes { match rsp { ReadCoils(coils) | ReadDiscreteInputs(coils) => { let packed_coils = pack_coils(&coils); - data.put_u8(packed_coils.len() as u8); + data.put_u8(u8_len(packed_coils.len())); for b in packed_coils { data.put_u8(b); } @@ -98,7 +114,7 @@ impl From for Bytes { ReadInputRegisters(registers) | ReadHoldingRegisters(registers) | ReadWriteMultipleRegisters(registers) => { - data.put_u8((registers.len() * 2) as u8); + data.put_u8(u8_len(registers.len() * 2)); for r in registers { data.put_u16(r); } @@ -130,7 +146,7 @@ impl From for Bytes { let mut data = BytesMut::with_capacity(2); debug_assert!(ex.function < 0x80); data.put_u8(ex.function + 0x80); - data.put_u8(ex.exception as u8); + data.put_u8(ex.exception.into()); data.freeze() } } @@ -160,7 +176,7 @@ impl TryFrom for Request { let address = rdr.read_u16::()?; let quantity = rdr.read_u16::()?; let byte_count = rdr.read_u8()?; - if bytes.len() < (6 + byte_count as usize) { + if bytes.len() < 6 + usize::from(byte_count) { return Err(Error::new(ErrorKind::InvalidData, "Invalid byte count")); } let x = &bytes[6..]; @@ -175,11 +191,11 @@ impl TryFrom for Request { 0x10 => { let address = rdr.read_u16::()?; let quantity = rdr.read_u16::()?; - let byte_count = rdr.read_u8()? as usize; - if bytes.len() < (6 + byte_count as usize) { + let byte_count = rdr.read_u8()?; + if bytes.len() < 6 + usize::from(byte_count) { return Err(Error::new(ErrorKind::InvalidData, "Invalid byte count")); } - let mut data = Vec::with_capacity(quantity as usize); + let mut data = Vec::with_capacity(quantity.into()); for _ in 0..quantity { data.push(rdr.read_u16::()?); } @@ -190,11 +206,11 @@ impl TryFrom for Request { let read_quantity = rdr.read_u16::()?; let write_address = rdr.read_u16::()?; let write_quantity = rdr.read_u16::()?; - let write_count = rdr.read_u8()? as usize; - if bytes.len() < (10 + write_count as usize) { + let write_count = rdr.read_u8()?; + if bytes.len() < 10 + usize::from(write_count) { return Err(Error::new(ErrorKind::InvalidData, "Invalid byte count")); } - let mut data = Vec::with_capacity(write_quantity as usize); + let mut data = Vec::with_capacity(write_quantity.into()); for _ in 0..write_quantity { data.push(rdr.read_u16::()?); } @@ -253,7 +269,7 @@ impl TryFrom for Response { 0x04 => { let byte_count = rdr.read_u8()?; let quantity = byte_count / 2; - let mut data = Vec::with_capacity(quantity as usize); + let mut data = Vec::with_capacity(quantity.into()); for _ in 0..quantity { data.push(rdr.read_u16::()?); } @@ -262,7 +278,7 @@ impl TryFrom for Response { 0x03 => { let byte_count = rdr.read_u8()?; let quantity = byte_count / 2; - let mut data = Vec::with_capacity(quantity as usize); + let mut data = Vec::with_capacity(quantity.into()); for _ in 0..quantity { data.push(rdr.read_u16::()?); } @@ -276,7 +292,7 @@ impl TryFrom for Response { 0x17 => { let byte_count = rdr.read_u8()?; let quantity = byte_count / 2; - let mut data = Vec::with_capacity(quantity as usize); + let mut data = Vec::with_capacity(quantity.into()); for _ in 0..quantity { data.push(rdr.read_u16::()?); } @@ -371,15 +387,15 @@ fn pack_coils(coils: &[Coil]) -> Vec { let mut res = vec![0; packed_size]; for (i, b) in coils.iter().enumerate() { let v = if *b { 0b1 } else { 0b0 }; - res[(i / 8) as usize] |= v << (i % 8); + res[i / 8] |= v << (i % 8); } res } fn unpack_coils(bytes: &[u8], count: u16) -> Vec { - let mut res = Vec::with_capacity(count as usize); - for i in 0..count { - res.push((bytes[(i / 8u16) as usize] >> (i % 8)) & 0b1 > 0); + let mut res = Vec::with_capacity(count.into()); + for i in 0usize..count.into() { + res.push((bytes[i / 8] >> (i % 8)) & 0b1 > 0); } res } @@ -973,7 +989,7 @@ mod tests { fn read_coils_max_quantity() { let quantity = 2000; let byte_count = quantity / 8; - let mut raw: Vec = vec![1, byte_count as u8]; + let mut raw: Vec = vec![1, u8_len(byte_count)]; let mut values: Vec = (0..byte_count).map(|_| 0b_1111_1111).collect(); raw.append(&mut values); let bytes = Bytes::from(raw); @@ -997,7 +1013,7 @@ mod tests { fn read_discrete_inputs_max_quantity() { let quantity = 2000; let byte_count = quantity / 8; - let mut raw: Vec = vec![2, byte_count as u8]; + let mut raw: Vec = vec![2, u8_len(byte_count)]; let mut values: Vec = (0..byte_count).map(|_| 0b_1111_1111).collect(); raw.append(&mut values); let bytes = Bytes::from(raw); diff --git a/src/codec/rtu.rs b/src/codec/rtu.rs index 546fc43a..a9d8f5c4 100644 --- a/src/codec/rtu.rs +++ b/src/codec/rtu.rs @@ -120,12 +120,16 @@ fn get_request_pdu_len(adu_buf: &BytesMut) -> Result> { 0x01..=0x06 => 5, 0x07 | 0x0B | 0x0C | 0x11 => 1, 0x0F | 0x10 => { - return Ok(adu_buf.get(6).map(|&byte_count| 6 + byte_count as usize)); + return Ok(adu_buf + .get(6) + .map(|&byte_count| 6 + usize::from(byte_count))); } 0x16 => 7, 0x18 => 3, 0x17 => { - return Ok(adu_buf.get(10).map(|&byte_count| 10 + byte_count as usize)); + return Ok(adu_buf + .get(10) + .map(|&byte_count| 10 + usize::from(byte_count))); } _ => { return Err(Error::new( @@ -144,14 +148,16 @@ fn get_response_pdu_len(adu_buf: &BytesMut) -> Result> { if let Some(fn_code) = adu_buf.get(1) { let len = match fn_code { 0x01..=0x04 | 0x0C | 0x17 => { - return Ok(adu_buf.get(2).map(|&byte_count| 2 + byte_count as usize)); + return Ok(adu_buf + .get(2) + .map(|&byte_count| 2 + usize::from(byte_count))); } 0x05 | 0x06 | 0x0B | 0x0F | 0x10 => 5, 0x07 => 2, 0x16 => 7, 0x18 => { if adu_buf.len() > 3 { - 3 + Cursor::new(&adu_buf[2..=3]).read_u16::()? as usize + 3 + usize::from(Cursor::new(&adu_buf[2..=3]).read_u16::()?) } else { // incomplete frame return Ok(None); diff --git a/src/codec/tcp.rs b/src/codec/tcp.rs index c871033f..69ae6ebe 100644 --- a/src/codec/tcp.rs +++ b/src/codec/tcp.rs @@ -51,7 +51,7 @@ impl Decoder for AduDecoder { } debug_assert!(HEADER_LEN >= 6); - let len = BigEndian::read_u16(&buf[4..6]) as usize; + let len = usize::from(BigEndian::read_u16(&buf[4..6])); let pdu_len = if len > 0 { // len = bytes of PDU + one byte (unit ID) len - 1 @@ -146,7 +146,7 @@ impl Encoder for ClientCodec { buf.reserve(pdu_data.len() + 7); buf.put_u16(hdr.transaction_id); buf.put_u16(PROTOCOL_ID); - buf.put_u16((pdu_data.len() + 1) as u16); + buf.put_u16(u16_len(pdu_data.len() + 1)); buf.put_u8(hdr.unit_id); buf.put_slice(&*pdu_data); Ok(()) @@ -162,7 +162,7 @@ impl Encoder for ServerCodec { buf.reserve(pdu_data.len() + 7); buf.put_u16(hdr.transaction_id); buf.put_u16(PROTOCOL_ID); - buf.put_u16((pdu_data.len() + 1) as u16); + buf.put_u16(u16_len(pdu_data.len() + 1)); buf.put_u8(hdr.unit_id); buf.put_slice(&*pdu_data); Ok(()) diff --git a/src/frame/mod.rs b/src/frame/mod.rs index a620fac1..b507f2e5 100644 --- a/src/frame/mod.rs +++ b/src/frame/mod.rs @@ -154,6 +154,7 @@ pub enum Response { /// A server (slave) exception. #[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] pub enum Exception { IllegalFunction = 0x01, IllegalDataAddress = 0x02, @@ -166,6 +167,12 @@ pub enum Exception { GatewayTargetDevice = 0x0B, } +impl From for u8 { + fn from(from: Exception) -> Self { + from as u8 + } +} + impl Exception { pub(crate) fn description(&self) -> &str { use crate::frame::Exception::*;