From 200b2eedd29c3c6ae6c9413a2332414e92d0aabb Mon Sep 17 00:00:00 2001 From: why Date: Sun, 21 Aug 2022 14:05:26 +0800 Subject: [PATCH 01/14] fix port_num bug port_num should always be 1. --- src/context.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/context.rs b/src/context.rs index abaf6cc..543a870 100644 --- a/src/context.rs +++ b/src/context.rs @@ -80,8 +80,9 @@ impl Context { // SAFETY: POD FFI type let mut inner_port_attr = unsafe { std::mem::zeroed() }; - let errno = - unsafe { rdma_sys::___ibv_query_port(inner_ctx.as_ptr(), 1, &mut inner_port_attr) }; + let errno = unsafe { + rdma_sys::___ibv_query_port(inner_ctx.as_ptr(), port_num, &mut inner_port_attr) + }; if errno != 0_i32 { return Err(log_ret_last_os_err_with_note("ibv_query_port failed")); } From a3759f37d6eb275f66c11d21b18c2c6e89d7e6bc Mon Sep 17 00:00:00 2001 From: why Date: Sun, 21 Aug 2022 14:10:26 +0800 Subject: [PATCH 02/14] add multi-connection support Add APIs for `Rdma` to create a new `Rdma` that has the same `mr_allocator` and `event_listener` as parent. Change server and client examples to add multi-connection APIs demo. --- examples/client.rs | 18 ++- examples/server.rs | 14 +++ src/agent.rs | 2 +- src/lib.rs | 268 ++++++++++++++++++++++++++++++++++++++++++--- src/queue_pair.rs | 33 ++++-- 5 files changed, 305 insertions(+), 30 deletions(-) diff --git a/examples/client.rs b/examples/client.rs index 248480d..6adcee3 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -102,10 +102,8 @@ async fn request_then_write_with_imm(rdma: &Rdma) -> io::Result<()> { #[tokio::main] async fn main() { println!("client start"); - let rdma = RdmaBuilder::default() - .connect("localhost:5555") - .await - .unwrap(); + let addr = "localhost:5555"; + let rdma = RdmaBuilder::default().connect(addr).await.unwrap(); println!("connected"); send_data_to_server(&rdma).await.unwrap(); send_data_with_imm_to_server(&rdma).await.unwrap(); @@ -113,4 +111,16 @@ async fn main() { request_then_write(&rdma).await.unwrap(); request_then_write_with_imm(&rdma).await.unwrap(); println!("client done"); + + // create new `Rdma`s (connections) that has the same `mr_allocator` and `event_listener` as parent + for _ in 0..3 { + let rdma = rdma.new_connect(addr).await.unwrap(); + println!("connected"); + send_data_to_server(&rdma).await.unwrap(); + send_data_with_imm_to_server(&rdma).await.unwrap(); + send_lmr_to_server(&rdma).await.unwrap(); + request_then_write(&rdma).await.unwrap(); + request_then_write_with_imm(&rdma).await.unwrap(); + } + println!("client done"); } diff --git a/examples/server.rs b/examples/server.rs index d9793c7..9c34082 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -88,4 +88,18 @@ async fn main() { .await .unwrap(); println!("server done"); + + // create new `Rdma`s (connections) that has the same `mr_allocator` and `event_listener` as parent + for _ in 0..3 { + let rdma = rdma.listen().await.unwrap(); + println!("accepted"); + receive_data_from_client(&rdma).await.unwrap(); + receive_data_with_imm_from_client(&rdma).await.unwrap(); + read_rmr_from_client(&rdma).await.unwrap(); + receive_mr_after_being_written(&rdma).await.unwrap(); + receive_mr_after_being_written_with_imm(&rdma) + .await + .unwrap(); + } + println!("server done"); } diff --git a/src/agent.rs b/src/agent.rs index 791cba2..7651628 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -60,7 +60,7 @@ pub(crate) struct Agent { #[allow(dead_code)] agent_thread: Arc, /// Max message length - max_sr_data_len: usize, + pub(crate) max_sr_data_len: usize, } impl Drop for Agent { diff --git a/src/lib.rs b/src/lib.rs index 80a3fda..a1deddc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -162,7 +162,6 @@ use clippy_utilities::Cast; use completion_queue::{DEFAULT_CQ_SIZE, DEFAULT_MAX_CQE}; use context::Context; use enumflags2::{bitflags, BitFlags}; -#[cfg(feature = "cm")] use error_utilities::log_ret_last_os_err; use event_listener::EventListener; pub use memory_region::{ @@ -185,10 +184,11 @@ use rdma_sys::{ use rmr_manager::DEFAULT_RMR_TIMEOUT; #[cfg(feature = "cm")] use std::ptr::null_mut; -use std::{alloc::Layout, fmt::Debug, io, sync::Arc, time::Duration}; +use std::{alloc::Layout, fmt::Debug, io, ptr::NonNull, sync::Arc, time::Duration}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream, ToSocketAddrs}, + sync::Mutex, }; use tracing::debug; @@ -765,6 +765,7 @@ impl RdmaBuilder { // wait for the remote agent to prepare tokio::time::sleep(Duration::from_secs(1)).await; } + rdma.tcp_listener = Arc::new(Mutex::new(Some(tcp_listener))); Ok(rdma) } ConnectionType::RCCM => Err(io::Error::new( @@ -925,6 +926,8 @@ pub struct Rdma { conn_type: ConnectionType, /// If send/recv raw data raw: bool, + /// Tcp listener used for new connections + tcp_listener: Arc>>, } impl Rdma { @@ -949,26 +952,65 @@ impl Rdma { Arc::::clone(&pd), mr_attr, )); - let qp = Arc::new( - pd.create_queue_pair_builder() - .set_event_listener(event_listener) - .set_port_num(dev_attr.port_num) - .set_gid_index(dev_attr.gid_index) - .set_max_send_wr(qp_attr.max_send_wr) - .set_max_send_sge(qp_attr.max_send_sge) - .set_max_recv_wr(qp_attr.max_recv_wr) - .set_max_recv_sge(qp_attr.max_recv_sge) - .build()?, - ); + let mut qp = pd + .create_queue_pair_builder() + .set_event_listener(event_listener) + .set_port_num(dev_attr.port_num) + .set_gid_index(dev_attr.gid_index) + .set_max_send_wr(qp_attr.max_send_wr) + .set_max_send_sge(qp_attr.max_send_sge) + .set_max_recv_wr(qp_attr.max_recv_wr) + .set_max_recv_sge(qp_attr.max_recv_sge) + .build()?; qp.modify_to_init(qp_attr.access, dev_attr.port_num)?; Ok(Self { ctx, pd, - qp, + qp: Arc::new(qp), agent: None, allocator, conn_type: qp_attr.conn_type, raw: qp_attr.raw, + tcp_listener: Arc::new(Mutex::new(None)), + }) + } + + /// Create a new `Rdma` that has the same `mr_allocator` and `event_listener` as parent. + fn clone(&self) -> io::Result { + let access = self.qp.access.map_or_else( + || { + Err(io::Error::new( + io::ErrorKind::Other, + "parent qp access is none", + )) + }, + Ok, + )?; + let mut qp_init_attr = self.qp.qp_init_attr.clone(); + // SAFETY: ffi + let inner_qp = NonNull::new(unsafe { + rdma_sys::ibv_create_qp(self.pd.as_ptr(), &mut qp_init_attr.qp_init_attr_inner) + }) + .ok_or_else(log_ret_last_os_err)?; + let mut qp = QueuePair { + pd: Arc::clone(&self.pd), + event_listener: Arc::clone(&self.qp.event_listener), + inner_qp, + port_num: self.qp.port_num, + gid_index: self.qp.gid_index, + qp_init_attr, + access: self.qp.access, + }; + qp.modify_to_init(access, self.qp.port_num)?; + Ok(Self { + ctx: Arc::clone(&self.ctx), + pd: Arc::clone(&self.pd), + qp: Arc::new(qp), + agent: None, + allocator: Arc::clone(&self.allocator), + conn_type: self.conn_type, + raw: self.raw, + tcp_listener: Arc::clone(&self.tcp_listener), }) } @@ -986,6 +1028,204 @@ impl Rdma { Ok(()) } + /// Listen for new connections using the same `mr_allocator` and `event_listener` as parent `Rdma` + /// + /// Used with `connect` and `new_connect` + /// + /// # Examples + /// + /// ``` + /// use async_rdma::RdmaBuilder; + /// use portpicker::pick_unused_port; + /// use std::{ + /// io, + /// net::{Ipv4Addr, SocketAddrV4}, + /// time::Duration, + /// }; + /// + /// async fn client(addr: SocketAddrV4) -> io::Result<()> { + /// let rdma = RdmaBuilder::default().connect(addr).await?; + /// for _ in 0..3 { + /// let _new_rdma = rdma.new_connect(addr).await?; + /// } + /// Ok(()) + /// } + /// + /// #[tokio::main] + /// async fn server(addr: SocketAddrV4) -> io::Result<()> { + /// let rdma = RdmaBuilder::default().listen(addr).await?; + /// for _ in 0..3 { + /// let _new_rdma = rdma.listen().await?; + /// } + /// Ok(()) + /// } + /// #[tokio::main] + /// async fn main() { + /// let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), pick_unused_port().unwrap()); + /// std::thread::spawn(move || server(addr)); + /// tokio::time::sleep(Duration::from_secs(3)).await; + /// client(addr) + /// .await + /// .map_err(|err| println!("{}", err)) + /// .unwrap(); + /// } + /// ``` + #[inline] + pub async fn listen(&self) -> io::Result { + match self.conn_type { + ConnectionType::RCSocket => { + let (mut stream, _) = self + .tcp_listener + .lock() + .await + .as_ref() + .map_or_else( + || Err(io::Error::new(io::ErrorKind::Other, "tcp_listener is None")), + |listener| Ok(async { listener.accept().await }), + )? + .await?; + let mut rdma = self.clone()?; + let endpoint_size = bincode::serialized_size(&rdma.endpoint()).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("Endpoint serialization failed, {:?}", e), + ) + })?; + let mut remote = vec![0_u8; endpoint_size.cast()]; + // the byte number is not important, as read_exact will fill the buffer + let _ = stream.read_exact(remote.as_mut()).await?; + let remote: QueuePairEndpoint = bincode::deserialize(&remote).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("failed to deserialize remote endpoint, {:?}", e), + ) + })?; + let local = bincode::serialize(&rdma.endpoint()).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("failed to deserialize remote endpoint, {:?}", e), + ) + })?; + stream.write_all(&local).await?; + rdma.qp_handshake(remote)?; + debug!("handshake done"); + if !rdma.raw { + #[allow(clippy::unreachable)] + let max_sr_data_len = self.agent.as_ref().map_or_else( + || { + unreachable!("agent of parent rdma is None"); + }, + |agent| agent.max_sr_data_len, + ); + + let agent = Arc::new(Agent::new( + Arc::::clone(&rdma.qp), + Arc::::clone(&rdma.allocator), + max_sr_data_len, + )?); + rdma.agent = Some(agent); + // wait for the remote agent to prepare + tokio::time::sleep(Duration::from_secs(1)).await; + } + Ok(rdma) + } + ConnectionType::RCCM => Err(io::Error::new( + io::ErrorKind::Other, + "ConnectionType should be XXSocket", + )), + } + } + + /// Establish new connections with RDMA server using the same `mr_allocator` and `event_listener` as parent `Rdma` + /// + /// Used with `listen` + /// + /// # Examples + /// + /// ``` + /// use async_rdma::RdmaBuilder; + /// use portpicker::pick_unused_port; + /// use std::{ + /// io, + /// net::{Ipv4Addr, SocketAddrV4}, + /// time::Duration, + /// }; + /// + /// async fn client(addr: SocketAddrV4) -> io::Result<()> { + /// let rdma = RdmaBuilder::default().connect(addr).await?; + /// for _ in 0..3 { + /// let _new_rdma = rdma.new_connect(addr).await?; + /// } + /// Ok(()) + /// } + /// + /// #[tokio::main] + /// async fn server(addr: SocketAddrV4) -> io::Result<()> { + /// let rdma = RdmaBuilder::default().listen(addr).await?; + /// for _ in 0..3 { + /// let _new_rdma = rdma.listen().await?; + /// } + /// Ok(()) + /// } + /// #[tokio::main] + /// async fn main() { + /// let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), pick_unused_port().unwrap()); + /// std::thread::spawn(move || server(addr)); + /// tokio::time::sleep(Duration::from_secs(3)).await; + /// client(addr) + /// .await + /// .map_err(|err| println!("{}", err)) + /// .unwrap(); + /// } + /// ``` + #[inline] + pub async fn new_connect(&self, addr: A) -> io::Result { + match self.conn_type { + ConnectionType::RCSocket => { + let mut rdma = self.clone()?; + let mut stream = TcpStream::connect(addr).await?; + let mut endpoint = bincode::serialize(&rdma.endpoint()).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("failed to serailize the endpoint, {:?}", e), + ) + })?; + stream.write_all(&endpoint).await?; + // the byte number is not important, as read_exact will fill the buffer + let _ = stream.read_exact(endpoint.as_mut()).await?; + let remote: QueuePairEndpoint = bincode::deserialize(&endpoint).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("failed to deserailize the endpoint, {:?}", e), + ) + })?; + rdma.qp_handshake(remote)?; + if !rdma.raw { + #[allow(clippy::unreachable)] + let max_sr_data_len = self.agent.as_ref().map_or_else( + || { + unreachable!("agent of parent rdma is None"); + }, + |agent| agent.max_sr_data_len, + ); + let agent = Arc::new(Agent::new( + Arc::::clone(&rdma.qp), + Arc::::clone(&rdma.allocator), + max_sr_data_len, + )?); + rdma.agent = Some(agent); + // wait for remote agent to initialize + tokio::time::sleep(Duration::from_secs(1)).await; + } + Ok(rdma) + } + ConnectionType::RCCM => Err(io::Error::new( + io::ErrorKind::Other, + "ConnectionType should be XXSocket", + )), + } + } + /// Send the content in the `lm` /// /// Used with `receive`. diff --git a/src/queue_pair.rs b/src/queue_pair.rs index d5c9b93..7505851 100644 --- a/src/queue_pair.rs +++ b/src/queue_pair.rs @@ -42,9 +42,10 @@ pub(crate) static MAX_SEND_SGE: u32 = 10; pub(crate) static MAX_RECV_SGE: u32 = 10; /// Queue pair initialized attribute -struct QueuePairInitAttr { +#[derive(Clone)] +pub(crate) struct QueuePairInitAttr { /// Internal `ibv_qp_init_attr` structure - qp_init_attr_inner: ibv_qp_init_attr, + pub(crate) qp_init_attr_inner: ibv_qp_init_attr, } impl Default for QueuePairInitAttr { @@ -129,7 +130,6 @@ impl QueuePairBuilder { /// `EPERM` Not enough permissions to create a QP with this Transport Service Type pub(crate) fn build(mut self) -> io::Result { // SAFETY: ffi - // TODO: check safety let inner_qp = NonNull::new(unsafe { rdma_sys::ibv_create_qp(self.pd.as_ptr(), &mut self.qp_init_attr.qp_init_attr_inner) }) @@ -137,14 +137,16 @@ impl QueuePairBuilder { Ok(QueuePair { pd: Arc::::clone(&self.pd), inner_qp, - event_listener: self.event_listener.ok_or_else(|| { + event_listener: Arc::new(self.event_listener.ok_or_else(|| { io::Error::new( io::ErrorKind::Other, "event channel is not set for the queue pair builder", ) - })?, + })?), port_num: self.port_num, gid_index: self.gid_index, + qp_init_attr: self.qp_init_attr, + access: None, }) } @@ -208,15 +210,19 @@ pub(crate) struct QueuePairEndpoint { #[derive(Debug)] pub(crate) struct QueuePair { /// protection domain it belongs to - pd: Arc, + pub(crate) pd: Arc, /// event listener - pub(crate) event_listener: EventListener, + pub(crate) event_listener: Arc, /// internal `ibv_qp` pointer - inner_qp: NonNull, + pub(crate) inner_qp: NonNull, /// port number - port_num: u8, + pub(crate) port_num: u8, /// gid index - gid_index: usize, + pub(crate) gid_index: usize, + /// backup for child qp + pub(crate) qp_init_attr: QueuePairInitAttr, + /// access of this qp + pub(crate) access: Option, } impl QueuePair { @@ -243,7 +249,11 @@ impl QueuePair { /// `EINVAL` Invalid value provided in attr or in `attr_mask` /// /// `ENOMEM` Not enough resources to complete this operation - pub(crate) fn modify_to_init(&self, flag: ibv_access_flags, port_num: u8) -> io::Result<()> { + pub(crate) fn modify_to_init( + &mut self, + flag: ibv_access_flags, + port_num: u8, + ) -> io::Result<()> { // SAFETY: POD FFI type let mut attr = unsafe { std::mem::zeroed::() }; attr.pkey_index = 0; @@ -260,6 +270,7 @@ impl QueuePair { if errno != 0_i32 { return Err(log_ret_last_os_err()); } + self.access = Some(flag); Ok(()) } From 8d72e3e5f5b0666c7f848f05347f9d3c74960304 Mon Sep 17 00:00:00 2001 From: why Date: Sun, 21 Aug 2022 15:22:08 +0800 Subject: [PATCH 03/14] code reuse and cleanup --- src/lib.rs | 504 +++++++++++++++++++---------------------------------- 1 file changed, 178 insertions(+), 326 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a1deddc..906e0e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -469,33 +469,9 @@ impl RdmaBuilder { match self.qp_attr.conn_type { ConnectionType::RCSocket => { let mut rdma = self.build()?; - let mut stream = TcpStream::connect(addr).await?; - let mut endpoint = bincode::serialize(&rdma.endpoint()).map_err(|e| { - io::Error::new( - io::ErrorKind::InvalidInput, - format!("failed to serailize the endpoint, {:?}", e), - ) - })?; - stream.write_all(&endpoint).await?; - // the byte number is not important, as read_exact will fill the buffer - let _ = stream.read_exact(endpoint.as_mut()).await?; - let remote: QueuePairEndpoint = bincode::deserialize(&endpoint).map_err(|e| { - io::Error::new( - io::ErrorKind::InvalidInput, - format!("failed to deserailize the endpoint, {:?}", e), - ) - })?; + let remote = tcp_connect_helper(addr, &rdma.endpoint()).await?; rdma.qp_handshake(remote)?; - if !rdma.raw { - let agent = Arc::new(Agent::new( - Arc::::clone(&rdma.qp), - Arc::::clone(&rdma.allocator), - self.agent_attr.max_message_length, - )?); - rdma.agent = Some(agent); - // wait for remote agent to initialize - tokio::time::sleep(Duration::from_secs(1)).await; - } + rdma.init_agent(self.agent_attr.max_message_length).await?; Ok(rdma) } ConnectionType::RCCM => Err(io::Error::new( @@ -602,86 +578,10 @@ impl RdmaBuilder { "ConnectionType should be XXSocket", )), ConnectionType::RCCM => { - let msg_len = self.agent_attr.max_message_length; + let max_message_length = self.agent_attr.max_message_length; let mut rdma = self.build()?; - - // SAFETY: POD FFI type - let mut hints = unsafe { std::mem::zeroed::() }; - let mut info: *mut rdma_addrinfo = null_mut(); - hints.ai_port_space = rdma_port_space::RDMA_PS_TCP.cast(); - // Safety: ffi - let mut ret = unsafe { - rdma_getaddrinfo( - node.as_ptr().cast(), - service.as_ptr().cast(), - &hints, - &mut info, - ) - }; - if ret != 0_i32 { - return Err(log_ret_last_os_err()); - } - - let mut id: *mut rdma_cm_id = null_mut(); - // Safety: ffi - ret = unsafe { rdma_create_ep(&mut id, info, rdma.pd.as_ptr(), null_mut()) }; - if ret != 0_i32 { - // Safety: ffi - unsafe { - rdma_freeaddrinfo(info); - } - return Err(log_ret_last_os_err()); - } - - // Safety: id was initialized by `rdma_create_ep` - unsafe { - debug!( - "cm_id: {:?},{:?},{:?},{:?},{:?},{:?},{:?}", - (*id).qp, - (*id).pd, - (*id).verbs, - (*id).recv_cq_channel, - (*id).send_cq_channel, - (*id).recv_cq, - (*id).send_cq - ); - (*id).qp = rdma.qp.as_ptr(); - (*id).pd = rdma.pd.as_ptr(); - (*id).verbs = rdma.ctx.as_ptr(); - (*id).recv_cq_channel = rdma.qp.event_listener.cq.event_channel().as_ptr(); - (*id).recv_cq_channel = rdma.qp.event_listener.cq.event_channel().as_ptr(); - (*id).recv_cq = rdma.qp.event_listener.cq.as_ptr(); - (*id).send_cq = rdma.qp.event_listener.cq.as_ptr(); - debug!( - "cm_id: {:?},{:?},{:?},{:?},{:?},{:?},{:?}", - (*id).qp, - (*id).pd, - (*id).verbs, - (*id).recv_cq_channel, - (*id).send_cq_channel, - (*id).recv_cq, - (*id).send_cq - ); - } - // Safety: ffi - ret = unsafe { rdma_connect(id, null_mut()) }; - if ret != 0_i32 { - // Safety: ffi - unsafe { - let _ = rdma_disconnect(id); - } - return Err(log_ret_last_os_err()); - } - if !rdma.raw { - let agent = Arc::new(Agent::new( - Arc::::clone(&rdma.qp), - Arc::::clone(&rdma.allocator), - msg_len, - )?); - rdma.agent = Some(agent); - // wait for remote agent to initialize - tokio::time::sleep(Duration::from_secs(1)).await; - } + cm_connect_helper(&mut rdma, node, service)?; + rdma.init_agent(max_message_length).await?; Ok(rdma) } } @@ -727,44 +627,12 @@ impl RdmaBuilder { pub async fn listen(self, addr: A) -> io::Result { match self.qp_attr.conn_type { ConnectionType::RCSocket => { - let tcp_listener = TcpListener::bind(addr).await?; - let (mut stream, _) = tcp_listener.accept().await?; let mut rdma = self.build()?; - - let endpoint_size = bincode::serialized_size(&rdma.endpoint()).map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!("Endpoint serialization failed, {:?}", e), - ) - })?; - let mut remote = vec![0_u8; endpoint_size.cast()]; - // the byte number is not important, as read_exact will fill the buffer - let _ = stream.read_exact(remote.as_mut()).await?; - let remote: QueuePairEndpoint = bincode::deserialize(&remote).map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!("failed to deserialize remote endpoint, {:?}", e), - ) - })?; - let local = bincode::serialize(&rdma.endpoint()).map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!("failed to deserialize remote endpoint, {:?}", e), - ) - })?; - stream.write_all(&local).await?; + let tcp_listener = TcpListener::bind(addr).await?; + let remote = tcp_listen(&tcp_listener, &rdma.endpoint()).await?; rdma.qp_handshake(remote)?; debug!("handshake done"); - if !rdma.raw { - let agent = Arc::new(Agent::new( - Arc::::clone(&rdma.qp), - Arc::::clone(&rdma.allocator), - self.agent_attr.max_message_length, - )?); - rdma.agent = Some(agent); - // wait for the remote agent to prepare - tokio::time::sleep(Duration::from_secs(1)).await; - } + rdma.init_agent(self.agent_attr.max_message_length).await?; rdma.tcp_listener = Arc::new(Mutex::new(Some(tcp_listener))); Ok(rdma) } @@ -898,6 +766,136 @@ impl Debug for RdmaBuilder { } } +/// Exchange metadata through tcp +async fn tcp_connect_helper( + addr: A, + ep: &QueuePairEndpoint, +) -> io::Result { + let mut stream = TcpStream::connect(addr).await?; + let mut endpoint = bincode::serialize(ep).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("failed to serailize the endpoint, {:?}", e), + ) + })?; + stream.write_all(&endpoint).await?; + // the byte number is not important, as read_exact will fill the buffer + let _ = stream.read_exact(endpoint.as_mut()).await?; + bincode::deserialize(&endpoint).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("failed to deserailize the endpoint, {:?}", e), + ) + }) +} + +/// Listen for exchanging metadata through tcp +async fn tcp_listen( + tcp_listener: &TcpListener, + ep: &QueuePairEndpoint, +) -> io::Result { + let (mut stream, _) = tcp_listener.accept().await?; + + let endpoint_size = bincode::serialized_size(ep).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("Endpoint serialization failed, {:?}", e), + ) + })?; + let mut remote = vec![0_u8; endpoint_size.cast()]; + // the byte number is not important, as read_exact will fill the buffer + let _ = stream.read_exact(remote.as_mut()).await?; + let remote: QueuePairEndpoint = bincode::deserialize(&remote).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("failed to deserialize remote endpoint, {:?}", e), + ) + })?; + let local = bincode::serialize(ep).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("failed to deserialize remote endpoint, {:?}", e), + ) + })?; + stream.write_all(&local).await?; + Ok(remote) +} + +/// Exchange metadata and setup connection through cm +#[inline] +#[cfg(feature = "cm")] +fn cm_connect_helper(rdma: &mut Rdma, node: &str, service: &str) -> io::Result<()> { + // SAFETY: POD FFI type + let mut hints = unsafe { std::mem::zeroed::() }; + let mut info: *mut rdma_addrinfo = null_mut(); + hints.ai_port_space = rdma_port_space::RDMA_PS_TCP.cast(); + // Safety: ffi + let mut ret = unsafe { + rdma_getaddrinfo( + node.as_ptr().cast(), + service.as_ptr().cast(), + &hints, + &mut info, + ) + }; + if ret != 0_i32 { + return Err(log_ret_last_os_err()); + } + + let mut id: *mut rdma_cm_id = null_mut(); + // Safety: ffi + ret = unsafe { rdma_create_ep(&mut id, info, rdma.pd.as_ptr(), null_mut()) }; + if ret != 0_i32 { + // Safety: ffi + unsafe { + rdma_freeaddrinfo(info); + } + return Err(log_ret_last_os_err()); + } + + // Safety: id was initialized by `rdma_create_ep` + unsafe { + debug!( + "cm_id: {:?},{:?},{:?},{:?},{:?},{:?},{:?}", + (*id).qp, + (*id).pd, + (*id).verbs, + (*id).recv_cq_channel, + (*id).send_cq_channel, + (*id).recv_cq, + (*id).send_cq + ); + (*id).qp = rdma.qp.as_ptr(); + (*id).pd = rdma.pd.as_ptr(); + (*id).verbs = rdma.ctx.as_ptr(); + (*id).recv_cq_channel = rdma.qp.event_listener.cq.event_channel().as_ptr(); + (*id).recv_cq_channel = rdma.qp.event_listener.cq.event_channel().as_ptr(); + (*id).recv_cq = rdma.qp.event_listener.cq.as_ptr(); + (*id).send_cq = rdma.qp.event_listener.cq.as_ptr(); + debug!( + "cm_id: {:?},{:?},{:?},{:?},{:?},{:?},{:?}", + (*id).qp, + (*id).pd, + (*id).verbs, + (*id).recv_cq_channel, + (*id).send_cq_channel, + (*id).recv_cq, + (*id).send_cq + ); + } + // Safety: ffi + ret = unsafe { rdma_connect(id, null_mut()) }; + if ret != 0_i32 { + // Safety: ffi + unsafe { + let _ = rdma_disconnect(id); + } + return Err(log_ret_last_os_err()); + } + + Ok(()) +} + /// Method of establishing a connection #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ConnectionType { @@ -1028,6 +1026,21 @@ impl Rdma { Ok(()) } + /// Agent init helper + async fn init_agent(&mut self, max_message_length: usize) -> io::Result<()> { + if !self.raw { + let agent = Arc::new(Agent::new( + Arc::::clone(&self.qp), + Arc::::clone(&self.allocator), + max_message_length, + )?); + self.agent = Some(agent); + // wait for the remote agent to prepare + tokio::time::sleep(Duration::from_secs(1)).await; + } + Ok(()) + } + /// Listen for new connections using the same `mr_allocator` and `event_listener` as parent `Rdma` /// /// Used with `connect` and `new_connect` @@ -1074,59 +1087,27 @@ impl Rdma { pub async fn listen(&self) -> io::Result { match self.conn_type { ConnectionType::RCSocket => { - let (mut stream, _) = self + let mut rdma = self.clone()?; + let remote = self .tcp_listener .lock() .await .as_ref() .map_or_else( || Err(io::Error::new(io::ErrorKind::Other, "tcp_listener is None")), - |listener| Ok(async { listener.accept().await }), + |listener| Ok(async { tcp_listen(listener, &rdma.endpoint()).await }), )? .await?; - let mut rdma = self.clone()?; - let endpoint_size = bincode::serialized_size(&rdma.endpoint()).map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!("Endpoint serialization failed, {:?}", e), - ) - })?; - let mut remote = vec![0_u8; endpoint_size.cast()]; - // the byte number is not important, as read_exact will fill the buffer - let _ = stream.read_exact(remote.as_mut()).await?; - let remote: QueuePairEndpoint = bincode::deserialize(&remote).map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!("failed to deserialize remote endpoint, {:?}", e), - ) - })?; - let local = bincode::serialize(&rdma.endpoint()).map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!("failed to deserialize remote endpoint, {:?}", e), - ) - })?; - stream.write_all(&local).await?; rdma.qp_handshake(remote)?; debug!("handshake done"); - if !rdma.raw { - #[allow(clippy::unreachable)] - let max_sr_data_len = self.agent.as_ref().map_or_else( - || { - unreachable!("agent of parent rdma is None"); - }, - |agent| agent.max_sr_data_len, - ); - - let agent = Arc::new(Agent::new( - Arc::::clone(&rdma.qp), - Arc::::clone(&rdma.allocator), - max_sr_data_len, - )?); - rdma.agent = Some(agent); - // wait for the remote agent to prepare - tokio::time::sleep(Duration::from_secs(1)).await; - } + #[allow(clippy::unreachable)] + let max_message_length = self.agent.as_ref().map_or_else( + || { + unreachable!("agent of parent rdma is None"); + }, + |agent| agent.max_sr_data_len, + ); + rdma.init_agent(max_message_length).await?; Ok(rdma) } ConnectionType::RCCM => Err(io::Error::new( @@ -1183,40 +1164,16 @@ impl Rdma { match self.conn_type { ConnectionType::RCSocket => { let mut rdma = self.clone()?; - let mut stream = TcpStream::connect(addr).await?; - let mut endpoint = bincode::serialize(&rdma.endpoint()).map_err(|e| { - io::Error::new( - io::ErrorKind::InvalidInput, - format!("failed to serailize the endpoint, {:?}", e), - ) - })?; - stream.write_all(&endpoint).await?; - // the byte number is not important, as read_exact will fill the buffer - let _ = stream.read_exact(endpoint.as_mut()).await?; - let remote: QueuePairEndpoint = bincode::deserialize(&endpoint).map_err(|e| { - io::Error::new( - io::ErrorKind::InvalidInput, - format!("failed to deserailize the endpoint, {:?}", e), - ) - })?; + let remote = tcp_connect_helper(addr, &rdma.endpoint()).await?; rdma.qp_handshake(remote)?; - if !rdma.raw { - #[allow(clippy::unreachable)] - let max_sr_data_len = self.agent.as_ref().map_or_else( - || { - unreachable!("agent of parent rdma is None"); - }, - |agent| agent.max_sr_data_len, - ); - let agent = Arc::new(Agent::new( - Arc::::clone(&rdma.qp), - Arc::::clone(&rdma.allocator), - max_sr_data_len, - )?); - rdma.agent = Some(agent); - // wait for remote agent to initialize - tokio::time::sleep(Duration::from_secs(1)).await; - } + #[allow(clippy::unreachable)] + let max_message_length = self.agent.as_ref().map_or_else( + || { + unreachable!("agent of parent rdma is None"); + }, + |agent| agent.max_sr_data_len, + ); + rdma.init_agent(max_message_length).await?; Ok(rdma) } ConnectionType::RCCM => Err(io::Error::new( @@ -2088,31 +2045,9 @@ impl Rdma { rdma.conn_type == ConnectionType::RCSocket, "should set connection type to RCSocket" ); - let mut stream = TcpStream::connect(addr).await?; - let mut endpoint = bincode::serialize(&rdma.endpoint()).map_err(|e| { - io::Error::new( - io::ErrorKind::InvalidInput, - format!("failed to serailize the endpoint, {:?}", e), - ) - })?; - stream.write_all(&endpoint).await?; - // the byte number is not important, as read_exact will fill the buffer - let _ = stream.read_exact(endpoint.as_mut()).await?; - let remote: QueuePairEndpoint = bincode::deserialize(&endpoint).map_err(|e| { - io::Error::new( - io::ErrorKind::InvalidInput, - format!("failed to deserailize the endpoint, {:?}", e), - ) - })?; + let remote = tcp_connect_helper(addr, &rdma.endpoint()).await?; rdma.qp_handshake(remote)?; - if !rdma.raw { - let agent = Arc::new(Agent::new( - Arc::::clone(&rdma.qp), - Arc::::clone(&rdma.allocator), - max_message_length, - )?); - rdma.agent = Some(agent); - } + rdma.init_agent(max_message_length).await?; // wait for server to initialize tokio::time::sleep(Duration::from_secs(1)).await; Ok(rdma) @@ -2221,82 +2156,8 @@ impl Rdma { rdma.conn_type == ConnectionType::RCCM, "should set connection type to RCSocket" ); - - // SAFETY: POD FFI type - let mut hints = unsafe { std::mem::zeroed::() }; - let mut info: *mut rdma_addrinfo = null_mut(); - hints.ai_port_space = rdma_port_space::RDMA_PS_TCP.cast(); - // Safety: ffi - let mut ret = unsafe { - rdma_getaddrinfo( - node.as_ptr().cast(), - service.as_ptr().cast(), - &hints, - &mut info, - ) - }; - if ret != 0_i32 { - return Err(log_ret_last_os_err()); - } - - let mut id: *mut rdma_cm_id = null_mut(); - // Safety: ffi - ret = unsafe { rdma_create_ep(&mut id, info, rdma.pd.as_ptr(), null_mut()) }; - if ret != 0_i32 { - // Safety: ffi - unsafe { - rdma_freeaddrinfo(info); - } - return Err(log_ret_last_os_err()); - } - // Safety: id was initialized by `rdma_create_ep` - unsafe { - debug!( - "cm_id: {:?},{:?},{:?},{:?},{:?},{:?},{:?}", - (*id).qp, - (*id).pd, - (*id).verbs, - (*id).recv_cq_channel, - (*id).send_cq_channel, - (*id).recv_cq, - (*id).send_cq - ); - (*id).qp = rdma.qp.as_ptr(); - (*id).pd = rdma.pd.as_ptr(); - (*id).verbs = rdma.ctx.as_ptr(); - (*id).recv_cq_channel = rdma.qp.event_listener.cq.event_channel().as_ptr(); - (*id).recv_cq_channel = rdma.qp.event_listener.cq.event_channel().as_ptr(); - (*id).recv_cq = rdma.qp.event_listener.cq.as_ptr(); - (*id).send_cq = rdma.qp.event_listener.cq.as_ptr(); - debug!( - "cm_id: {:?},{:?},{:?},{:?},{:?},{:?},{:?}", - (*id).qp, - (*id).pd, - (*id).verbs, - (*id).recv_cq_channel, - (*id).send_cq_channel, - (*id).recv_cq, - (*id).send_cq - ); - } - - // Safety: ffi - ret = unsafe { rdma_connect(id, null_mut()) }; - if ret != 0_i32 { - // Safety: ffi - unsafe { - let _ = rdma_disconnect(id); - } - return Err(log_ret_last_os_err()); - } - if !rdma.raw { - let agent = Arc::new(Agent::new( - Arc::::clone(&rdma.qp), - Arc::::clone(&rdma.allocator), - max_message_length, - )?); - rdma.agent = Some(agent); - } + cm_connect_helper(&mut rdma, node, service)?; + rdma.init_agent(max_message_length).await?; // wait for server to initialize tokio::time::sleep(Duration::from_secs(1)).await; Ok(rdma) @@ -3105,16 +2966,7 @@ impl RdmaListener { stream.write_all(&local).await?; rdma.qp_handshake(remote)?; debug!("handshake done"); - if !rdma.raw { - let agent = Arc::new(Agent::new( - Arc::::clone(&rdma.qp), - Arc::::clone(&rdma.allocator), - max_message_length, - )?); - rdma.agent = Some(agent); - // wait for the remote agent to prepare - tokio::time::sleep(Duration::from_secs(1)).await; - } + rdma.init_agent(max_message_length).await?; Ok(rdma) } } From 56d848d43b8d013be27d8b93bf65e7504c68ba18 Mon Sep 17 00:00:00 2001 From: why Date: Sun, 21 Aug 2022 19:31:56 +0800 Subject: [PATCH 04/14] add `CloneAttr` for `Rdma` Add `CloneAttr` to define the behaviors that occur during `clone` before `listen` or `new_connect`. --- src/lib.rs | 193 +++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 173 insertions(+), 20 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 906e0e1..72ba35c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -633,7 +633,7 @@ impl RdmaBuilder { rdma.qp_handshake(remote)?; debug!("handshake done"); rdma.init_agent(self.agent_attr.max_message_length).await?; - rdma.tcp_listener = Arc::new(Mutex::new(Some(tcp_listener))); + rdma.clone_attr = CloneAttr::default().set_tcp_listener(tcp_listener); Ok(rdma) } ConnectionType::RCCM => Err(io::Error::new( @@ -905,14 +905,53 @@ pub enum ConnectionType { RCCM, } +/// Attributes for creating new `Rdma`s through `clone` +#[derive(Debug, Clone, Default)] +pub(crate) struct CloneAttr { + /// Tcp listener used for new connections + pub(crate) tcp_listener: Option>>, + /// Clone `Rdma` with new `ProtectionDomain` + #[allow(dead_code)] // not yet fully implemented + pub(crate) pd: Option>, + /// Clone `Rdma` with new `Port number` + pub(crate) port_num: Option, + /// Clone `Rdma` with new qp access + pub(crate) qp_access: Option, +} + +impl CloneAttr { + /// Set `TcpListener` + fn set_tcp_listener(mut self, tcp_listener: TcpListener) -> Self { + self.tcp_listener = Some(Arc::new(Mutex::new(tcp_listener))); + self + } + + /// Set `ProtectionDomain` + #[allow(dead_code)] // not yet fully implemented + fn set_pd(mut self, pd: ProtectionDomain) -> Self { + self.pd = Some(Arc::new(pd)); + self + } + + /// Set port number + fn set_port_num(mut self, port_num: u8) -> Self { + self.port_num = Some(port_num); + self + } + + /// Set qp access + fn set_qp_access(mut self, access: ibv_access_flags) -> Self { + self.qp_access = Some(access); + self + } +} + /// Rdma handler, the only interface that the users deal with rdma #[derive(Debug)] pub struct Rdma { /// device context - #[allow(dead_code)] ctx: Arc, /// protection domain - #[allow(dead_code)] pd: Arc, /// Memory region allocator allocator: Arc, @@ -924,8 +963,8 @@ pub struct Rdma { conn_type: ConnectionType, /// If send/recv raw data raw: bool, - /// Tcp listener used for new connections - tcp_listener: Arc>>, + /// Attributes for creating new `Rdma`s through `clone` + clone_attr: CloneAttr, } impl Rdma { @@ -969,21 +1008,34 @@ impl Rdma { allocator, conn_type: qp_attr.conn_type, raw: qp_attr.raw, - tcp_listener: Arc::new(Mutex::new(None)), + clone_attr: CloneAttr::default(), }) } /// Create a new `Rdma` that has the same `mr_allocator` and `event_listener` as parent. fn clone(&self) -> io::Result { - let access = self.qp.access.map_or_else( + let qp_access = self.clone_attr.qp_access.map_or_else( || { - Err(io::Error::new( - io::ErrorKind::Other, - "parent qp access is none", - )) + self.qp.access.map_or_else( + || { + Err(io::Error::new( + io::ErrorKind::Other, + "parent qp access is none", + )) + }, + Ok, + ) }, Ok, )?; + let port_num = self.clone_attr.port_num.unwrap_or(self.qp.port_num); + // TODO: add multi-pd support for allocator + let pd = self + .clone_attr + .pd + .as_ref() + .map_or_else(|| &self.pd, |pd| pd); + let mut qp_init_attr = self.qp.qp_init_attr.clone(); // SAFETY: ffi let inner_qp = NonNull::new(unsafe { @@ -991,24 +1043,24 @@ impl Rdma { }) .ok_or_else(log_ret_last_os_err)?; let mut qp = QueuePair { - pd: Arc::clone(&self.pd), + pd: Arc::clone(pd), event_listener: Arc::clone(&self.qp.event_listener), inner_qp, - port_num: self.qp.port_num, + port_num, gid_index: self.qp.gid_index, qp_init_attr, - access: self.qp.access, + access: Some(qp_access), }; - qp.modify_to_init(access, self.qp.port_num)?; + qp.modify_to_init(qp_access, self.qp.port_num)?; Ok(Self { ctx: Arc::clone(&self.ctx), - pd: Arc::clone(&self.pd), + pd: Arc::clone(pd), qp: Arc::new(qp), agent: None, allocator: Arc::clone(&self.allocator), conn_type: self.conn_type, raw: self.raw, - tcp_listener: Arc::clone(&self.tcp_listener), + clone_attr: self.clone_attr.clone(), }) } @@ -1089,13 +1141,17 @@ impl Rdma { ConnectionType::RCSocket => { let mut rdma = self.clone()?; let remote = self + .clone_attr .tcp_listener - .lock() - .await .as_ref() .map_or_else( || Err(io::Error::new(io::ErrorKind::Other, "tcp_listener is None")), - |listener| Ok(async { tcp_listen(listener, &rdma.endpoint()).await }), + |tcp_listener| { + Ok(async { + let tcp_listener = tcp_listener.lock().await; + tcp_listen(&tcp_listener, &rdma.endpoint()).await + }) + }, )? .await?; rdma.qp_handshake(remote)?; @@ -2841,6 +2897,103 @@ impl Rdma { )) } } + + /// Set qp access for new `Rdma` that created by `clone` + /// + /// Used with `listen`, `new_connect` + /// + /// # Examples + /// + /// ``` + /// use async_rdma::{AccessFlag, RdmaBuilder}; + /// use portpicker::pick_unused_port; + /// use std::{ + /// io, + /// net::{Ipv4Addr, SocketAddrV4}, + /// time::Duration, + /// }; + /// + /// async fn client(addr: SocketAddrV4) -> io::Result<()> { + /// let rdma = RdmaBuilder::default().connect(addr).await?; + /// let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; + /// let rdma = rdma.set_new_qp_access(access); + /// let _new_rdma = rdma.new_connect(addr).await?; + /// Ok(()) + /// } + /// + /// #[tokio::main] + /// async fn server(addr: SocketAddrV4) -> io::Result<()> { + /// let rdma = RdmaBuilder::default().listen(addr).await?; + /// let _new_rdma = rdma.listen().await?; + /// Ok(()) + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), pick_unused_port().unwrap()); + /// std::thread::spawn(move || server(addr)); + /// tokio::time::sleep(Duration::from_secs(3)).await; + /// client(addr) + /// .await + /// .map_err(|err| println!("{}", err)) + /// .unwrap(); + /// } + /// ``` + #[inline] + #[must_use] + pub fn set_new_qp_access(mut self, qp_access: BitFlags) -> Self { + self.clone_attr = self + .clone_attr + .set_qp_access(flags_into_ibv_access(qp_access)); + self + } + + /// Set qp access for new `Rdma` that created by `clone` + /// + /// Used with `listen`, `new_connect` + /// + /// # Examples + /// + /// ``` + /// use async_rdma::RdmaBuilder; + /// use portpicker::pick_unused_port; + /// use std::{ + /// io, + /// net::{Ipv4Addr, SocketAddrV4}, + /// time::Duration, + /// }; + /// + /// async fn client(addr: SocketAddrV4) -> io::Result<()> { + /// let rdma = RdmaBuilder::default().connect(addr).await?; + /// let rdma = rdma.set_new_port_num(1_u8); + /// let _new_rdma = rdma.new_connect(addr).await?; + /// Ok(()) + /// } + /// + /// #[tokio::main] + /// async fn server(addr: SocketAddrV4) -> io::Result<()> { + /// let rdma = RdmaBuilder::default().listen(addr).await?; + /// let _new_rdma = rdma.listen().await?; + /// Ok(()) + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), pick_unused_port().unwrap()); + /// std::thread::spawn(move || server(addr)); + /// tokio::time::sleep(Duration::from_secs(3)).await; + /// client(addr) + /// .await + /// .map_err(|err| println!("{}", err)) + /// .unwrap(); + /// } + /// ``` + #[inline] + #[must_use] + pub fn set_new_port_num(mut self, port_num: u8) -> Self { + self.clone_attr = self.clone_attr.set_port_num(port_num); + self + } } /// Rdma Listener is the wrapper of a `TcpListener`, which is used to From 99b67a5e49195e81ed00d693adc6783108ccbb00 Mon Sep 17 00:00:00 2001 From: why Date: Tue, 23 Aug 2022 21:42:22 +0800 Subject: [PATCH 05/14] disable `tcache` by default `tcache` is a feature of `Jemalloc` to speed up memory allocation. However `Jemalloc` may alloc `MR` with wrong `arena_index` from `tcache` when we create more than one `Jemalloc` enabled `mr_allocator`s. So we disable `tcache` by default. --- .cargo/config.toml | 9 ++++++++- src/mr_allocator.rs | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index c9dec75..ba80378 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -2,4 +2,11 @@ # `narenas` is the maximum number of arenas to use for automatic multiplexing of threads and arenas. # The default is four times the number of CPUs, or one if there is a single CPU. # `async-rdma` doesn't need that many antomatic arenas so we set it to 1. -JEMALLOC_SYS_WITH_MALLOC_CONF = "narenas:1" + +# `tcache` is a feature of `Jemalloc` to speed up memory allocation. +# However `Jemalloc` may alloc `MR` with wrong `arena_index` from `tcache` +# when we create more than one `Jemalloc` enabled `mr_allocator`s. +# So we disable `tcache` by default. +# If you want to enable `tcache` and make sure safety by yourself, change +# `JEMALLOC_SYS_WITH_MALLOC_CONF` from `tcache:false` to `tcache:true`. +JEMALLOC_SYS_WITH_MALLOC_CONF = "narenas:1,tcache:false" diff --git a/src/mr_allocator.rs b/src/mr_allocator.rs index 28560cf..5d237f1 100644 --- a/src/mr_allocator.rs +++ b/src/mr_allocator.rs @@ -974,4 +974,19 @@ mod tests { } Ok(()) } + + /// `Jemalloc` may alloc `MR` with wrong `arena_index` from `tcache` when we + /// create more than one `Jemalloc` enabled `mr_allocator`s. + /// So we disable `tcache` by default and make this test. + /// If you want to enable `tcache` and make sure safety by yourself, change + /// `JEMALLOC_SYS_WITH_MALLOC_CONF` from `tcache:false` to `tcache:true`. + #[tokio::test] + async fn multi_je_allocator() -> io::Result<()> { + let rdma_1 = RdmaBuilder::default().build()?; + let layout = Layout::new::(); + let _mr_1 = rdma_1.alloc_local_mr(layout); + let rdma_2 = RdmaBuilder::default().build()?; + let _mr_2 = rdma_2.alloc_local_mr(layout); + Ok(()) + } } From 7ad3049473cc3ebac6d0c57ee39ff6fee9bbdb7c Mon Sep 17 00:00:00 2001 From: why Date: Wed, 24 Aug 2022 20:51:40 +0800 Subject: [PATCH 06/14] add multi-pd support for allocator Prepare for multi-connection with multi-pd. --- src/lib.rs | 5 +- src/memory_region/local.rs | 7 +++ src/memory_region/raw.rs | 6 +++ src/mr_allocator.rs | 98 ++++++++++++++++++++++++++++++-------- src/protection_domain.rs | 2 +- 5 files changed, 94 insertions(+), 24 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 72ba35c..7fe1d89 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2374,7 +2374,7 @@ impl Rdma { access: BitFlags, ) -> io::Result { self.allocator - .alloc_zeroed(&layout, flags_into_ibv_access(access)) + .alloc_zeroed(&layout, flags_into_ibv_access(access), &self.pd) } /// Allocate a local memory region with specified access that has not been initialized @@ -2414,7 +2414,8 @@ impl Rdma { layout: Layout, access: BitFlags, ) -> io::Result { - self.allocator.alloc(&layout, flags_into_ibv_access(access)) + self.allocator + .alloc(&layout, flags_into_ibv_access(access), &self.pd) } /// Request a remote memory region with default timeout value. diff --git a/src/memory_region/local.rs b/src/memory_region/local.rs index b79ad03..e059359 100644 --- a/src/memory_region/local.rs +++ b/src/memory_region/local.rs @@ -1,6 +1,7 @@ use super::{raw::RawMemoryRegion, MrAccess, MrToken}; use crate::{ lock_utilities::{MappedRwLockReadGuard, MappedRwLockWriteGuard}, + protection_domain::ProtectionDomain, MRManageStrategy, }; use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; @@ -540,6 +541,12 @@ impl LocalMrInner { fn lkey(&self) -> u32 { self.raw.lkey() } + + /// Get pd of this memory region + #[allow(dead_code)] // used by test + pub(crate) fn pd(&self) -> &Arc { + self.raw.pd() + } } impl MrAccess for &LocalMr { diff --git a/src/memory_region/raw.rs b/src/memory_region/raw.rs index 25fef88..db0e0b0 100644 --- a/src/memory_region/raw.rs +++ b/src/memory_region/raw.rs @@ -76,6 +76,12 @@ impl RawMemoryRegion { // TODO: check safety unsafe { self.inner_mr.as_ref().lkey } } + + /// Get pd of this memory region + #[allow(dead_code)] // used by test + pub(crate) fn pd(&self) -> &Arc { + &self.pd + } } impl Debug for RawMemoryRegion { diff --git a/src/mr_allocator.rs b/src/mr_allocator.rs index 5d237f1..ce9fd58 100644 --- a/src/mr_allocator.rs +++ b/src/mr_allocator.rs @@ -72,12 +72,24 @@ static mut RDMA_EXTENT_HOOKS: extent_hooks_t = extent_hooks_t { merge: Some(RDMA_MERGE_EXTENT_HOOK), }; +/// Combination of `ProtectionDomain`'s ptr and `ibv_access_falgs` as a key. +#[derive(PartialEq, Eq, Hash)] +pub(crate) struct AccessPDKey(u128); + +impl AccessPDKey { + /// Cast the `ProtectionDomain`'s ptr and `ibv_access_flags` into a `u128` + #[allow(clippy::as_conversions)] + pub(crate) fn new(pd: &Arc, access: ibv_access_flags) -> Self { + Self((u128::from(access.0)).wrapping_shl(64) | (pd.as_ptr() as u128)) + } +} + /// The map that records correspondence between extent metadata and `raw_mr` type ExtentTokenMap = Arc>>; /// The map that records correspondence between `arena_ind` and `ProtectionDomain` type ArenaPdMap = Arc, ibv_access_flags)>>>; -/// The map that records correspondence between `ibv_access_flags` and `arena_ind` -type AccessArenaMap = Arc>>; +/// The map that records correspondence between (`ibv_access_flags`, &pd) and `arena_ind` +type AccessPDArenaMap = Arc>>; lazy_static! { /// Default extent hooks of jemalloc @@ -88,7 +100,7 @@ lazy_static! { /// The correspondence between `arena_ind` and `ProtectionDomain` pub(crate) static ref ARENA_PD_MAP: ArenaPdMap = Arc::new(Mutex::new(HashMap::new())); /// The correspondence between `ibv_access_flags` and `arena_ind` - pub(crate) static ref ACCESS_ARENA_MAP: AccessArenaMap = Arc::new(Mutex::new(HashMap::new())); + pub(crate) static ref ACCESS_PD_ARENA_MAP: AccessPDArenaMap = Arc::new(Mutex::new(HashMap::new())); } /// Combination between extent metadata and `raw_mr` @@ -117,7 +129,7 @@ pub enum MRManageStrategy { #[derive(Debug)] pub(crate) struct MrAllocator { /// Protection domain that holds the allocator - pd: Arc, + default_pd: Arc, /// Default arena index default_arena_ind: u32, /// Strategy to manage `MR`s @@ -137,7 +149,7 @@ impl MrAllocator { init_je_statics(Arc::::clone(&pd), mr_attr.access) .expect("init je statics failed"); Self { - pd, + default_pd: pd, default_arena_ind: arena_ind, strategy: mr_attr.strategy, default_access: mr_attr.access, @@ -146,7 +158,7 @@ impl MrAllocator { MRManageStrategy::Raw => { debug!("new mr_allocator using raw strategy"); Self { - pd, + default_pd: pd, default_arena_ind: 0_u32, strategy: mr_attr.strategy, default_access: mr_attr.access, @@ -166,8 +178,9 @@ impl MrAllocator { self: &Arc, layout: &Layout, access: ibv_access_flags, + pd: &Arc, ) -> io::Result { - let inner = self.alloc_inner(layout, 0, access)?; + let inner = self.alloc_inner(layout, 0, access, pd)?; Ok(LocalMr::new(inner)) } @@ -177,9 +190,10 @@ impl MrAllocator { self: &Arc, layout: &Layout, access: ibv_access_flags, + pd: &Arc, ) -> io::Result { // SAFETY: alloc zeroed memory is safe - let inner = unsafe { self.alloc_inner(layout, MALLOCX_ZERO, access)? }; + let inner = unsafe { self.alloc_inner(layout, MALLOCX_ZERO, access, pd)? }; Ok(LocalMr::new(inner)) } @@ -191,13 +205,13 @@ impl MrAllocator { /// Initialize it before using to make it safe. #[allow(clippy::as_conversions)] pub(crate) unsafe fn alloc_default(self: &Arc, layout: &Layout) -> io::Result { - self.alloc(layout, self.default_access) + self.alloc(layout, self.default_access, &self.default_pd) } /// Allocate a `LocalMr` according to the `layout` with default access #[allow(clippy::as_conversions)] pub(crate) fn alloc_zeroed_default(self: &Arc, layout: &Layout) -> io::Result { - self.alloc_zeroed(layout, self.default_access) + self.alloc_zeroed(layout, self.default_access, &self.default_pd) } /// Allocate a `LocalMrInner` according to the `layout` @@ -211,6 +225,7 @@ impl MrAllocator { layout: &Layout, flag: i32, access: ibv_access_flags, + pd: &Arc, ) -> io::Result { // check to ensure the safety requirements of `slice` methords if layout.size() > isize::MAX.cast() { @@ -222,7 +237,7 @@ impl MrAllocator { match self.strategy { MRManageStrategy::Jemalloc => { - if access == self.default_access{ + if access == self.default_access && pd.inner_pd == self.default_pd.inner_pd { alloc_from_je(self.default_arena_ind, layout, flag).map_or_else(||{ Err(io::Error::new(io::ErrorKind::OutOfMemory, "insufficient contiguous memory was available to service the allocation request")) }, |addr|{ @@ -236,11 +251,12 @@ impl MrAllocator { Ok(LocalMrInner::new(addr as usize, *layout, raw_mr, self.strategy)) }) }else { - let arena_id = ACCESS_ARENA_MAP.lock().get(&access).copied(); + let access_pd_key = AccessPDKey::new(pd, access); + let arena_id = ACCESS_PD_ARENA_MAP.lock().get(&access_pd_key).copied(); let arena_id = arena_id.map_or_else(||{ // didn't use this access before, so create an arena to mange this kind of MR - let ind = init_je_statics(Arc::::clone(&self.pd), access)?; - if ACCESS_ARENA_MAP.lock().insert(access, ind).is_some(){ + let ind = init_je_statics(Arc::::clone(pd), access)?; + if ACCESS_PD_ARENA_MAP.lock().insert(access_pd_key, ind).is_some(){ return Err(io::Error::new( io::ErrorKind::Other, format!("this is a bug: insert ACCESS_ARENA_MAP failed, access:{:?} has been recorded", access), @@ -267,7 +283,7 @@ impl MrAllocator { alloc_raw_mem(layout, flag).map_or_else(||{ Err(io::Error::new(io::ErrorKind::OutOfMemory, "insufficient contiguous memory was available to service the allocation request")) }, |addr|{ - let raw_mr = Arc::new(RawMemoryRegion::register_from_pd(&self.pd, addr, layout.size(), access)?); + let raw_mr = Arc::new(RawMemoryRegion::register_from_pd(pd, addr, layout.size(), access)?); Ok(LocalMrInner::new(addr as usize, *layout, raw_mr, self.strategy)) }) }, @@ -281,12 +297,12 @@ impl MrAllocator { /// /// This func is safe when `zeroed == true`, otherwise newly allocated memory is uninitialized. unsafe fn alloc_from_je(arena_ind: u32, layout: &Layout, flag: i32) -> Option<*mut u8> { - let addr = { - tikv_jemalloc_sys::mallocx( - layout.size(), - (MALLOCX_ALIGN(layout.align()) | MALLOCX_ARENA(arena_ind.cast()) | flag).cast(), - ) - }; + let flags = (MALLOCX_ALIGN(layout.align()) | MALLOCX_ARENA(arena_ind.cast()) | flag).cast(); + debug!( + "alloc mr from je, arena_ind: {:?}, layout: {:?}, flags: {:?}", + arena_ind, layout, flags + ); + let addr = { tikv_jemalloc_sys::mallocx(layout.size(), flags) }; if addr.is_null() { None } else { @@ -989,4 +1005,44 @@ mod tests { let _mr_2 = rdma_2.alloc_local_mr(layout); Ok(()) } + + #[test] + fn alloc_raw_multi_pd_mr() -> io::Result<()> { + let ctx = Arc::new(Context::open(None, 1, 1)?); + let pd_1 = Arc::new(ctx.create_protection_domain()?); + let pd_2 = Arc::new(ctx.create_protection_domain()?); + let mr_attr = MRInitAttr { + access: *DEFAULT_ACCESS, + strategy: MRManageStrategy::Raw, + }; + let allocator = Arc::new(MrAllocator::new(Arc::clone(&pd_1), mr_attr)); + let layout = Layout::new::(); + let mr_1 = allocator.alloc_zeroed(&layout, *DEFAULT_ACCESS, &pd_1)?; + println!("{:?}", mr_1); + assert_eq!(&mr_1.read_inner().pd().inner_pd, &pd_1.inner_pd); + let mr_2 = allocator.alloc_zeroed(&layout, *DEFAULT_ACCESS, &pd_2)?; + println!("{:?}", mr_2); + assert_eq!(&mr_2.read_inner().pd().inner_pd, &pd_2.inner_pd); + Ok(()) + } + + #[test] + fn alloc_multi_pd_mr_from_je() -> io::Result<()> { + let ctx = Arc::new(Context::open(None, 1, 1)?); + let pd_1 = Arc::new(ctx.create_protection_domain()?); + let pd_2 = Arc::new(ctx.create_protection_domain()?); + let mr_attr = MRInitAttr { + access: *DEFAULT_ACCESS, + strategy: MRManageStrategy::Jemalloc, + }; + let allocator = Arc::new(MrAllocator::new(Arc::clone(&pd_1), mr_attr)); + let layout = Layout::new::(); + let mr_1 = allocator.alloc_zeroed(&layout, *DEFAULT_ACCESS, &pd_1)?; + println!("{:?}", mr_1); + assert_eq!(&mr_1.read_inner().pd().inner_pd, &pd_1.inner_pd); + let mr_2 = allocator.alloc_zeroed(&layout, *DEFAULT_ACCESS, &pd_2)?; + println!("{:?}", mr_2); + assert_eq!(&mr_2.read_inner().pd().inner_pd, &pd_2.inner_pd); + Ok(()) + } } diff --git a/src/protection_domain.rs b/src/protection_domain.rs index a5d3ad1..8b37485 100644 --- a/src/protection_domain.rs +++ b/src/protection_domain.rs @@ -12,7 +12,7 @@ pub(crate) struct ProtectionDomain { /// The device context pub(crate) ctx: Arc, /// Internal `ibv_pd` pointer - inner_pd: NonNull, + pub(crate) inner_pd: NonNull, } impl ProtectionDomain { From cb19b228d8acd26d4241a04ebfd3d3576d1e7dba Mon Sep 17 00:00:00 2001 From: why Date: Thu, 25 Aug 2022 21:45:34 +0800 Subject: [PATCH 07/14] add multi-pd support for multi-connection Set new `ProtectionDomain` for new `Rdma` to provide isolation. --- src/lib.rs | 56 ++++++++++++++++++-- src/memory_region/local.rs | 6 ++- src/memory_region/raw.rs | 2 +- src/mr_allocator.rs | 102 ++++++++++++++++++++++++++++++++++--- 4 files changed, 151 insertions(+), 15 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 7fe1d89..78ad99e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -911,7 +911,6 @@ pub(crate) struct CloneAttr { /// Tcp listener used for new connections pub(crate) tcp_listener: Option>>, /// Clone `Rdma` with new `ProtectionDomain` - #[allow(dead_code)] // not yet fully implemented pub(crate) pd: Option>, /// Clone `Rdma` with new `Port number` pub(crate) port_num: Option, @@ -969,7 +968,6 @@ pub struct Rdma { impl Rdma { /// create a new `Rdma` instance - #[allow(clippy::too_many_arguments)] // TODO: fix with builder pattern fn new( dev_attr: &DeviceInitAttr, cq_attr: CQInitAttr, @@ -1029,7 +1027,6 @@ impl Rdma { Ok, )?; let port_num = self.clone_attr.port_num.unwrap_or(self.qp.port_num); - // TODO: add multi-pd support for allocator let pd = self .clone_attr .pd @@ -2276,7 +2273,8 @@ impl Rdma { /// ``` #[inline] pub fn alloc_local_mr(&self, layout: Layout) -> io::Result { - self.allocator.alloc_zeroed_default(&layout) + self.allocator + .alloc_zeroed_default_access(&layout, &self.pd) } /// Allocate a local memory region that has not been initialized @@ -2341,7 +2339,7 @@ impl Rdma { /// ``` #[inline] pub unsafe fn alloc_local_mr_uninit(&self, layout: Layout) -> io::Result { - self.allocator.alloc_default(&layout) + self.allocator.alloc_default_access(&layout, &self.pd) } /// Allocate a local memory region with specified access @@ -2995,6 +2993,54 @@ impl Rdma { self.clone_attr = self.clone_attr.set_port_num(port_num); self } + + /// Set new `ProtectionDomain` for new `Rdma` that created by `clone` to provide isolation. + /// + /// Used with `listen`, `new_connect` + /// + /// # Examples + /// + /// ``` + /// use async_rdma::RdmaBuilder; + /// use portpicker::pick_unused_port; + /// use std::{ + /// io, + /// net::{Ipv4Addr, SocketAddrV4}, + /// time::Duration, + /// }; + /// + /// async fn client(addr: SocketAddrV4) -> io::Result<()> { + /// let rdma = RdmaBuilder::default().connect(addr).await?; + /// let rdma = rdma.set_new_pd()?; + /// // then the `Rdma`s created by `new_connect` will have a new `ProtectionDomain` + /// let _new_rdma = rdma.new_connect(addr).await?; + /// Ok(()) + /// } + /// + /// #[tokio::main] + /// async fn server(addr: SocketAddrV4) -> io::Result<()> { + /// let rdma = RdmaBuilder::default().listen(addr).await?; + /// let _new_rdma = rdma.listen().await?; + /// Ok(()) + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), pick_unused_port().unwrap()); + /// std::thread::spawn(move || server(addr)); + /// tokio::time::sleep(Duration::from_secs(3)).await; + /// client(addr) + /// .await + /// .map_err(|err| println!("{}", err)) + /// .unwrap(); + /// } + /// ``` + #[inline] + pub fn set_new_pd(mut self) -> io::Result { + let new_pd = self.ctx.create_protection_domain()?; + self.clone_attr = self.clone_attr.set_pd(new_pd); + Ok(self) + } } /// Rdma Listener is the wrapper of a `TcpListener`, which is used to diff --git a/src/memory_region/local.rs b/src/memory_region/local.rs index e059359..f9e637a 100644 --- a/src/memory_region/local.rs +++ b/src/memory_region/local.rs @@ -1,9 +1,11 @@ use super::{raw::RawMemoryRegion, MrAccess, MrToken}; +#[cfg(test)] +use crate::protection_domain::ProtectionDomain; use crate::{ lock_utilities::{MappedRwLockReadGuard, MappedRwLockWriteGuard}, - protection_domain::ProtectionDomain, MRManageStrategy, }; + use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use rdma_sys::ibv_access_flags; use sealed::sealed; @@ -543,7 +545,7 @@ impl LocalMrInner { } /// Get pd of this memory region - #[allow(dead_code)] // used by test + #[cfg(test)] pub(crate) fn pd(&self) -> &Arc { self.raw.pd() } diff --git a/src/memory_region/raw.rs b/src/memory_region/raw.rs index db0e0b0..1ef5a3b 100644 --- a/src/memory_region/raw.rs +++ b/src/memory_region/raw.rs @@ -78,7 +78,7 @@ impl RawMemoryRegion { } /// Get pd of this memory region - #[allow(dead_code)] // used by test + #[cfg(test)] pub(crate) fn pd(&self) -> &Arc { &self.pd } diff --git a/src/mr_allocator.rs b/src/mr_allocator.rs index ce9fd58..76932a5 100644 --- a/src/mr_allocator.rs +++ b/src/mr_allocator.rs @@ -197,6 +197,12 @@ impl MrAllocator { Ok(LocalMr::new(inner)) } + /// Allocate a `LocalMr` according to the `layout` with default access and pd + #[allow(clippy::as_conversions)] + pub(crate) fn alloc_zeroed_default(self: &Arc, layout: &Layout) -> io::Result { + self.alloc_zeroed(layout, self.default_access, &self.default_pd) + } + /// Allocate an uninitialized `LocalMr` according to the `layout` with default access /// /// # Safety @@ -204,14 +210,22 @@ impl MrAllocator { /// The newly allocated memory in this `LocalMr` is uninitialized. /// Initialize it before using to make it safe. #[allow(clippy::as_conversions)] - pub(crate) unsafe fn alloc_default(self: &Arc, layout: &Layout) -> io::Result { - self.alloc(layout, self.default_access, &self.default_pd) + pub(crate) unsafe fn alloc_default_access( + self: &Arc, + layout: &Layout, + pd: &Arc, + ) -> io::Result { + self.alloc(layout, self.default_access, pd) } /// Allocate a `LocalMr` according to the `layout` with default access #[allow(clippy::as_conversions)] - pub(crate) fn alloc_zeroed_default(self: &Arc, layout: &Layout) -> io::Result { - self.alloc_zeroed(layout, self.default_access, &self.default_pd) + pub(crate) fn alloc_zeroed_default_access( + self: &Arc, + layout: &Layout, + pd: &Arc, + ) -> io::Result { + self.alloc_zeroed(layout, self.default_access, pd) } /// Allocate a `LocalMrInner` according to the `layout` @@ -777,7 +791,6 @@ mod tests { #[test] #[allow(clippy::unwrap_used)] fn test_extent_hooks() -> io::Result<()> { - tracing_subscriber::fmt::init(); let ctx = Arc::new(Context::open(None, 1, 1)?); let pd = Arc::new(ctx.create_protection_domain()?); let allocator = Arc::new(MrAllocator::new(pd, MRInitAttr::default())); @@ -1038,11 +1051,86 @@ mod tests { let allocator = Arc::new(MrAllocator::new(Arc::clone(&pd_1), mr_attr)); let layout = Layout::new::(); let mr_1 = allocator.alloc_zeroed(&layout, *DEFAULT_ACCESS, &pd_1)?; - println!("{:?}", mr_1); assert_eq!(&mr_1.read_inner().pd().inner_pd, &pd_1.inner_pd); let mr_2 = allocator.alloc_zeroed(&layout, *DEFAULT_ACCESS, &pd_2)?; - println!("{:?}", mr_2); assert_eq!(&mr_2.read_inner().pd().inner_pd, &pd_2.inner_pd); Ok(()) } } + +/// Test `LocalMr` APIs with multi-pd & multi-connection +#[cfg(test)] +mod mr_with_multi_pd_test { + use crate::AccessFlag; + use crate::{LocalMrReadAccess, RdmaBuilder}; + use portpicker::pick_unused_port; + use std::{ + alloc::Layout, + io, + net::{Ipv4Addr, SocketAddrV4}, + time::Duration, + }; + + async fn client(addr: SocketAddrV4) -> io::Result<()> { + let rdma = RdmaBuilder::default().connect(addr).await?; + let rdma = rdma.set_new_pd()?; + let layout = Layout::new::(); + // then the `Rdma`s created by `new_connect` will have a new `ProtectionDomain` + let new_rdma = rdma.new_connect(addr).await?; + + let mr_1 = rdma.alloc_local_mr(layout)?; + let mr_2 = new_rdma.alloc_local_mr(layout)?; + assert_ne!( + &mr_1.read_inner().pd().inner_pd, + &mr_2.read_inner().pd().inner_pd + ); + + // SAFETY: test without memory access + let mr_3 = unsafe { rdma.alloc_local_mr_uninit(layout)? }; + // SAFETY: test without memory access + let mr_4 = unsafe { new_rdma.alloc_local_mr_uninit(layout)? }; + assert_ne!( + &mr_3.read_inner().pd().inner_pd, + &mr_4.read_inner().pd().inner_pd + ); + + let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead | AccessFlag::RemoteWrite; + let mr_5 = rdma.alloc_local_mr_with_access(layout, access)?; + let mr_6 = new_rdma.alloc_local_mr_with_access(layout, access)?; + assert_ne!( + &mr_5.read_inner().pd().inner_pd, + &mr_6.read_inner().pd().inner_pd + ); + + // SAFETY: test without memory access + let mr_7 = unsafe { rdma.alloc_local_mr_uninit_with_access(layout, access)? }; + // SAFETY: test without memory access + let mr_8 = unsafe { new_rdma.alloc_local_mr_uninit_with_access(layout, access)? }; + assert_ne!( + &mr_7.read_inner().pd().inner_pd, + &mr_8.read_inner().pd().inner_pd + ); + + Ok(()) + } + + #[tokio::main] + async fn server(addr: SocketAddrV4) -> io::Result<()> { + let rdma = RdmaBuilder::default().listen(addr).await?; + let _new_rdma = rdma.listen().await?; + Ok(()) + } + + #[tokio::test] + #[allow(clippy::unwrap_used)] + async fn main() { + let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), pick_unused_port().unwrap()); + let server_handle = std::thread::spawn(move || server(addr)); + tokio::time::sleep(Duration::from_secs(3)).await; + client(addr) + .await + .map_err(|err| println!("{}", err)) + .unwrap(); + server_handle.join().unwrap().unwrap(); + } +} From 6661857bc25e1f836edc15ef35edb30739d43656 Mon Sep 17 00:00:00 2001 From: why Date: Tue, 30 Aug 2022 22:56:18 +0800 Subject: [PATCH 08/14] add `RemoteMr` access control We can use `set_max_rmr_access` to set the maximum permission on `RemoteMr` that the remote end can request. --- src/agent.rs | 29 ++++++--- src/lib.rs | 113 ++++++++++++++++++++++++++++++------ src/memory_region/local.rs | 22 +++---- src/memory_region/mod.rs | 15 ++++- src/memory_region/raw.rs | 2 +- src/memory_region/remote.rs | 30 +++++----- src/mr_allocator.rs | 8 +-- src/rmr_manager.rs | 11 +++- tests/rmr_access.rs | 77 ++++++++++++++++++++++++ 9 files changed, 247 insertions(+), 60 deletions(-) create mode 100644 tests/rmr_access.rs diff --git a/src/agent.rs b/src/agent.rs index 7651628..c9ffd8a 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -13,6 +13,7 @@ use crate::{ queue_pair::QueuePair, }; use clippy_utilities::Cast; +use rdma_sys::ibv_access_flags; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::time::SystemTime; @@ -59,8 +60,6 @@ pub(crate) struct Agent { /// Agent thread resource #[allow(dead_code)] agent_thread: Arc, - /// Max message length - pub(crate) max_sr_data_len: usize, } impl Drop for Agent { @@ -77,9 +76,10 @@ impl Agent { qp: Arc, allocator: Arc, max_sr_data_len: usize, + max_rmr_access: ibv_access_flags, ) -> io::Result { let response_waits = Arc::new(parking_lot::Mutex::new(HashMap::new())); - let rmr_manager = RemoteMrManager::new(); + let rmr_manager = RemoteMrManager::new(Arc::clone(&qp.pd), max_rmr_access); let (local_mr_send, local_mr_recv) = channel(1024); let (remote_mr_send, remote_mr_recv) = channel(1024); let (data_send, data_recv) = channel(1024); @@ -112,7 +112,6 @@ impl Agent { imm_recv, handles, agent_thread, - max_sr_data_len, }) } @@ -209,7 +208,7 @@ impl Agent { let mut start = 0; let lm_len = lm.length(); while start < lm_len { - let end = (start.saturating_add(self.max_sr_data_len)).min(lm_len); + let end = (start.saturating_add(self.max_msg_len())).min(lm_len); let kind = RequestKind::SendData(SendDataRequest { len: end.wrapping_sub(start), }); @@ -266,6 +265,16 @@ impl Agent { .await .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "imm data channel closed")) } + + /// Get the max length of message for send/recv + pub(crate) fn max_msg_len(&self) -> usize { + self.inner.max_sr_data_len + } + + /// Get the max access permission for remote mr requests + pub(crate) fn max_rmr_access(&self) -> ibv_access_flags { + self.inner.rmr_manager.max_rmr_access + } } /// Agent thread data structure, actually it spawn a task on the tokio thread pool @@ -426,9 +435,11 @@ impl AgentThread { let response = match request.kind { RequestKind::AllocMR(param) => { // TODO: error handling - let mr = self.inner.allocator.alloc_zeroed_default( + let mr = self.inner.allocator.alloc_zeroed( &Layout::from_size_align(param.size, param.align) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?, + self.inner.rmr_manager.max_rmr_access, + &self.inner.rmr_manager.pd, )?; // SAFETY: no date race here let token = unsafe { mr.token_with_timeout_unchecked(param.timeout) }.map_or_else( @@ -718,7 +729,7 @@ lazy_static! { }; } /// Used for checking if `header_buf` is clean. -const CLEAN_STATE: [u8; 52] = [0_u8; 52]; +const CLEAN_STATE: [u8; 56] = [0_u8; 56]; lazy_static! { static ref REQUEST_HEADER_MAX_LEN: usize = { @@ -734,6 +745,7 @@ lazy_static! { len: 0, rkey: 0, ddl: SystemTime::now(), + access:0, }, }), RequestKind::SendMR(SendMRRequest { @@ -742,6 +754,7 @@ lazy_static! { len: 0, rkey: 0, ddl: SystemTime::now(), + access:0, }), }), RequestKind::SendMR(SendMRRequest { @@ -750,6 +763,7 @@ lazy_static! { len: 0, rkey: 0, ddl: SystemTime::now(), + access:0, }), }), RequestKind::SendData(SendDataRequest { len: 0 }), @@ -781,6 +795,7 @@ lazy_static! { len: 0, rkey: 0, ddl: SystemTime::now(), + access:0, }, }), ResponseKind::ReleaseMR(ReleaseMRResponse { status: 0 }), diff --git a/src/lib.rs b/src/lib.rs index 78ad99e..b85d301 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -198,7 +198,7 @@ extern crate lazy_static; /// A wrapper for ibv_access_flag, hide the ibv binding types #[bitflags] #[repr(u64)] -#[derive(Copy, Clone, Debug, PartialEq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum AccessFlag { /// local write permission LocalWrite, @@ -234,7 +234,7 @@ pub enum AccessFlag { /// let layout = Layout::new::<[u8; 4096]>(); /// let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; /// let mr = rdma.alloc_local_mr_with_access(layout, access).unwrap(); -/// assert_eq!(mr.access(), flags_into_ibv_access(access)); +/// assert_eq!(mr.ibv_access(), flags_into_ibv_access(access)); /// } /// /// ``` @@ -272,6 +272,60 @@ pub fn flags_into_ibv_access(flags: BitFlags) -> ibv_access_flags { ret } +/// Convert `ibv_access_flags` into `BitFlags` +/// +/// # Example +/// +/// ``` +/// use async_rdma::{ibv_access_into_flags, AccessFlag, MrAccess, RdmaBuilder}; +/// use std::alloc::Layout; +/// +/// #[tokio::main] +/// async fn main() { +/// let rdma = RdmaBuilder::default().build().unwrap(); +/// let layout = Layout::new::<[u8; 4096]>(); +/// let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; +/// let mr = rdma.alloc_local_mr_with_access(layout, access).unwrap(); +/// assert_eq!(access, ibv_access_into_flags(mr.ibv_access())); +/// } +/// +/// ``` +#[inline] +#[must_use] +pub fn ibv_access_into_flags(access: ibv_access_flags) -> BitFlags { + let mut ret = BitFlags::::empty(); + if (access & ibv_access_flags::IBV_ACCESS_LOCAL_WRITE).0 != 0 { + ret |= AccessFlag::LocalWrite; + } + if (access & ibv_access_flags::IBV_ACCESS_LOCAL_WRITE).0 != 0 { + ret |= AccessFlag::LocalWrite; + } + if (access & ibv_access_flags::IBV_ACCESS_REMOTE_READ).0 != 0 { + ret |= AccessFlag::RemoteRead; + } + if (access & ibv_access_flags::IBV_ACCESS_REMOTE_WRITE).0 != 0 { + ret |= AccessFlag::RemoteWrite; + } + if (access & ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC).0 != 0 { + ret |= AccessFlag::RemoteAtomic; + } + if (access & ibv_access_flags::IBV_ACCESS_MW_BIND).0 != 0 { + ret |= AccessFlag::MwBind; + } + if (access & ibv_access_flags::IBV_ACCESS_ZERO_BASED).0 != 0 { + ret |= AccessFlag::ZeroBased; + } + if (access & ibv_access_flags::IBV_ACCESS_ON_DEMAND).0 != 0 { + ret |= AccessFlag::OnDemand; + } + if (access & ibv_access_flags::IBV_ACCESS_HUGETLB).0 != 0 { + ret |= AccessFlag::HugeTlb; + } + if (access & ibv_access_flags::IBV_ACCESS_RELAXED_ORDERING).0 != 0 { + ret |= AccessFlag::RelaxOrder; + } + ret +} /// initial device attributes #[derive(Debug)] pub struct DeviceInitAttr { @@ -380,6 +434,8 @@ impl Default for MRInitAttr { pub struct AgentInitAttr { /// Max length of message send/recv by Agent max_message_length: usize, + /// Max access permission for remote mr requests + max_rmr_access: ibv_access_flags, } impl Default for AgentInitAttr { @@ -387,6 +443,7 @@ impl Default for AgentInitAttr { fn default() -> Self { Self { max_message_length: MAX_MSG_LEN, + max_rmr_access: *DEFAULT_ACCESS, } } } @@ -471,7 +528,11 @@ impl RdmaBuilder { let mut rdma = self.build()?; let remote = tcp_connect_helper(addr, &rdma.endpoint()).await?; rdma.qp_handshake(remote)?; - rdma.init_agent(self.agent_attr.max_message_length).await?; + rdma.init_agent( + self.agent_attr.max_message_length, + self.agent_attr.max_rmr_access, + ) + .await?; Ok(rdma) } ConnectionType::RCCM => Err(io::Error::new( @@ -579,9 +640,10 @@ impl RdmaBuilder { )), ConnectionType::RCCM => { let max_message_length = self.agent_attr.max_message_length; + let max_rmr_access = self.agent_attr.max_rmr_access; let mut rdma = self.build()?; cm_connect_helper(&mut rdma, node, service)?; - rdma.init_agent(max_message_length).await?; + rdma.init_agent(max_message_length, max_rmr_access).await?; Ok(rdma) } } @@ -632,7 +694,11 @@ impl RdmaBuilder { let remote = tcp_listen(&tcp_listener, &rdma.endpoint()).await?; rdma.qp_handshake(remote)?; debug!("handshake done"); - rdma.init_agent(self.agent_attr.max_message_length).await?; + rdma.init_agent( + self.agent_attr.max_message_length, + self.agent_attr.max_rmr_access, + ) + .await?; rdma.clone_attr = CloneAttr::default().set_tcp_listener(tcp_listener); Ok(rdma) } @@ -754,6 +820,14 @@ impl RdmaBuilder { self.agent_attr.max_message_length = max_msg_len; self } + + /// Set max access permission for remote mr requests + #[inline] + #[must_use] + pub fn set_max_rmr_access(mut self, flags: BitFlags) -> Self { + self.agent_attr.max_rmr_access = flags_into_ibv_access(flags); + self + } } impl Debug for RdmaBuilder { @@ -1076,12 +1150,17 @@ impl Rdma { } /// Agent init helper - async fn init_agent(&mut self, max_message_length: usize) -> io::Result<()> { + async fn init_agent( + &mut self, + max_message_length: usize, + max_rmr_access: ibv_access_flags, + ) -> io::Result<()> { if !self.raw { let agent = Arc::new(Agent::new( Arc::::clone(&self.qp), Arc::::clone(&self.allocator), max_message_length, + max_rmr_access, )?); self.agent = Some(agent); // wait for the remote agent to prepare @@ -1154,13 +1233,13 @@ impl Rdma { rdma.qp_handshake(remote)?; debug!("handshake done"); #[allow(clippy::unreachable)] - let max_message_length = self.agent.as_ref().map_or_else( + let (max_message_length, max_rmr_access) = self.agent.as_ref().map_or_else( || { unreachable!("agent of parent rdma is None"); }, - |agent| agent.max_sr_data_len, + |agent| (agent.max_msg_len(), agent.max_rmr_access()), ); - rdma.init_agent(max_message_length).await?; + rdma.init_agent(max_message_length, max_rmr_access).await?; Ok(rdma) } ConnectionType::RCCM => Err(io::Error::new( @@ -1220,13 +1299,13 @@ impl Rdma { let remote = tcp_connect_helper(addr, &rdma.endpoint()).await?; rdma.qp_handshake(remote)?; #[allow(clippy::unreachable)] - let max_message_length = self.agent.as_ref().map_or_else( + let (max_message_length, max_rmr_access) = self.agent.as_ref().map_or_else( || { unreachable!("agent of parent rdma is None"); }, - |agent| agent.max_sr_data_len, + |agent| (agent.max_msg_len(), agent.max_rmr_access()), ); - rdma.init_agent(max_message_length).await?; + rdma.init_agent(max_message_length, max_rmr_access).await?; Ok(rdma) } ConnectionType::RCCM => Err(io::Error::new( @@ -2100,7 +2179,7 @@ impl Rdma { ); let remote = tcp_connect_helper(addr, &rdma.endpoint()).await?; rdma.qp_handshake(remote)?; - rdma.init_agent(max_message_length).await?; + rdma.init_agent(max_message_length, *DEFAULT_ACCESS).await?; // wait for server to initialize tokio::time::sleep(Duration::from_secs(1)).await; Ok(rdma) @@ -2210,7 +2289,7 @@ impl Rdma { "should set connection type to RCSocket" ); cm_connect_helper(&mut rdma, node, service)?; - rdma.init_agent(max_message_length).await?; + rdma.init_agent(max_message_length, *DEFAULT_ACCESS).await?; // wait for server to initialize tokio::time::sleep(Duration::from_secs(1)).await; Ok(rdma) @@ -2361,7 +2440,7 @@ impl Rdma { /// let layout = Layout::new::<[u8; 4096]>(); /// let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; /// let mr = rdma.alloc_local_mr_with_access(layout, access).unwrap(); - /// assert_eq!(mr.access(), flags_into_ibv_access(access)); + /// assert_eq!(mr.access(), access); /// } /// /// ``` @@ -2402,7 +2481,7 @@ impl Rdma { /// rdma.alloc_local_mr_uninit_with_access(layout, access) /// .unwrap() /// }; - /// assert_eq!(mr.access(), flags_into_ibv_access(access)); + /// assert_eq!(mr.access(), access); /// } /// /// ``` @@ -3166,7 +3245,7 @@ impl RdmaListener { stream.write_all(&local).await?; rdma.qp_handshake(remote)?; debug!("handshake done"); - rdma.init_agent(max_message_length).await?; + rdma.init_agent(max_message_length, *DEFAULT_ACCESS).await?; Ok(rdma) } } diff --git a/src/memory_region/local.rs b/src/memory_region/local.rs index f9e637a..8467121 100644 --- a/src/memory_region/local.rs +++ b/src/memory_region/local.rs @@ -154,6 +154,7 @@ pub unsafe trait LocalMrReadAccess: MrAccess { len: self.length(), rkey: self.rkey(), ddl, + access: self.ibv_access().0, }) }, ) @@ -178,6 +179,7 @@ pub unsafe trait LocalMrReadAccess: MrAccess { len: self.length(), rkey: self.rkey_unchecked(), ddl, + access: self.ibv_access().0, }) }, ) @@ -326,8 +328,8 @@ impl MrAccess for LocalMr { } #[inline] - fn access(&self) -> ibv_access_flags { - self.read_inner().access() + fn ibv_access(&self) -> ibv_access_flags { + self.read_inner().ibv_access() } } @@ -518,8 +520,8 @@ impl MrAccess for LocalMrInner { } #[inline] - fn access(&self) -> ibv_access_flags { - self.raw.access() + fn ibv_access(&self) -> ibv_access_flags { + self.raw.ibv_access() } } @@ -568,8 +570,8 @@ impl MrAccess for &LocalMr { } #[inline] - fn access(&self) -> ibv_access_flags { - self.read_inner().access() + fn ibv_access(&self) -> ibv_access_flags { + self.read_inner().ibv_access() } } @@ -616,8 +618,8 @@ impl MrAccess for LocalMrSlice<'_> { } #[inline] - fn access(&self) -> ibv_access_flags { - self.read_inner().access() + fn ibv_access(&self) -> ibv_access_flags { + self.read_inner().ibv_access() } } @@ -697,8 +699,8 @@ impl MrAccess for LocalMrSliceMut<'_> { } #[inline] - fn access(&self) -> ibv_access_flags { - self.read_inner().access() + fn ibv_access(&self) -> ibv_access_flags { + self.read_inner().ibv_access() } } diff --git a/src/memory_region/mod.rs b/src/memory_region/mod.rs index 6138ecf..6ab593d 100644 --- a/src/memory_region/mod.rs +++ b/src/memory_region/mod.rs @@ -4,11 +4,14 @@ pub(crate) mod local; mod raw; /// Remote Memory Region pub(crate) mod remote; +use enumflags2::BitFlags; pub(crate) use raw::RawMemoryRegion; use rdma_sys::ibv_access_flags; use serde::{Deserialize, Serialize}; use std::{fmt::Debug, time::SystemTime}; +use crate::{ibv_access_into_flags, AccessFlag}; + /// Rdma Memory Region Access pub trait MrAccess: Sync + Send + Debug { /// Get the start addr @@ -20,8 +23,14 @@ pub trait MrAccess: Sync + Send + Debug { /// Get the remote key fn rkey(&self) -> u32; - /// Get the access of mr - fn access(&self) -> ibv_access_flags; + /// Get the `ibv_access_flags` of mr + fn ibv_access(&self) -> ibv_access_flags; + + /// Get the `BitFlags` of mr + #[inline] + fn access(&self) -> BitFlags { + ibv_access_into_flags(self.ibv_access()) + } } /// Memory region token used for the remote access @@ -35,4 +44,6 @@ pub struct MrToken { pub rkey: u32, /// Deadline for timeout pub ddl: SystemTime, + /// Remote mr `ibv_access_flags` inner + pub access: u32, } diff --git a/src/memory_region/raw.rs b/src/memory_region/raw.rs index 1ef5a3b..3417612 100644 --- a/src/memory_region/raw.rs +++ b/src/memory_region/raw.rs @@ -37,7 +37,7 @@ impl MrAccess for RawMemoryRegion { unsafe { self.inner_mr.as_ref().rkey } } - fn access(&self) -> ibv_access_flags { + fn ibv_access(&self) -> ibv_access_flags { self.access } } diff --git a/src/memory_region/remote.rs b/src/memory_region/remote.rs index 5385513..79634c7 100644 --- a/src/memory_region/remote.rs +++ b/src/memory_region/remote.rs @@ -1,7 +1,7 @@ use rdma_sys::ibv_access_flags; use super::{MrAccess, MrToken}; -use crate::{agent::AgentInner, DEFAULT_ACCESS}; +use crate::agent::AgentInner; use std::{io, ops::Range, sync::Arc, time::SystemTime}; /// Remote Memory Region Accrss @@ -45,9 +45,8 @@ impl MrAccess for RemoteMr { } #[inline] - fn access(&self) -> ibv_access_flags { - // TODO: add access control for rmr - *DEFAULT_ACCESS + fn ibv_access(&self) -> ibv_access_flags { + ibv_access_flags(self.token.access) } } @@ -88,6 +87,7 @@ impl RemoteMr { len: i.end.wrapping_sub(i.start), rkey: self.rkey(), ddl: self.token.ddl, + access: self.ibv_access().0, }; Ok(RemoteMrSlice::new_from_token(self, slice_token)) } @@ -105,6 +105,7 @@ impl RemoteMr { len: i.end.wrapping_sub(i.start), rkey: self.rkey(), ddl: self.token.ddl, + access: self.ibv_access().0, }; Ok(RemoteMrSliceMut::new_from_token(self, slice_token)) } @@ -128,12 +129,10 @@ impl MrAccess for &RemoteMr { } #[inline] - fn access(&self) -> ibv_access_flags { - // TODO: add access control for rmr - *DEFAULT_ACCESS + fn ibv_access(&self) -> ibv_access_flags { + ibv_access_flags(self.token.access) } } - impl RemoteMrReadAccess for &RemoteMr { #[inline] fn token(&self) -> MrToken { @@ -158,9 +157,8 @@ impl MrAccess for &mut RemoteMr { } #[inline] - fn access(&self) -> ibv_access_flags { - // TODO: add access control for rmr - *DEFAULT_ACCESS + fn ibv_access(&self) -> ibv_access_flags { + ibv_access_flags(self.token.access) } } impl RemoteMrReadAccess for &mut RemoteMr { @@ -207,9 +205,8 @@ impl MrAccess for RemoteMrSlice<'_> { } #[inline] - fn access(&self) -> ibv_access_flags { - // TODO: add access control for rmr - *DEFAULT_ACCESS + fn ibv_access(&self) -> ibv_access_flags { + ibv_access_flags(self.token.access) } } @@ -246,9 +243,8 @@ impl MrAccess for RemoteMrSliceMut<'_> { } #[inline] - fn access(&self) -> ibv_access_flags { - // TODO: add access control for rmr - *DEFAULT_ACCESS + fn ibv_access(&self) -> ibv_access_flags { + ibv_access_flags(self.token.access) } } diff --git a/src/mr_allocator.rs b/src/mr_allocator.rs index 76932a5..18fcef8 100644 --- a/src/mr_allocator.rs +++ b/src/mr_allocator.rs @@ -942,8 +942,8 @@ mod tests { let layout = Layout::new::<[u8; 4096]>(); let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; let mr = rdma.alloc_local_mr_with_access(layout, access)?; - assert_eq!(mr.access(), flags_into_ibv_access(access)); - assert_ne!(mr.access(), *DEFAULT_ACCESS); + assert_eq!(mr.ibv_access(), flags_into_ibv_access(access)); + assert_ne!(mr.ibv_access(), *DEFAULT_ACCESS); Ok(()) } @@ -953,8 +953,8 @@ mod tests { let layout = Layout::new::<[u8; 4096]>(); let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; let mr = rdma.alloc_local_mr_with_access(layout, access)?; - assert_eq!(mr.access(), flags_into_ibv_access(access)); - assert_ne!(mr.access(), *DEFAULT_ACCESS); + assert_eq!(mr.ibv_access(), flags_into_ibv_access(access)); + assert_ne!(mr.ibv_access(), *DEFAULT_ACCESS); Ok(()) } diff --git a/src/rmr_manager.rs b/src/rmr_manager.rs index b5ab9d1..1af36a2 100644 --- a/src/rmr_manager.rs +++ b/src/rmr_manager.rs @@ -1,5 +1,6 @@ -use crate::{memory_region::MrToken, LocalMr}; +use crate::{memory_region::MrToken, protection_domain::ProtectionDomain, LocalMr}; use parking_lot::Mutex; +use rdma_sys::ibv_access_flags; use std::{collections::HashMap, io, sync::Arc, time::Duration}; use tokio::{ sync::{mpsc, oneshot}, @@ -27,11 +28,15 @@ pub(crate) struct RemoteMrManager { timer_tx: mpsc::Sender, /// `RemoteMrManager` task handler handler: JoinHandle<()>, + /// `Protection Domain` of remote mr + pub(crate) pd: Arc, + /// Max access permission for remote mr requests + pub(crate) max_rmr_access: ibv_access_flags, } impl RemoteMrManager { /// New a `RemoteMrManager` - pub(crate) fn new() -> Self { + pub(crate) fn new(pd: Arc, max_rmr_access: ibv_access_flags) -> Self { let mr_own = Arc::new(Mutex::new(HashMap::new())); let (timer_tx, timer_rx) = mpsc::channel::(TIMER_CHANNEL_BUF_SIZE); let handler = tokio::task::spawn(timeout_monitor(timer_rx, RmrMap::clone(&mr_own))); @@ -39,6 +44,8 @@ impl RemoteMrManager { mr_map: mr_own, timer_tx, handler, + pd, + max_rmr_access, } } diff --git a/tests/rmr_access.rs b/tests/rmr_access.rs new file mode 100644 index 0000000..9f6011e --- /dev/null +++ b/tests/rmr_access.rs @@ -0,0 +1,77 @@ +mod test_utilities; +use async_rdma::{LocalMrReadAccess, Rdma}; +use std::{alloc::Layout, io, time::Duration}; +use test_utilities::test_server_client; +use tokio::time::sleep; + +mod send_local_mr { + use async_rdma::AccessFlag; + + use super::*; + + async fn server(rdma: Rdma) -> io::Result<()> { + let lmr = rdma.alloc_local_mr(Layout::new::())?; + let mut rmr = rdma.receive_remote_mr().await?; + // wrong mr access, should panic + rdma.write(&lmr, &mut rmr).await.unwrap(); + dbg!(unsafe { *(*lmr.as_ptr() as *const char) }); + Ok(()) + } + + async fn client(rdma: Rdma) -> io::Result<()> { + let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; + let lmr = rdma.alloc_local_mr_with_access(Layout::new::(), access)?; + rdma.send_local_mr(lmr).await?; + // wait for panic + sleep(Duration::from_secs(3)).await; + Ok(()) + } + + #[should_panic] + #[test] + fn main() { + test_server_client(server, client); + } +} + +mod request_remote_mr { + use std::net::{Ipv4Addr, SocketAddrV4}; + + use async_rdma::{AccessFlag, RdmaBuilder}; + use portpicker::pick_unused_port; + + use super::*; + async fn client(addr: SocketAddrV4) -> io::Result<()> { + let rdma = RdmaBuilder::default().connect(addr).await?; + let lmr = rdma.alloc_local_mr(Layout::new::())?; + let mut rmr = rdma.request_remote_mr(Layout::new::()).await?; + // wrong mr access, should panic + rdma.write(&lmr, &mut rmr).await.unwrap(); + rdma.send_remote_mr(rmr).await?; + Ok(()) + } + /// + #[tokio::main] + async fn server(addr: SocketAddrV4) -> io::Result<()> { + let flags = AccessFlag::LocalWrite | AccessFlag::RemoteRead; + let rdma = RdmaBuilder::default() + .set_max_rmr_access(flags) + .listen(addr) + .await?; + rdma.receive_local_mr().await?; + Ok(()) + } + + #[should_panic] + #[tokio::main] + #[test] + async fn main() { + let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), pick_unused_port().unwrap()); + std::thread::spawn(move || server(addr)); + tokio::time::sleep(Duration::from_secs(3)).await; + client(addr) + .await + .map_err(|err| println!("{}", err)) + .unwrap(); + } +} From ae21577138360c2da0a1f59167756cd033ffb3e1 Mon Sep 17 00:00:00 2001 From: why Date: Wed, 31 Aug 2022 21:06:00 +0800 Subject: [PATCH 09/14] suppot new max remote mr access for cloned `Rdma` Set new max access permission for remote mr requests for new `Rdma` by `set_new_max_rmr_access` --- src/lib.rs | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index b85d301..3d4946e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -990,6 +990,8 @@ pub(crate) struct CloneAttr { pub(crate) port_num: Option, /// Clone `Rdma` with new qp access pub(crate) qp_access: Option, + /// Clone `Rdma` with new max access permission for remote mr requests + pub(crate) max_rmr_access: Option, } impl CloneAttr { @@ -1017,6 +1019,12 @@ impl CloneAttr { self.qp_access = Some(access); self } + + /// Set max access permission for remote mr requests with new agent + fn set_max_rmr_access(mut self, access: ibv_access_flags) -> Self { + self.max_rmr_access = Some(access); + self + } } /// Rdma handler, the only interface that the users deal with rdma @@ -1239,6 +1247,10 @@ impl Rdma { }, |agent| (agent.max_msg_len(), agent.max_rmr_access()), ); + let max_rmr_access = self + .clone_attr + .max_rmr_access + .map_or(max_rmr_access, |new_access| new_access); rdma.init_agent(max_message_length, max_rmr_access).await?; Ok(rdma) } @@ -1305,6 +1317,10 @@ impl Rdma { }, |agent| (agent.max_msg_len(), agent.max_rmr_access()), ); + let max_rmr_access = self + .clone_attr + .max_rmr_access + .map_or(max_rmr_access, |new_access| new_access); rdma.init_agent(max_message_length, max_rmr_access).await?; Ok(rdma) } @@ -3026,6 +3042,67 @@ impl Rdma { self } + /// Set max access permission for remote mr requests for new `Rdma` that created by `clone` + /// + /// Used with `listen`, `new_connect` + /// + /// # Examples + /// + /// ``` + /// use async_rdma::{AccessFlag, RdmaBuilder, MrAccess}; + /// use portpicker::pick_unused_port; + /// use std::{ + /// alloc::Layout, + /// io, + /// net::{Ipv4Addr, SocketAddrV4}, + /// time::Duration, + /// }; + /// + /// async fn client(addr: SocketAddrV4) -> io::Result<()> { + /// let rdma = RdmaBuilder::default().connect(addr).await?; + /// let rmr = rdma.request_remote_mr(Layout::new::()).await?; + /// let new_rdma = rdma.new_connect(addr).await?; + /// let new_rmr = new_rdma.request_remote_mr(Layout::new::()).await?; + /// let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; + /// assert_eq!(new_rmr.access(), access); + /// assert_ne!(rmr.access(), new_rmr.access()); + /// new_rdma.send_remote_mr(new_rmr).await?; + /// Ok(()) + /// } + /// + /// #[tokio::main] + /// async fn server(addr: SocketAddrV4) -> io::Result<()> { + /// let rdma = RdmaBuilder::default().listen(addr).await?; + /// let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; + /// let rdma = rdma.set_new_max_rmr_access(access); + /// let new_rdma = rdma.listen().await?; + /// // receive the metadata of the lmr that had been requested by client + /// let _lmr = new_rdma.receive_local_mr().await?; + /// // wait for the agent thread to send all reponses to the remote. + /// tokio::time::sleep(Duration::from_secs(1)).await; + /// Ok(()) + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), pick_unused_port().unwrap()); + /// std::thread::spawn(move || server(addr)); + /// tokio::time::sleep(Duration::from_secs(3)).await; + /// client(addr) + /// .await + /// .map_err(|err| println!("{}", err)) + /// .unwrap(); + /// } + /// ``` + #[inline] + #[must_use] + pub fn set_new_max_rmr_access(mut self, max_rmr_access: BitFlags) -> Self { + self.clone_attr = self + .clone_attr + .set_max_rmr_access(flags_into_ibv_access(max_rmr_access)); + self + } + /// Set qp access for new `Rdma` that created by `clone` /// /// Used with `listen`, `new_connect` From 320afd3ed3900a61560b7d0f579f064a2b45090a Mon Sep 17 00:00:00 2001 From: why Date: Thu, 1 Sep 2022 13:34:25 +0800 Subject: [PATCH 10/14] udpate `AccessPDKey` --- src/mr_allocator.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mr_allocator.rs b/src/mr_allocator.rs index 18fcef8..22fb1c4 100644 --- a/src/mr_allocator.rs +++ b/src/mr_allocator.rs @@ -72,15 +72,15 @@ static mut RDMA_EXTENT_HOOKS: extent_hooks_t = extent_hooks_t { merge: Some(RDMA_MERGE_EXTENT_HOOK), }; -/// Combination of `ProtectionDomain`'s ptr and `ibv_access_falgs` as a key. +/// The tuple of `ibv_access_falgs` and `ProtectionDomain` as a key. #[derive(PartialEq, Eq, Hash)] -pub(crate) struct AccessPDKey(u128); +pub(crate) struct AccessPDKey(u32, usize); impl AccessPDKey { - /// Cast the `ProtectionDomain`'s ptr and `ibv_access_flags` into a `u128` + /// Create a new `AccessPDKey` #[allow(clippy::as_conversions)] pub(crate) fn new(pd: &Arc, access: ibv_access_flags) -> Self { - Self((u128::from(access.0)).wrapping_shl(64) | (pd.as_ptr() as u128)) + Self(access.0, pd.as_ptr() as usize) } } From b525adcde9c5a2724a2ebd057e099ddffa74b6aa Mon Sep 17 00:00:00 2001 From: why Date: Mon, 5 Sep 2022 17:13:13 +0800 Subject: [PATCH 11/14] share some code to deduplicate --- src/mr_allocator.rs | 46 +++++++++++++++++---------------------------- 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/src/mr_allocator.rs b/src/mr_allocator.rs index 22fb1c4..d061ed2 100644 --- a/src/mr_allocator.rs +++ b/src/mr_allocator.rs @@ -251,23 +251,12 @@ impl MrAllocator { match self.strategy { MRManageStrategy::Jemalloc => { - if access == self.default_access && pd.inner_pd == self.default_pd.inner_pd { - alloc_from_je(self.default_arena_ind, layout, flag).map_or_else(||{ - Err(io::Error::new(io::ErrorKind::OutOfMemory, "insufficient contiguous memory was available to service the allocation request")) - }, |addr|{ - #[allow(clippy::unreachable)] - let raw_mr = lookup_raw_mr(self.default_arena_ind, addr as usize).map_or_else( - || { - unreachable!("can not find raw mr by addr {}", addr as usize); - }, - |raw_mr| raw_mr, - ); - Ok(LocalMrInner::new(addr as usize, *layout, raw_mr, self.strategy)) - }) + let arena_id = if access == self.default_access && pd.inner_pd == self.default_pd.inner_pd { + self.default_arena_ind }else { let access_pd_key = AccessPDKey::new(pd, access); let arena_id = ACCESS_PD_ARENA_MAP.lock().get(&access_pd_key).copied(); - let arena_id = arena_id.map_or_else(||{ + arena_id.map_or_else(||{ // didn't use this access before, so create an arena to mange this kind of MR let ind = init_je_statics(Arc::::clone(pd), access)?; if ACCESS_PD_ARENA_MAP.lock().insert(access_pd_key, ind).is_some(){ @@ -277,21 +266,20 @@ impl MrAllocator { )) }; Ok(ind) - }, |ind|{Ok(ind)})?; - - alloc_from_je(arena_id, layout, flag).map_or_else(||{ - Err(io::Error::new(io::ErrorKind::OutOfMemory, "insufficient contiguous memory was available to service the allocation request")) - }, |addr|{ - #[allow(clippy::unreachable)] - let raw_mr = lookup_raw_mr(arena_id, addr as usize).map_or_else( - || { - unreachable!("can not find raw mr with arena_id: {} by addr: {}",arena_id, addr as usize); - }, - |raw_mr| raw_mr, - ); - Ok(LocalMrInner::new(addr as usize, *layout, raw_mr, self.strategy)) - }) - } + }, |ind|{Ok(ind)})? + }; + alloc_from_je(arena_id, layout, flag).map_or_else(||{ + Err(io::Error::new(io::ErrorKind::OutOfMemory, "insufficient contiguous memory was available to service the allocation request")) + }, |addr|{ + #[allow(clippy::unreachable)] + let raw_mr = lookup_raw_mr(arena_id, addr as usize).map_or_else( + || { + unreachable!("can not find raw mr with arena_id: {} by addr: {}",arena_id, addr as usize); + }, + |raw_mr| raw_mr, + ); + Ok(LocalMrInner::new(addr as usize, *layout, raw_mr, self.strategy)) + }) }, MRManageStrategy::Raw => { alloc_raw_mem(layout, flag).map_or_else(||{ From 308aa88426f0544bc9da4b03b9087c2769048ee1 Mon Sep 17 00:00:00 2001 From: why Date: Tue, 6 Sep 2022 13:34:55 +0800 Subject: [PATCH 12/14] cleanup the code about access Place access in a separate file and hide `ibv_access_flags` for users. Users should not touch the underlying details about `ibv_access_flags`. Because it may change with `rdma-sys` bindings. --- src/access.rs | 127 ++++++++++++++++++++++++++++++++ src/lib.rs | 141 ++---------------------------------- src/memory_region/local.rs | 12 ++- src/memory_region/mod.rs | 13 ++-- src/memory_region/raw.rs | 5 +- src/memory_region/remote.rs | 17 ++++- src/mr_allocator.rs | 10 +-- 7 files changed, 176 insertions(+), 149 deletions(-) create mode 100644 src/access.rs diff --git a/src/access.rs b/src/access.rs new file mode 100644 index 0000000..7746afd --- /dev/null +++ b/src/access.rs @@ -0,0 +1,127 @@ +use enumflags2::{bitflags, BitFlags}; +use rdma_sys::ibv_access_flags; + +/// A wrapper for `ibv_access_flag`, hide the ibv binding types +#[bitflags] +#[repr(u64)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum AccessFlag { + /// local write permission + LocalWrite, + /// remote write permission + RemoteWrite, + /// remote read permission + RemoteRead, + /// remote atomic operation permission + RemoteAtomic, + /// enable memory window binding + MwBind, + /// use byte offset from beginning of MR to access this MR, instead of a pointer address + ZeroBased, + /// create an on-demand paging MR + OnDemand, + /// huge pages are guaranteed to be used for this MR, only used with `OnDemand` + HugeTlb, + /// allow system to reorder accesses to the MR to improve performance + RelaxOrder, +} + +/// Convert `BitFlags` into `ibv_access_flags` +#[inline] +#[must_use] +pub(crate) fn flags_into_ibv_access(flags: BitFlags) -> ibv_access_flags { + let mut ret = ibv_access_flags(0); + if flags.contains(AccessFlag::LocalWrite) { + ret |= ibv_access_flags::IBV_ACCESS_LOCAL_WRITE; + } + if flags.contains(AccessFlag::RemoteWrite) { + ret |= ibv_access_flags::IBV_ACCESS_REMOTE_WRITE; + } + if flags.contains(AccessFlag::RemoteRead) { + ret |= ibv_access_flags::IBV_ACCESS_REMOTE_READ; + } + if flags.contains(AccessFlag::RemoteAtomic) { + ret |= ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC; + } + if flags.contains(AccessFlag::MwBind) { + ret |= ibv_access_flags::IBV_ACCESS_MW_BIND; + } + if flags.contains(AccessFlag::ZeroBased) { + ret |= ibv_access_flags::IBV_ACCESS_ZERO_BASED; + } + if flags.contains(AccessFlag::OnDemand) { + ret |= ibv_access_flags::IBV_ACCESS_ON_DEMAND; + } + if flags.contains(AccessFlag::HugeTlb) { + ret |= ibv_access_flags::IBV_ACCESS_HUGETLB; + } + if flags.contains(AccessFlag::RelaxOrder) { + ret |= ibv_access_flags::IBV_ACCESS_RELAXED_ORDERING; + } + ret +} + +/// Convert `ibv_access_flags` into `BitFlags` +#[inline] +#[must_use] +pub(crate) fn ibv_access_into_flags(access: ibv_access_flags) -> BitFlags { + let mut ret = BitFlags::::empty(); + if (access & ibv_access_flags::IBV_ACCESS_LOCAL_WRITE).0 != 0 { + ret |= AccessFlag::LocalWrite; + } + if (access & ibv_access_flags::IBV_ACCESS_LOCAL_WRITE).0 != 0 { + ret |= AccessFlag::LocalWrite; + } + if (access & ibv_access_flags::IBV_ACCESS_REMOTE_READ).0 != 0 { + ret |= AccessFlag::RemoteRead; + } + if (access & ibv_access_flags::IBV_ACCESS_REMOTE_WRITE).0 != 0 { + ret |= AccessFlag::RemoteWrite; + } + if (access & ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC).0 != 0 { + ret |= AccessFlag::RemoteAtomic; + } + if (access & ibv_access_flags::IBV_ACCESS_MW_BIND).0 != 0 { + ret |= AccessFlag::MwBind; + } + if (access & ibv_access_flags::IBV_ACCESS_ZERO_BASED).0 != 0 { + ret |= AccessFlag::ZeroBased; + } + if (access & ibv_access_flags::IBV_ACCESS_ON_DEMAND).0 != 0 { + ret |= AccessFlag::OnDemand; + } + if (access & ibv_access_flags::IBV_ACCESS_HUGETLB).0 != 0 { + ret |= AccessFlag::HugeTlb; + } + if (access & ibv_access_flags::IBV_ACCESS_RELAXED_ORDERING).0 != 0 { + ret |= AccessFlag::RelaxOrder; + } + ret +} + +#[cfg(test)] +mod access_test { + use super::*; + use crate::{memory_region::IbvAccess, RdmaBuilder}; + use std::alloc::Layout; + + #[tokio::test] + #[allow(clippy::unwrap_used)] + async fn flags_into_ibv_access_test() { + let rdma = RdmaBuilder::default().build().unwrap(); + let layout = Layout::new::<[u8; 4096]>(); + let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; + let mr = rdma.alloc_local_mr_with_access(layout, access).unwrap(); + assert_eq!(mr.ibv_access(), flags_into_ibv_access(access)); + } + + #[tokio::test] + #[allow(clippy::unwrap_used)] + async fn ibv_access_into_flags_test() { + let rdma = RdmaBuilder::default().build().unwrap(); + let layout = Layout::new::<[u8; 4096]>(); + let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; + let mr = rdma.alloc_local_mr_with_access(layout, access).unwrap(); + assert_eq!(access, ibv_access_into_flags(mr.ibv_access())); + } +} diff --git a/src/lib.rs b/src/lib.rs index 3d4946e..d61be95 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -128,6 +128,8 @@ mod context; /// The rmda device pub mod device; +/// Access of `QP` and `MR` +mod access; /// Error handling utilities mod error_utilities; /// The event channel that notifies the completion or error of a request @@ -157,11 +159,13 @@ mod rmr_manager; /// Work Request wrapper mod work_request; +use access::flags_into_ibv_access; +pub use access::AccessFlag; use agent::{Agent, MAX_MSG_LEN}; use clippy_utilities::Cast; use completion_queue::{DEFAULT_CQ_SIZE, DEFAULT_MAX_CQE}; use context::Context; -use enumflags2::{bitflags, BitFlags}; +use enumflags2::BitFlags; use error_utilities::log_ret_last_os_err; use event_listener::EventListener; pub use memory_region::{ @@ -195,137 +199,6 @@ use tracing::debug; #[macro_use] extern crate lazy_static; -/// A wrapper for ibv_access_flag, hide the ibv binding types -#[bitflags] -#[repr(u64)] -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum AccessFlag { - /// local write permission - LocalWrite, - /// remote write permission - RemoteWrite, - /// remote read permission - RemoteRead, - /// remote atomic operation permission - RemoteAtomic, - /// enable memory window binding - MwBind, - /// use byte offset from beginning of MR to access this MR, instead of a pointer address - ZeroBased, - /// create an on-demand paging MR - OnDemand, - /// huge pages are guaranteed to be used for this MR, only used with `OnDemand` - HugeTlb, - /// allow system to reorder accesses to the MR to improve performance - RelaxOrder, -} - -/// Convert `BitFlags` into `ibv_access_flags` -/// -/// # Example -/// -/// ``` -/// use async_rdma::{flags_into_ibv_access, AccessFlag, MrAccess, RdmaBuilder}; -/// use std::alloc::Layout; -/// -/// #[tokio::main] -/// async fn main() { -/// let rdma = RdmaBuilder::default().build().unwrap(); -/// let layout = Layout::new::<[u8; 4096]>(); -/// let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; -/// let mr = rdma.alloc_local_mr_with_access(layout, access).unwrap(); -/// assert_eq!(mr.ibv_access(), flags_into_ibv_access(access)); -/// } -/// -/// ``` -#[inline] -#[must_use] -pub fn flags_into_ibv_access(flags: BitFlags) -> ibv_access_flags { - let mut ret = ibv_access_flags(0); - if flags.contains(AccessFlag::LocalWrite) { - ret |= ibv_access_flags::IBV_ACCESS_LOCAL_WRITE; - } - if flags.contains(AccessFlag::RemoteWrite) { - ret |= ibv_access_flags::IBV_ACCESS_REMOTE_WRITE; - } - if flags.contains(AccessFlag::RemoteRead) { - ret |= ibv_access_flags::IBV_ACCESS_REMOTE_READ; - } - if flags.contains(AccessFlag::RemoteAtomic) { - ret |= ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC; - } - if flags.contains(AccessFlag::MwBind) { - ret |= ibv_access_flags::IBV_ACCESS_MW_BIND; - } - if flags.contains(AccessFlag::ZeroBased) { - ret |= ibv_access_flags::IBV_ACCESS_ZERO_BASED; - } - if flags.contains(AccessFlag::OnDemand) { - ret |= ibv_access_flags::IBV_ACCESS_ON_DEMAND; - } - if flags.contains(AccessFlag::HugeTlb) { - ret |= ibv_access_flags::IBV_ACCESS_HUGETLB; - } - if flags.contains(AccessFlag::RelaxOrder) { - ret |= ibv_access_flags::IBV_ACCESS_RELAXED_ORDERING; - } - ret -} - -/// Convert `ibv_access_flags` into `BitFlags` -/// -/// # Example -/// -/// ``` -/// use async_rdma::{ibv_access_into_flags, AccessFlag, MrAccess, RdmaBuilder}; -/// use std::alloc::Layout; -/// -/// #[tokio::main] -/// async fn main() { -/// let rdma = RdmaBuilder::default().build().unwrap(); -/// let layout = Layout::new::<[u8; 4096]>(); -/// let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; -/// let mr = rdma.alloc_local_mr_with_access(layout, access).unwrap(); -/// assert_eq!(access, ibv_access_into_flags(mr.ibv_access())); -/// } -/// -/// ``` -#[inline] -#[must_use] -pub fn ibv_access_into_flags(access: ibv_access_flags) -> BitFlags { - let mut ret = BitFlags::::empty(); - if (access & ibv_access_flags::IBV_ACCESS_LOCAL_WRITE).0 != 0 { - ret |= AccessFlag::LocalWrite; - } - if (access & ibv_access_flags::IBV_ACCESS_LOCAL_WRITE).0 != 0 { - ret |= AccessFlag::LocalWrite; - } - if (access & ibv_access_flags::IBV_ACCESS_REMOTE_READ).0 != 0 { - ret |= AccessFlag::RemoteRead; - } - if (access & ibv_access_flags::IBV_ACCESS_REMOTE_WRITE).0 != 0 { - ret |= AccessFlag::RemoteWrite; - } - if (access & ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC).0 != 0 { - ret |= AccessFlag::RemoteAtomic; - } - if (access & ibv_access_flags::IBV_ACCESS_MW_BIND).0 != 0 { - ret |= AccessFlag::MwBind; - } - if (access & ibv_access_flags::IBV_ACCESS_ZERO_BASED).0 != 0 { - ret |= AccessFlag::ZeroBased; - } - if (access & ibv_access_flags::IBV_ACCESS_ON_DEMAND).0 != 0 { - ret |= AccessFlag::OnDemand; - } - if (access & ibv_access_flags::IBV_ACCESS_HUGETLB).0 != 0 { - ret |= AccessFlag::HugeTlb; - } - if (access & ibv_access_flags::IBV_ACCESS_RELAXED_ORDERING).0 != 0 { - ret |= AccessFlag::RelaxOrder; - } - ret -} /// initial device attributes #[derive(Debug)] pub struct DeviceInitAttr { @@ -2447,7 +2320,7 @@ impl Rdma { /// # Example /// /// ``` - /// use async_rdma::{flags_into_ibv_access, AccessFlag, MrAccess, RdmaBuilder}; + /// use async_rdma::{AccessFlag, MrAccess, RdmaBuilder}; /// use std::alloc::Layout; /// /// #[tokio::main] @@ -2485,7 +2358,7 @@ impl Rdma { /// # Example /// /// ``` - /// use async_rdma::{flags_into_ibv_access, AccessFlag, MrAccess, RdmaBuilder}; + /// use async_rdma::{AccessFlag, MrAccess, RdmaBuilder}; /// use std::alloc::Layout; /// /// #[tokio::main] diff --git a/src/memory_region/local.rs b/src/memory_region/local.rs index 8467121..4d45342 100644 --- a/src/memory_region/local.rs +++ b/src/memory_region/local.rs @@ -1,4 +1,4 @@ -use super::{raw::RawMemoryRegion, MrAccess, MrToken}; +use super::{raw::RawMemoryRegion, IbvAccess, MrAccess, MrToken}; #[cfg(test)] use crate::protection_domain::ProtectionDomain; use crate::{ @@ -326,7 +326,9 @@ impl MrAccess for LocalMr { fn rkey(&self) -> u32 { self.read_inner().rkey() } +} +impl IbvAccess for LocalMr { #[inline] fn ibv_access(&self) -> ibv_access_flags { self.read_inner().ibv_access() @@ -518,7 +520,9 @@ impl MrAccess for LocalMrInner { fn rkey(&self) -> u32 { self.raw.rkey() } +} +impl IbvAccess for LocalMrInner { #[inline] fn ibv_access(&self) -> ibv_access_flags { self.raw.ibv_access() @@ -568,7 +572,9 @@ impl MrAccess for &LocalMr { fn rkey(&self) -> u32 { self.read_inner().rkey() } +} +impl IbvAccess for &LocalMr { #[inline] fn ibv_access(&self) -> ibv_access_flags { self.read_inner().ibv_access() @@ -616,7 +622,9 @@ impl MrAccess for LocalMrSlice<'_> { fn rkey(&self) -> u32 { self.lmr.rkey() } +} +impl IbvAccess for LocalMrSlice<'_> { #[inline] fn ibv_access(&self) -> ibv_access_flags { self.read_inner().ibv_access() @@ -697,7 +705,9 @@ impl MrAccess for LocalMrSliceMut<'_> { fn rkey(&self) -> u32 { self.lmr.rkey() } +} +impl IbvAccess for LocalMrSliceMut<'_> { #[inline] fn ibv_access(&self) -> ibv_access_flags { self.read_inner().ibv_access() diff --git a/src/memory_region/mod.rs b/src/memory_region/mod.rs index 6ab593d..48881ab 100644 --- a/src/memory_region/mod.rs +++ b/src/memory_region/mod.rs @@ -10,10 +10,10 @@ use rdma_sys::ibv_access_flags; use serde::{Deserialize, Serialize}; use std::{fmt::Debug, time::SystemTime}; -use crate::{ibv_access_into_flags, AccessFlag}; +use crate::{access::ibv_access_into_flags, AccessFlag}; /// Rdma Memory Region Access -pub trait MrAccess: Sync + Send + Debug { +pub trait MrAccess: Sync + Send + Debug + IbvAccess { /// Get the start addr fn addr(&self) -> usize; @@ -23,9 +23,6 @@ pub trait MrAccess: Sync + Send + Debug { /// Get the remote key fn rkey(&self) -> u32; - /// Get the `ibv_access_flags` of mr - fn ibv_access(&self) -> ibv_access_flags; - /// Get the `BitFlags` of mr #[inline] fn access(&self) -> BitFlags { @@ -33,6 +30,12 @@ pub trait MrAccess: Sync + Send + Debug { } } +/// Get `ibv_access_flags` of mr +pub trait IbvAccess { + /// Get `ibv_access_flags` of mr + fn ibv_access(&self) -> ibv_access_flags; +} + /// Memory region token used for the remote access #[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Copy, Debug)] pub struct MrToken { diff --git a/src/memory_region/raw.rs b/src/memory_region/raw.rs index 3417612..47572c2 100644 --- a/src/memory_region/raw.rs +++ b/src/memory_region/raw.rs @@ -1,4 +1,4 @@ -use super::MrAccess; +use super::{IbvAccess, MrAccess}; use crate::protection_domain::ProtectionDomain; use clippy_utilities::Cast; use rdma_sys::{ibv_access_flags, ibv_dereg_mr, ibv_mr, ibv_reg_mr}; @@ -36,7 +36,10 @@ impl MrAccess for RawMemoryRegion { // TODO: check safety unsafe { self.inner_mr.as_ref().rkey } } +} +impl IbvAccess for RawMemoryRegion { + #[inline] fn ibv_access(&self) -> ibv_access_flags { self.access } diff --git a/src/memory_region/remote.rs b/src/memory_region/remote.rs index 79634c7..042fea6 100644 --- a/src/memory_region/remote.rs +++ b/src/memory_region/remote.rs @@ -1,7 +1,6 @@ -use rdma_sys::ibv_access_flags; - -use super::{MrAccess, MrToken}; +use super::{IbvAccess, MrAccess, MrToken}; use crate::agent::AgentInner; +use rdma_sys::ibv_access_flags; use std::{io, ops::Range, sync::Arc, time::SystemTime}; /// Remote Memory Region Accrss @@ -43,7 +42,9 @@ impl MrAccess for RemoteMr { fn rkey(&self) -> u32 { self.token.rkey } +} +impl IbvAccess for RemoteMr { #[inline] fn ibv_access(&self) -> ibv_access_flags { ibv_access_flags(self.token.access) @@ -127,12 +128,15 @@ impl MrAccess for &RemoteMr { fn rkey(&self) -> u32 { self.token.rkey } +} +impl IbvAccess for &RemoteMr { #[inline] fn ibv_access(&self) -> ibv_access_flags { ibv_access_flags(self.token.access) } } + impl RemoteMrReadAccess for &RemoteMr { #[inline] fn token(&self) -> MrToken { @@ -155,12 +159,15 @@ impl MrAccess for &mut RemoteMr { fn rkey(&self) -> u32 { self.token.rkey } +} +impl IbvAccess for &mut RemoteMr { #[inline] fn ibv_access(&self) -> ibv_access_flags { ibv_access_flags(self.token.access) } } + impl RemoteMrReadAccess for &mut RemoteMr { #[inline] fn token(&self) -> MrToken { @@ -203,7 +210,9 @@ impl MrAccess for RemoteMrSlice<'_> { fn rkey(&self) -> u32 { self.token.rkey } +} +impl IbvAccess for RemoteMrSlice<'_> { #[inline] fn ibv_access(&self) -> ibv_access_flags { ibv_access_flags(self.token.access) @@ -241,7 +250,9 @@ impl MrAccess for RemoteMrSliceMut<'_> { fn rkey(&self) -> u32 { self.token.rkey } +} +impl IbvAccess for RemoteMrSliceMut<'_> { #[inline] fn ibv_access(&self) -> ibv_access_flags { ibv_access_flags(self.token.access) diff --git a/src/mr_allocator.rs b/src/mr_allocator.rs index d061ed2..69b95c5 100644 --- a/src/mr_allocator.rs +++ b/src/mr_allocator.rs @@ -662,7 +662,7 @@ pub(crate) fn register_extent_mr(addr: *mut c_void, size: usize, arena_ind: u32) mod tests { use super::*; use crate::{ - context::Context, flags_into_ibv_access, AccessFlag, LocalMrReadAccess, MrAccess, + access::ibv_access_into_flags, context::Context, AccessFlag, LocalMrReadAccess, MrAccess, RdmaBuilder, DEFAULT_ACCESS, }; use std::{alloc::Layout, io, thread}; @@ -930,8 +930,8 @@ mod tests { let layout = Layout::new::<[u8; 4096]>(); let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; let mr = rdma.alloc_local_mr_with_access(layout, access)?; - assert_eq!(mr.ibv_access(), flags_into_ibv_access(access)); - assert_ne!(mr.ibv_access(), *DEFAULT_ACCESS); + assert_eq!(mr.access(), access); + assert_ne!(mr.access(), ibv_access_into_flags(*DEFAULT_ACCESS)); Ok(()) } @@ -941,8 +941,8 @@ mod tests { let layout = Layout::new::<[u8; 4096]>(); let access = AccessFlag::LocalWrite | AccessFlag::RemoteRead; let mr = rdma.alloc_local_mr_with_access(layout, access)?; - assert_eq!(mr.ibv_access(), flags_into_ibv_access(access)); - assert_ne!(mr.ibv_access(), *DEFAULT_ACCESS); + assert_eq!(mr.access(), access); + assert_ne!(mr.access(), ibv_access_into_flags(*DEFAULT_ACCESS)); Ok(()) } From 5969fc3724e01a84940b4c085f3d83e2520dd946 Mon Sep 17 00:00:00 2001 From: why Date: Tue, 6 Sep 2022 16:58:32 +0800 Subject: [PATCH 13/14] update README with new APIs and safe code --- README.md | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 5cbd1fd..349f22d 100644 --- a/README.md +++ b/README.md @@ -122,36 +122,40 @@ And finally client `send_mr` to make server aware of this memory region. Server `receive_local_mr`, and then get data from this mr. ```rust -use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, Rdma, RdmaListener}; +use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, RdmaBuilder}; use portpicker::pick_unused_port; use std::{ alloc::Layout, - io, + io::{self, Write}, net::{Ipv4Addr, SocketAddrV4}, time::Duration, }; -struct Data(String); - async fn client(addr: SocketAddrV4) -> io::Result<()> { - let rdma = Rdma::connect(addr, 1, 1, 512).await?; - let mut lmr = rdma.alloc_local_mr(Layout::new::())?; - let mut rmr = rdma.request_remote_mr(Layout::new::()).await?; - // then send this mr to server to make server aware of this mr. - unsafe { *(*lmr.as_mut_ptr() as *mut Data) = Data("hello world".to_string()) }; - rdma.write(&lmr, &mut rmr).await?; - // send the content of lmr to server + let layout = Layout::new::<[u8; 8]>(); + let rdma = RdmaBuilder::default().connect(addr).await?; + // alloc 8 bytes remote memory + let mut rmr = rdma.request_remote_mr(layout).await?; + // alloc 8 bytes local memory + let mut lmr = rdma.alloc_local_mr(layout)?; + // write data into lmr + let _num = lmr.as_mut_slice().write(&[1_u8; 8])?; + // write the second half of the data in lmr to the rmr + rdma.write(&lmr.get(4..8).unwrap(), &mut rmr.get_mut(4..8).unwrap()) + .await?; + // send rmr's meta data to the remote end rdma.send_remote_mr(rmr).await?; Ok(()) } #[tokio::main] async fn server(addr: SocketAddrV4) -> io::Result<()> { - let rdma_listener = RdmaListener::bind(addr).await?; - let rdma = rdma_listener.accept(1, 1, 512).await?; + let rdma = RdmaBuilder::default().listen(addr).await?; + // receive mr's meta data from client let lmr = rdma.receive_local_mr().await?; - // print the content of lmr, which was `write` by client - unsafe { println!("{}", &*(*(*lmr.as_ptr() as *const Data)).0) }; + let data = *lmr.as_slice(); + println!("Data written by the client using RDMA WRITE: {:?}", data); + assert_eq!(data, [[0_u8; 4], [1_u8; 4]].concat()); Ok(()) } @@ -165,6 +169,7 @@ async fn main() { .map_err(|err| println!("{}", err)) .unwrap(); } + ``` ## Getting Help From 2133d580c2f02c8d230ae73da479dba4b1ed6ede Mon Sep 17 00:00:00 2001 From: why Date: Tue, 6 Sep 2022 19:38:41 +0800 Subject: [PATCH 14/14] update version to 0.4.0 --- Cargo.toml | 2 +- ChangeLog.md | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index e1a0d82..c01ef9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ keywords = ["rdma", "network"] license-file = "LICENSE" name = "async-rdma" repository = "https://github.com/datenlord/async-rdma" -version = "0.3.0" +version = "0.4.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] diff --git a/ChangeLog.md b/ChangeLog.md index 0e418e6..b319ae8 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -1,5 +1,44 @@ # ChangeLog +## 0.4.0 + + In this release, we add some APIs to make it easier to establish connections and + provide more configurable options. + + We provide support for sending/receiving raw data and allocating raw memory region + to enable users to implement the upper layer protocol or memory pool by themselves. + + APIs for supporting multi-connection can be used to easily establish multiple connections + with more than one remote end and reuse resources. + +### New features + +* Adapt RDMA CM APIs. Add `cm_connect` to establish connection with CM Server. + Add `{send/recv}_raw` APIs to send/recv raw data without the help of agent. +* Add APIs to set attributes of queue pair like `set_max_{send/recv}_{sge/wr}`. +* Add APIs for `RdmaBuilder` to establish connections conveniently. +* Add APIs to alloc raw mr. Sometimes users want to setup memory pool by themselves + instead of using `Jemalloc` to manage mrs. So we define two strategies. +* Add APIs to alloc mrs with different accesses and protection domains from `Jemalloc`. + `MrAllocator` will create a new arena when we alloc mr with not the default access and pd. +* Add access API for mrs to query access. +* Support multi-connection APIs. Add APIs for `Rdma` to create a new `Rdma` that has the + same `mr_allocator` and `event_listener` as parent. +* Add `RemoteMr` access control. Add `set_max_rmr_access` API to set the maximum permission on + `RemoteMr` that the remote end can request. + +### Optimizations and refactors + +* Use submodule to setup CI environment. +* Reorganize some attributes to avoid too many arguments. +* Update examples. Replace unsafe blocks and add more comments. Show more APIs. + +### Bug fixes + +* Disable `tcache` of `Jemalloc` as default. `tcache` is a feature of `Jemalloc` to speed up + memory allocation. However `Jemalloc` may alloc `MR` with wrong `arena_index` from `tcache` + when we create more than one `Jemalloc` enabled `mr_allocator`s. So we disable `tcache` by default. + ## 0.3.0 In this release, we adapted `Jemalloc` to manage RDMA memory region to improve memory