From aae4969fde8a3b2bc091c1fbbad51071a3fa542c Mon Sep 17 00:00:00 2001 From: Sam Rijs Date: Mon, 13 Jan 2025 13:53:52 +0100 Subject: [PATCH] Add specialized `Buf::chunks_vectored` for `Take` (#617) Co-authored-by: Alice Ryhl Co-authored-by: Michal 'vorner' Vaner --- src/buf/take.rs | 49 +++++++++++++++++++++++++++++++++++++++++++ tests/test_buf.rs | 2 +- tests/test_take.rs | 52 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 1 deletion(-) diff --git a/src/buf/take.rs b/src/buf/take.rs index a16a434ee..fc4e39dda 100644 --- a/src/buf/take.rs +++ b/src/buf/take.rs @@ -2,6 +2,9 @@ use crate::{Buf, Bytes}; use core::cmp; +#[cfg(feature = "std")] +use std::io::IoSlice; + /// A `Buf` adapter which limits the bytes read from an underlying buffer. /// /// This struct is generally created by calling `take()` on `Buf`. See @@ -152,4 +155,50 @@ impl Buf for Take { self.limit -= len; r } + + #[cfg(feature = "std")] + fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize { + if self.limit == 0 { + return 0; + } + + const LEN: usize = 16; + let mut slices: [IoSlice<'a>; LEN] = [ + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + ]; + + let cnt = self + .inner + .chunks_vectored(&mut slices[..dst.len().min(LEN)]); + let mut limit = self.limit; + for (i, (dst, slice)) in dst[..cnt].iter_mut().zip(slices.iter()).enumerate() { + if let Some(buf) = slice.get(..limit) { + // SAFETY: We could do this safely with `IoSlice::advance` if we had a larger MSRV. + let buf = unsafe { std::mem::transmute::<&[u8], &'a [u8]>(buf) }; + *dst = IoSlice::new(buf); + return i + 1; + } else { + // SAFETY: We could do this safely with `IoSlice::advance` if we had a larger MSRV. + let buf = unsafe { std::mem::transmute::<&[u8], &'a [u8]>(slice) }; + *dst = IoSlice::new(buf); + limit -= slice.len(); + } + } + cnt + } } diff --git a/tests/test_buf.rs b/tests/test_buf.rs index 5a5ac7e80..099016e24 100644 --- a/tests/test_buf.rs +++ b/tests/test_buf.rs @@ -404,7 +404,7 @@ mod chain_limited_slices { Buf::take(Buf::chain(Buf::chain(a, b), Buf::chain(c, d)), buf.len()) } - buf_tests!(make_input, /* `Limit` does not forward `chucks_vectored */ false); + buf_tests!(make_input, true); } #[allow(unused_allocation)] // This is intentional. diff --git a/tests/test_take.rs b/tests/test_take.rs index 51df91d14..0c0159be1 100644 --- a/tests/test_take.rs +++ b/tests/test_take.rs @@ -30,3 +30,55 @@ fn take_copy_to_bytes_panics() { let abcd = Bytes::copy_from_slice(b"abcd"); abcd.take(2).copy_to_bytes(3); } + +#[cfg(feature = "std")] +#[test] +fn take_chunks_vectored() { + fn chain() -> impl Buf { + Bytes::from([1, 2, 3].to_vec()).chain(Bytes::from([4, 5, 6].to_vec())) + } + + { + let mut dst = [std::io::IoSlice::new(&[]); 2]; + let take = chain().take(0); + assert_eq!(take.chunks_vectored(&mut dst), 0); + } + + { + let mut dst = [std::io::IoSlice::new(&[]); 2]; + let take = chain().take(1); + assert_eq!(take.chunks_vectored(&mut dst), 1); + assert_eq!(&*dst[0], &[1]); + } + + { + let mut dst = [std::io::IoSlice::new(&[]); 2]; + let take = chain().take(3); + assert_eq!(take.chunks_vectored(&mut dst), 1); + assert_eq!(&*dst[0], &[1, 2, 3]); + } + + { + let mut dst = [std::io::IoSlice::new(&[]); 2]; + let take = chain().take(4); + assert_eq!(take.chunks_vectored(&mut dst), 2); + assert_eq!(&*dst[0], &[1, 2, 3]); + assert_eq!(&*dst[1], &[4]); + } + + { + let mut dst = [std::io::IoSlice::new(&[]); 2]; + let take = chain().take(6); + assert_eq!(take.chunks_vectored(&mut dst), 2); + assert_eq!(&*dst[0], &[1, 2, 3]); + assert_eq!(&*dst[1], &[4, 5, 6]); + } + + { + let mut dst = [std::io::IoSlice::new(&[]); 2]; + let take = chain().take(7); + assert_eq!(take.chunks_vectored(&mut dst), 2); + assert_eq!(&*dst[0], &[1, 2, 3]); + assert_eq!(&*dst[1], &[4, 5, 6]); + } +}