Skip to content

Commit

Permalink
(wip)realm-io: refactor & add accessor for mmsghdr
Browse files Browse the repository at this point in the history
  • Loading branch information
zephyrchien committed Apr 29, 2024
1 parent 5d406ff commit 6b5f03e
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 108 deletions.
4 changes: 1 addition & 3 deletions realm_io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ mod linux;
#[cfg(any(target_os = "linux", doc))]
#[cfg_attr(doc, doc(cfg(target_os = "linux")))]
pub use linux::{
AsyncRawIO,
mmsg::{Packet, PacketMut, PacketStore, SockAddrStore, STORE_LEN},
mmsg::{send_mul_pkts, recv_mul_pkts},
AsyncRawIO, mmsg,
zero_copy::{Pipe, bidi_zero_copy, pipe_size, set_pipe_size},
};

Expand Down
318 changes: 215 additions & 103 deletions realm_io/src/linux/mmsg.rs
Original file line number Diff line number Diff line change
@@ -1,85 +1,24 @@
//! Mmsg impl.
use std::task::{Poll, Context};
use std::io::Result;
use std::io::{IoSlice, IoSliceMut};
use std::os::unix::io::RawFd;

use crate::AsyncRawIO;

#[derive(Debug, Clone, Copy)]
pub struct Packet<'a, 'buf> {
addr: &'a SockAddrStore,
iovec: IoSlice<'buf>,
}

#[derive(Debug)]
pub struct PacketMut<'a, 'buf> {
addr: &'a mut SockAddrStore,
iovec: IoSliceMut<'buf>,
}

impl<'a, 'buf> Packet<'a, 'buf> {
pub fn new(addr: &'a SockAddrStore, data: &'buf [u8]) -> Self {
Self {
addr,
iovec: IoSlice::new(data),
}
}

pub fn into_store<'pkt>(&'pkt self) -> PacketStore<'a, 'buf, 'pkt> {
use std::marker::PhantomData;
let Packet { addr, iovec } = self;
PacketStore {
msg: libc::msghdr {
msg_name: addr.0.as_ptr() as *mut _,
msg_namelen: addr.0.len(),
msg_iov: iovec as *const IoSlice as *mut _,
msg_iovlen: 1,
msg_control: std::ptr::null_mut(),
msg_controllen: 0,
msg_flags: 0,
},
addr: PhantomData,
iovec: PhantomData,
packet: PhantomData,
}
}
}

impl<'a, 'buf> PacketMut<'a, 'buf> {
pub fn new(addr: &'a mut SockAddrStore, data: &'buf mut [u8]) -> Self {
Self {
addr,
iovec: IoSliceMut::new(data),
}
}

pub fn into_store<'pkt>(&'pkt mut self) -> PacketStore<'a, 'buf, 'pkt> {
use std::marker::PhantomData;
let PacketMut { addr, iovec } = self;
PacketStore {
msg: libc::msghdr {
msg_name: addr.0.as_ptr() as *mut _,
msg_namelen: addr.0.len(),
msg_iov: iovec as *mut IoSliceMut as *mut _,
msg_iovlen: 1,
msg_control: std::ptr::null_mut(),
msg_controllen: 0,
msg_flags: 0,
},
addr: PhantomData,
iovec: PhantomData,
packet: PhantomData,
}
}
}
pub use store::{PacketStore, Const, Mutable};
pub use store::{PacketRef, PacketMutRef};
pub use store::{SockAddrStore, SOCK_STORE_LEN};
pub type Packet<'a, 'iov, 'ctrl> = PacketStore<'a, 'iov, 'ctrl, Const>;
pub type PacketMut<'a, 'iov, 'ctrl> = PacketStore<'a, 'iov, 'ctrl, Mutable>;

