From 749687814cc8455349b52b8c098a36220528a3c5 Mon Sep 17 00:00:00 2001 From: zephyr Date: Mon, 20 May 2024 06:46:33 +0900 Subject: [PATCH] realm_core: batched udp --- realm_core/Cargo.toml | 7 +- realm_core/src/udp/batched.rs | 174 ++++++++++++++++++++++++++++++++ realm_core/src/udp/middle.rs | 182 ++++++++++++++++++++++++---------- realm_core/src/udp/mod.rs | 11 +- realm_core/src/udp/socket.rs | 5 +- realm_core/src/udp/sockmap.rs | 15 +++ 6 files changed, 332 insertions(+), 62 deletions(-) create mode 100644 realm_core/src/udp/batched.rs diff --git a/realm_core/Cargo.toml b/realm_core/Cargo.toml index d52ade6f..c5cf0d9f 100644 --- a/realm_core/Cargo.toml +++ b/realm_core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "realm_core" -version = "0.3.9" +version = "0.4.0" authors = ["Realm Contributors"] description = "Realm's core facilities." repository = "https://github.com/zhboner/realm" @@ -12,7 +12,7 @@ license = "MIT" [dependencies] # realm -realm_io = "0.4" +realm_io = { version = "0.5" } realm_syscall = "0.1" realm_hook = { version = "0.1", optional = true } realm_lb = { version = "0.1", optional = true } @@ -36,8 +36,9 @@ brutal-shutdown = ["realm_io/brutal-shutdown"] transport = ["kaminari"] transport-boost = [] proxy = ["proxy-protocol", "bytes", "tokio/io-util"] +batched-udp = [] multi-thread = [] [dev-dependencies] -env_logger = "0.10" +env_logger = "0.11" tokio = { version = "1", features = ["macros"] } diff --git a/realm_core/src/udp/batched.rs b/realm_core/src/udp/batched.rs new file mode 100644 index 00000000..d9adde79 --- /dev/null +++ b/realm_core/src/udp/batched.rs @@ -0,0 +1,174 @@ +use std::io::Result; +use std::net::SocketAddr; +use tokio::net::UdpSocket; + +pub const PACKET_SIZE: usize = 1500; +pub const MAX_PACKETS: usize = 128; + +#[repr(transparent)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SockAddrStore { + #[cfg(all(target_os = "linux", feature = "batched-udp"))] + inner: realm_io::mmsg::SockAddrStore, + + #[cfg(not(all(target_os = "linux", feature = "batched-udp")))] + inner: std::net::SocketAddr, +} + +impl SockAddrStore { + pub const fn new() -> Self { + Self { + #[cfg(all(target_os = "linux", feature = "batched-udp"))] + inner: realm_io::mmsg::SockAddrStore::new(), + + #[cfg(not(all(target_os = "linux", feature = "batched-udp")))] + inner: { + use std::net::{IpAddr, Ipv4Addr}; + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0) + }, + } + } +} + +impl Default for SockAddrStore { + fn default() -> Self { + Self::new() + } +} + +impl From for SockAddrStore { + fn from(value: SocketAddr) -> Self { + SockAddrStore { inner: value.into() } + } +} + +impl From for SocketAddr { + fn from(value: SockAddrStore) -> Self { + value.inner.into() + } +} + +#[derive(Debug, Clone)] +pub struct Packet { + pub(super) buf: [u8; PACKET_SIZE], + pub(super) addr: SockAddrStore, + pub(super) cursor: u16, +} + +#[derive(Debug, Clone, Copy)] +pub struct PacketRef<'buf, 'addr> { + buf: &'buf [u8], + addr: &'addr SockAddrStore, +} + +impl Packet { + pub const fn new() -> Self { + Self { + buf: [0u8; PACKET_SIZE], + addr: SockAddrStore::new(), + cursor: 0u16, + } + } + + pub fn ref_with_addr<'a>(&self, addr: &'a SockAddrStore) -> PacketRef<'_, 'a> { + PacketRef { + buf: &self.buf[..self.cursor as usize], + addr, + } + } +} + +#[cfg(not(all(target_os = "linux", feature = "batched-udp")))] +pub use common::{recv_some, send_all}; +#[cfg(not(all(target_os = "linux", feature = "batched-udp")))] +mod common { + use super::*; + pub async fn recv_some(sock: &UdpSocket, pkts: &mut [Packet]) -> Result { + debug_assert!(!pkts.is_empty()); + let pkt = &mut pkts[0]; + let (bytes, addr) = sock.recv_from(&mut pkt.buf).await?; + pkt.addr.inner = addr; + pkt.cursor = bytes as u16; + Ok(1) + } + + pub async fn send_all<'a, 'b, I>(sock: &UdpSocket, pkts: I) -> Result<()> + where + I: ExactSizeIterator>, + { + for pkt in pkts { + let _ = sock.send_to(pkt.buf, &pkt.addr.inner).await?; + } + Ok(()) + } +} + +#[cfg(all(target_os = "linux", feature = "batched-udp"))] +pub use linux::{recv_some, send_all}; +#[cfg(all(target_os = "linux", feature = "batched-udp"))] +mod linux { + use super::*; + use std::io::{IoSlice, IoSliceMut}; + use std::mem::MaybeUninit; + use realm_io::mmsg::{MmsgHdr, MmsgHdrMut}; + use realm_io::mmsg::{send_mul_pkts, recv_mul_pkts}; + + pub async fn recv_some(sock: &UdpSocket, pkts: &mut [Packet]) -> Result { + const MAX_PKTS: usize = MAX_PACKETS; + debug_assert!(pkts.len() <= MAX_PKTS); + + let pkt_amt = pkts.len(); + let mut iovs: MaybeUninit<[IoSliceMut; MAX_PKTS]> = MaybeUninit::uninit(); + let mut msgs: MaybeUninit<[MmsgHdrMut; MAX_PKTS]> = MaybeUninit::uninit(); + let iovs = unsafe { iovs.assume_init_mut() }; + let msgs = unsafe { msgs.assume_init_mut() }; + + for ((pkt, iov), msg) in pkts.iter_mut().zip(iovs.iter_mut()).zip(msgs.iter_mut()) { + *iov = IoSliceMut::new(&mut pkt.buf); + *msg = MmsgHdrMut::new() + .with_addr(&mut pkt.addr.inner) + .with_iovec(std::slice::from_mut(iov)) + } + + let pkt_amt = recv_mul_pkts(sock, &mut msgs[..pkt_amt]).await?; + { + let mut bytes: [u16; MAX_PKTS] = unsafe { std::mem::zeroed() }; + for (msg, byte) in msgs.iter().zip(bytes.iter_mut()).take(pkt_amt) { + *byte = msg.get_ref().nbytes() as u16 + } + + for (pkt, byte) in pkts.iter_mut().zip(bytes).take(pkt_amt) { + pkt.cursor = byte + } + } + Ok(pkt_amt) + } + + pub async fn send_all<'a, 'b, I>(sock: &UdpSocket, pkts: I) -> Result<()> + where + I: ExactSizeIterator>, + { + const MAX_PKTS: usize = MAX_PACKETS; + debug_assert!(pkts.len() <= MAX_PKTS); + + let pkt_amt = pkts.len(); + let mut iovs: MaybeUninit<[IoSlice; MAX_PKTS]> = MaybeUninit::uninit(); + let mut msgs: MaybeUninit<[MmsgHdr; MAX_PKTS]> = MaybeUninit::uninit(); + let iovs = unsafe { iovs.assume_init_mut() }; + let msgs = unsafe { msgs.assume_init_mut() }; + + for ((pkt, iov), msg) in pkts.zip(iovs.iter_mut()).zip(msgs.iter_mut()) { + *iov = IoSlice::new(pkt.buf); + *msg = MmsgHdr::new() + .with_addr(&pkt.addr.inner) + .with_iovec(std::slice::from_ref(iov)) + } + + let mut cursor = 0; + while cursor < pkt_amt { + let n = send_mul_pkts(sock, &mut msgs[cursor..pkt_amt]).await?; + cursor += n; + } + Ok(()) + } +} diff --git a/realm_core/src/udp/middle.rs b/realm_core/src/udp/middle.rs index 0a9fe67f..ac38d0ae 100644 --- a/realm_core/src/udp/middle.rs +++ b/realm_core/src/udp/middle.rs @@ -1,92 +1,170 @@ use std::io::Result; use std::net::SocketAddr; use std::sync::Arc; - use tokio::net::UdpSocket; use super::SockMap; -use super::BUF_SIZE; -use super::socket; +use super::{socket, batched}; use crate::trick::Ref; use crate::time::timeoutfut; use crate::dns::resolve_addr; use crate::endpoint::{RemoteAddr, ConnectOpts}; -pub async fn associate_and_relay( - lis: &UdpSocket, - raddr: &RemoteAddr, - conn_opts: &ConnectOpts, - sockmap: &SockMap, -) -> Result<()> { - let mut buf = vec![0u8; BUF_SIZE]; - let associate_timeout = conn_opts.associate_timeout; +use batched::{Packet, SockAddrStore}; +use registry::Registry; +mod registry { + use super::*; + type Range = std::ops::Range; - loop { - let (n, laddr) = lis.recv_from(&mut buf).await?; - log::debug!("[udp]recvfrom client {}", &laddr); + pub struct Registry { + pkts: Box<[Packet]>, + groups: Vec, + cursor: u16, + } - let addr = resolve_addr(raddr).await?.iter().next().unwrap(); - log::debug!("[udp]{} resolved as {}", raddr, &addr); + impl Registry { + pub fn new(npkts: usize) -> Self { + debug_assert!(npkts <= batched::MAX_PACKETS); + Self { + pkts: vec![Packet::new(); npkts].into_boxed_slice(), + groups: Vec::with_capacity(npkts), + cursor: 0u16, + } + } - // get the socket associated with a unique client - let remote = match sockmap.find(&laddr) { - Some(x) => x, - None => { - log::info!("[udp]new association {} => {} as {}", &laddr, raddr, &addr); + pub async fn batched_recv_on(&mut self, sock: &UdpSocket) -> Result<()> { + let n = batched::recv_some(sock, &mut self.pkts).await?; + self.cursor = n as u16; + Ok(()) + } - let remote = Arc::new(socket::associate(&addr, conn_opts).await?); + pub fn group_by_addr(&mut self) { + let n = self.cursor as usize; + self.groups.clear(); + group_by_inner(&mut self.pkts[..n], &mut self.groups, |a, b| a.addr == b.addr); + } - sockmap.insert(laddr, remote.clone()); + pub fn group_iter(&self) -> GroupIter { + GroupIter { + pkts: &self.pkts, + ranges: self.groups.iter(), + } + } - // spawn sending back task - tokio::spawn(send_back( - Ref::new(lis), - laddr, - remote.clone(), - Ref::new(sockmap), - associate_timeout, - )); + pub fn iter(&self) -> std::slice::Iter<'_, Packet> { + self.pkts[..self.cursor as usize].iter() + } + + pub const fn count(&self) -> usize { + self.cursor as usize + } + } + + use std::slice::Iter; + use std::iter::Iterator; + pub struct GroupIter<'a> { + pkts: &'a [Packet], + ranges: Iter<'a, Range>, + } + + impl<'a> Iterator for GroupIter<'a> { + type Item = &'a [Packet]; + + fn next(&mut self) -> Option { + self.ranges + .next() + .map(|Range { start, end }| &self.pkts[*start as usize..*end as usize]) + } + } - remote + fn group_by_inner(data: &mut [T], groups: &mut Vec, eq: F) + where + F: Fn(&T, &T) -> bool, + { + let maxn = data.len(); + let (mut beg, mut end) = (0, 1); + while end < maxn { + // go ahead if addr is same + if eq(&data[end], &data[beg]) { + end += 1; + continue; } - }; + // pick packets afterwards + let mut probe = end + 1; + while probe < maxn { + if eq(&data[probe], &data[beg]) { + data.swap(probe, end); + end += 1; + } + probe += 1; + } + groups.push(beg as _..end as _); + (beg, end) = (end, end + 1); + } + groups.push(beg as _..end as _); + } +} + +pub async fn associate_and_relay( + lis: Ref, + rname: Ref, + conn_opts: Ref, + sockmap: Ref, +) -> Result<()> { + let mut registry = Registry::new(batched::MAX_PACKETS); - remote.send_to(&buf[..n], &addr).await?; + loop { + registry.batched_recv_on(&lis).await?; + log::debug!("[udp]entry batched recvfrom[{}]", registry.count()); + let raddr = resolve_addr(&rname).await?.iter().next().unwrap(); + log::debug!("[udp]{} resolved as {}", *rname, raddr); + + registry.group_by_addr(); + for pkts in registry.group_iter() { + let laddr = pkts[0].addr.clone().into(); + let rsock = sockmap.find_or_insert(&laddr, || { + let s = Arc::new(socket::associate(&raddr, &conn_opts)?); + tokio::spawn(send_back(lis, laddr, s.clone(), conn_opts, sockmap)); + log::info!("[udp]new association {} => {} as {}", laddr, *rname, raddr); + Result::Ok(s) + })?; + let raddr: SockAddrStore = raddr.into(); + batched::send_all(&rsock, pkts.iter().map(|x| x.ref_with_addr(&raddr))).await?; + } } } async fn send_back( - lis: Ref, + lsock: Ref, laddr: SocketAddr, - remote: Arc, + rsock: Arc, + conn_opts: Ref, sockmap: Ref, - associate_timeout: usize, ) { - let mut buf = vec![0u8; BUF_SIZE]; + let mut registry = Registry::new(batched::MAX_PACKETS); + let timeout = conn_opts.associate_timeout; + let laddr_s: SockAddrStore = laddr.into(); loop { - let res = match timeoutfut(remote.recv_from(&mut buf), associate_timeout).await { - Ok(x) => x, + match timeoutfut(registry.batched_recv_on(&rsock), timeout).await { Err(_) => { - log::debug!("[udp]association for {} timeout", &laddr); + log::debug!("[udp]rear recvfrom timeout"); break; } - }; - - let (n, raddr) = match res { - Ok(x) => x, - Err(e) => { - log::error!("[udp]failed to recvfrom remote: {}", e); - continue; + Ok(Err(e)) => { + log::error!("[udp]rear recvfrom failed: {}", e); + break; + } + Ok(Ok(())) => { + log::debug!("[udp]rear batched recvfrom[{}]", registry.count()) } }; - log::debug!("[udp]recvfrom remote {}", &raddr); - - if let Err(e) = lis.send_to(&buf[..n], &laddr).await { + let pkts = registry.iter().map(|pkt| pkt.ref_with_addr(&laddr_s)); + if let Err(e) = batched::send_all(&lsock, pkts).await { log::error!("[udp]failed to sendto client{}: {}", &laddr, e); - continue; + break; } } diff --git a/realm_core/src/udp/mod.rs b/realm_core/src/udp/mod.rs index e963a430..6f045dc4 100644 --- a/realm_core/src/udp/mod.rs +++ b/realm_core/src/udp/mod.rs @@ -3,17 +3,16 @@ mod socket; mod sockmap; mod middle; +mod batched; use std::io::Result; +use crate::trick::Ref; use crate::endpoint::Endpoint; use sockmap::SockMap; use middle::associate_and_relay; -/// UDP Buffer size. -pub const BUF_SIZE: usize = 2048; - /// Launch a udp relay. pub async fn run_udp(endpoint: Endpoint) -> Result<()> { let Endpoint { @@ -28,8 +27,12 @@ pub async fn run_udp(endpoint: Endpoint) -> Result<()> { let lis = socket::bind(&laddr, bind_opts).unwrap_or_else(|e| panic!("[udp]failed to bind {}: {}", laddr, e)); + let lis = Ref::new(&lis); + let raddr = Ref::new(&raddr); + let conn_opts = Ref::new(&conn_opts); + let sockmap = Ref::new(&sockmap); loop { - if let Err(e) = associate_and_relay(&lis, &raddr, &conn_opts, &sockmap).await { + if let Err(e) = associate_and_relay(lis, raddr, conn_opts, sockmap).await { log::error!("[udp]error: {}", e); } } diff --git a/realm_core/src/udp/socket.rs b/realm_core/src/udp/socket.rs index d97122ed..492842d8 100644 --- a/realm_core/src/udp/socket.rs +++ b/realm_core/src/udp/socket.rs @@ -6,7 +6,6 @@ use realm_syscall::new_udp_socket; use crate::endpoint::{BindOpts, ConnectOpts}; -#[allow(clippy::clone_on_copy)] pub fn bind(laddr: &SocketAddr, bind_opts: BindOpts) -> Result { let BindOpts { ipv6_only } = bind_opts; let socket = new_udp_socket(laddr)?; @@ -19,12 +18,12 @@ pub fn bind(laddr: &SocketAddr, bind_opts: BindOpts) -> Result { // ignore error let _ = socket.set_reuse_address(true); - socket.bind(&laddr.clone().into())?; + socket.bind(&(*laddr).into())?; UdpSocket::from_std(socket.into()) } -pub async fn associate(raddr: &SocketAddr, conn_opts: &ConnectOpts) -> Result { +pub fn associate(raddr: &SocketAddr, conn_opts: &ConnectOpts) -> Result { let ConnectOpts { bind_address, diff --git a/realm_core/src/udp/sockmap.rs b/realm_core/src/udp/sockmap.rs index c5b954f8..1e50848f 100644 --- a/realm_core/src/udp/sockmap.rs +++ b/realm_core/src/udp/sockmap.rs @@ -32,6 +32,21 @@ impl SockMap { // drop the lock } + #[inline] + pub fn find_or_insert(&self, addr: &SocketAddr, f: F) -> Result, E> + where + F: Fn() -> Result, E>, + { + match self.find(addr) { + Some(x) => Ok(x), + None => { + let socket = f()?; + self.insert(*addr, Arc::clone(&socket)); + Ok(socket) + } + } + } + #[inline] pub fn remove(&self, addr: &SocketAddr) { // fetch the lock