#[inline]
fn sendmpkts(fd: RawFd, pkts: &[PacketStore]) -> i32 {
fn sendmpkts<M>(fd: RawFd, pkts: &mut [PacketStore<'_, '_, '_, M>]) -> i32 {
unsafe { libc::sendmmsg(fd, pkts.as_ptr() as *mut _, pkts.len() as u32, 0) }
}

#[inline]
fn recvmpkts(fd: RawFd, pkts: &mut [PacketStore]) -> i32 {
fn recvmpkts(fd: RawFd, pkts: &mut [PacketMut]) -> i32 {
unsafe {
libc::recvmmsg(
fd,
Expand All @@ -91,77 +30,250 @@ fn recvmpkts(fd: RawFd, pkts: &mut [PacketStore]) -> i32 {
}
}

fn poll_sendmpkts<S>(stream: &mut S, cx: &mut Context<'_>, pkts: &[PacketStore]) -> Poll<Result<usize>>
fn poll_sendmpkts<S, M>(
stream: &mut S,
cx: &mut Context<'_>,
pkts: &mut [PacketStore<'_, '_, '_, M>],
) -> Poll<Result<usize>>
where
S: AsyncRawIO + Unpin,
{
stream.poll_read_raw(cx, || sendmpkts(stream.as_raw_fd(), pkts) as isize)
stream.poll_write_raw(cx, || sendmpkts(stream.as_raw_fd(), pkts) as isize)
}

fn poll_recvmpkts<S>(stream: &mut S, cx: &mut Context<'_>, pkts: &mut [PacketStore]) -> Poll<Result<usize>>
fn poll_recvmpkts<S>(stream: &mut S, cx: &mut Context<'_>, pkts: &mut [PacketMut]) -> Poll<Result<usize>>
where
S: AsyncRawIO + Unpin,
{
stream.poll_write_raw(cx, || recvmpkts(stream.as_raw_fd(), pkts) as isize)
stream.poll_read_raw(cx, || recvmpkts(stream.as_raw_fd(), pkts) as isize)
}

pub async fn send_mul_pkts<S>(stream: &mut S, pkts: &[PacketStore<'_, '_, '_>]) -> Result<usize>
/// Send multiple packets.
pub async fn send_mul_pkts<S>(stream: &mut S, pkts: &mut [Packet<'_, '_, '_>]) -> Result<usize>
where
S: AsyncRawIO + Unpin,
{
std::future::poll_fn(move |cx| poll_sendmpkts(stream, cx, pkts)).await
}

pub async fn recv_mul_pkts<S>(stream: &mut S, pkts: &mut [PacketStore<'_, '_, '_>]) -> Result<usize>
/// Recv multiple packets.
pub async fn recv_mul_pkts<S>(stream: &mut S, pkts: &mut [PacketMut<'_, '_, '_>]) -> Result<usize>
where
S: AsyncRawIO + Unpin,
{
std::future::poll_fn(move |cx| poll_recvmpkts(stream, cx, pkts)).await
}

pub use store::{PacketStore, SockAddrStore, STORE_LEN};
mod store {
use std::mem;
use std::{mem, ptr, slice};
use std::marker::PhantomData;
use std::io::{IoSlice, IoSliceMut};
use std::net::SocketAddr;
use socket2::SockAddr;
use libc::{msghdr, sockaddr_storage, socklen_t};
use libc::{msghdr, mmsghdr};
use libc::{sockaddr_storage, socklen_t};

/// Marker.
pub struct Const {}

/// Marker.
pub struct Mutable {}

/// Represent [`libc::msghdr`].
/// Represent [`libc::mmsghdr`].
#[derive(Clone, Copy)]
#[repr(transparent)]
pub struct PacketStore<'a, 'buf, 'pkt> {
pub(crate) msg: msghdr,
pub(crate) addr: PhantomData<&'a ()>,
pub(crate) iovec: PhantomData<&'buf ()>,
pub(crate) packet: PhantomData<&'pkt ()>,
pub struct PacketStore<'a, 'iov, 'ctrl, M> {
pub(crate) store: mmsghdr,
_type: PhantomData<M>,
_lifetime: PhantomData<(&'a (), &'iov (), &'ctrl ())>,
}

/// Constant field accessor for [`PacketStore`].
pub struct PacketRef<'a, 'iov, 'ctrl, 'this> {
addr: &'a SockAddrStore,
iovec: &'iov [IoSlice<'iov>],
control: &'ctrl [u8],
flags: i32,
nbytes: u32,
_lifetime: PhantomData<&'this ()>,
}

/// Mutable field accessor for [`PacketStore`].
pub struct PacketMutRef<'a, 'iov, 'ctrl, 'this> {
addr: &'a mut SockAddrStore,
iovec: &'iov mut [IoSlice<'iov>],
control: &'ctrl mut [u8],
flags: i32,
nbytes: u32,
_lifetime: PhantomData<&'this ()>,
}

#[rustfmt::skip]
macro_rules! access_fn {
(!ref, $field: ident, $type: ty) => {
pub fn $field(&self) -> $type { &self.$field }
};
(!mut, $field: ident, $type: ty) => {
pub fn $field(&mut self) -> $type { &mut self.$field }
};
(!val, $field: ident, $type: ty) => {
pub fn $field(&self) -> $type { self.$field }
};
}

impl<'a, 'iov, 'ctrl, 'this> PacketRef<'a, 'iov, 'ctrl, 'this> {
access_fn!(!ref, addr, &&'a SockAddrStore);
access_fn!(!ref, iovec, &&'iov [IoSlice<'iov>]);
access_fn!(!ref, control, &&'ctrl [u8]);
access_fn!(!val, flags, i32);
access_fn!(!val, nbytes, u32);
}

impl<'a, 'iov, 'ctrl, 'this> PacketMutRef<'a, 'iov, 'ctrl, 'this> {
access_fn!(!mut, addr, &mut &'a mut SockAddrStore);
access_fn!(!mut, iovec, &mut &'iov mut [IoSlice<'iov>]);
access_fn!(!mut, control, &mut &'ctrl mut [u8]);
access_fn!(!val, flags, i32);
access_fn!(!val, nbytes, u32);
}

impl<'a, 'iov, 'ctrl, M> PacketStore<'a, 'iov, 'ctrl, M> {
/// New zeroed storage.
pub const fn new() -> Self {
Self {
store: unsafe { mem::zeroed::<mmsghdr>() },
_type: PhantomData,
_lifetime: PhantomData,
}
}

#[rustfmt::skip]
pub fn get_ref<'this>(&'this self) -> PacketRef<'this, 'a, 'iov, 'ctrl> {
let msghdr {
msg_name, msg_namelen,
msg_iov, msg_iovlen,
msg_control, msg_controllen, msg_flags,
} = self.store.msg_hdr;
let msg_len = self.store.msg_len;
unsafe { PacketRef {
addr: &*msg_name.cast(),
iovec: slice::from_raw_parts(msg_iov as *const _, msg_iovlen),
control: slice::from_raw_parts(msg_control as *const _, msg_controllen),
flags: msg_flags,
nbytes: msg_len,
_lifetime: PhantomData,
}}
}
}

impl<'a, 'iov, 'ctrl, M> Default for PacketStore<'a, 'iov, 'ctrl, M> {
fn default() -> Self {
Self::new()
}
}

impl<'a, 'iov, 'ctrl> PacketStore<'a, 'iov, 'ctrl, Const> {
/// Set target address.
pub const fn with_addr(mut self, addr: &'a SockAddrStore) -> Self {
self.store.msg_hdr.msg_name = addr.0.as_ptr() as *mut _;
self.store.msg_hdr.msg_namelen = addr.0.len();
self
}

/// Set data to send.
pub const fn with_iovec(mut self, iov: &'iov [IoSlice]) -> Self {
self.store.msg_hdr.msg_iov = ptr::from_ref(iov) as *mut _;
self.store.msg_hdr.msg_iovlen = iov.len();
self
}

/// Set control message to send.
pub const fn with_control(mut self, ctrl: &'ctrl [u8]) -> Self {
self.store.msg_hdr.msg_control = ptr::from_ref(ctrl) as *mut _;
self.store.msg_hdr.msg_controllen = ctrl.len();
self
}

/// Set sending flags.
pub const fn with_flags(mut self, flags: i32) -> Self {
self.store.msg_hdr.msg_flags = flags;
self
}
}

impl<'a, 'iov, 'ctrl> PacketStore<'a, 'iov, 'ctrl, Mutable> {
/// Set storage to accommodate peer address.
pub fn with_addr(mut self, addr: &'a mut SockAddrStore) -> Self {
self.store.msg_hdr.msg_name = addr.0.as_ptr() as *mut _;
self.store.msg_hdr.msg_namelen = addr.0.len();
self
}

/// Set storage to receive data.
pub fn with_iovec(mut self, iov: &'iov mut [IoSliceMut]) -> Self {
self.store.msg_hdr.msg_iov = ptr::from_mut(iov) as *mut _;
self.store.msg_hdr.msg_iovlen = iov.len();
self
}

/// Set storage to receive control message.
pub fn with_control(mut self, ctrl: &'ctrl mut [u8]) -> Self {
self.store.msg_hdr.msg_control = ptr::from_mut(ctrl) as *mut _;
self.store.msg_hdr.msg_controllen = ctrl.len();
self
}

#[rustfmt::skip]
pub fn get_mut<'this>(&'this mut self) -> PacketMutRef<'this, 'a, 'iov, 'ctrl> {
let msghdr {
msg_name, msg_namelen,
msg_iov, msg_iovlen,
msg_control, msg_controllen, msg_flags,
} = self.store.msg_hdr;
let msg_len = self.store.msg_len;
unsafe { PacketMutRef {
addr: &mut *msg_name.cast(),
iovec: slice::from_raw_parts_mut(msg_iov as *mut _, msg_iovlen),
control: slice::from_raw_parts_mut(msg_control as *mut _, msg_controllen),
flags: msg_flags,
nbytes: msg_len,
_lifetime: PhantomData,
}}
}
}

/// Represent [`libc::sockaddr_storage`].
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SockAddrStore(pub(crate) SockAddr);
pub const STORE_LEN: socklen_t = mem::size_of::<sockaddr_storage>() as socklen_t;
mod addr {
use super::*;
impl SockAddrStore {
pub const fn new_zeroed() -> Self {
Self(unsafe { SockAddr::new(mem::zeroed::<sockaddr_storage>(), STORE_LEN) })
}

/// Size of [`libc::sockaddr_storage`].
pub const SOCK_STORE_LEN: socklen_t = mem::size_of::<sockaddr_storage>() as socklen_t;

impl SockAddrStore {
/// New zeroed storage.
pub const fn new() -> Self {
Self(unsafe { SockAddr::new(mem::zeroed::<sockaddr_storage>(), SOCK_STORE_LEN) })
}
}

impl<T> From<T> for SockAddrStore
where
SockAddr: From<T>,
{
fn from(addr: T) -> Self {
Self(addr.into())
}
impl Default for SockAddrStore {
fn default() -> Self {
Self::new()
}
}

impl<T> From<T> for SockAddrStore
where
SockAddr: From<T>,
{
fn from(addr: T) -> Self {
Self(addr.into())
}
}

impl From<SockAddrStore> for SocketAddr {
fn from(store: SockAddrStore) -> Self {
store.0.as_socket().unwrap()
}
impl From<SockAddrStore> for SocketAddr {
fn from(store: SockAddrStore) -> Self {
store.0.as_socket().unwrap()
}
}
}
Loading

0 comments on commit 6b5f03e

Please sign in to comment.