diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index 37e8106d8..eeffe042e 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -13,11 +13,16 @@ pub struct UnixListener { } impl UnixListener { - /// Creates a new `UnixListener` bound to the specified socket. + /// Creates a new `UnixListener` bound to the specified socket `path`. pub fn bind>(path: P) -> io::Result { sys::uds::listener::bind(path.as_ref()).map(UnixListener::from_std) } + /// Creates a new `UnixListener` bound to the specified socket `address`. + pub fn bind_addr(address: &SocketAddr) -> io::Result { + sys::uds::listener::bind_addr(address).map(UnixListener::from_std) + } + /// Creates a new `UnixListener` from a standard `net::UnixListener`. /// /// This function is intended to be used to wrap a Unix listener from the diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index b38812e5d..1c17d84a1 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -1,4 +1,5 @@ use crate::io_source::IoSource; +use crate::net::SocketAddr; use crate::{event, sys, Interest, Registry, Token}; use std::fmt; @@ -22,6 +23,14 @@ impl UnixStream { sys::uds::stream::connect(path.as_ref()).map(UnixStream::from_std) } + /// Connects to the socket named by `address`. + /// + /// This may return a `WouldBlock` in which case the socket connection + /// cannot be completed immediately. Usually it means the backlog is full. + pub fn connect_addr(address: &SocketAddr) -> io::Result { + sys::uds::stream::connect_addr(address).map(UnixStream::from_std) + } + /// Creates a new `UnixStream` from a standard `net::UnixStream`. /// /// This function is intended to be used to wrap a Unix stream from the diff --git a/src/sys/shell/uds.rs b/src/sys/shell/uds.rs index c18aca042..bac547b03 100644 --- a/src/sys/shell/uds.rs +++ b/src/sys/shell/uds.rs @@ -42,6 +42,10 @@ pub(crate) mod listener { os_required!() } + pub(crate) fn bind_addr(_: &SocketAddr) -> io::Result { + os_required!() + } + pub(crate) fn accept(_: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> { os_required!() } @@ -61,6 +65,10 @@ pub(crate) mod stream { os_required!() } + pub(crate) fn connect_addr(_: &SocketAddr) -> io::Result { + os_required!() + } + pub(crate) fn pair() -> io::Result<(net::UnixStream, net::UnixStream)> { os_required!() } diff --git a/src/sys/unix/uds/datagram.rs b/src/sys/unix/uds/datagram.rs index a5ada72ef..5853a4d50 100644 --- a/src/sys/unix/uds/datagram.rs +++ b/src/sys/unix/uds/datagram.rs @@ -2,12 +2,13 @@ use super::{socket_addr, SocketAddr}; use crate::sys::unix::net::new_socket; use std::io; +use std::os::unix::ffi::OsStrExt; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::os::unix::net; use std::path::Path; pub(crate) fn bind(path: &Path) -> io::Result { - let (sockaddr, socklen) = socket_addr(path)?; + let (sockaddr, socklen) = socket_addr(path.as_os_str().as_bytes())?; let sockaddr = &sockaddr as *const libc::sockaddr_un as *const _; let socket = unbound()?; diff --git a/src/sys/unix/uds/listener.rs b/src/sys/unix/uds/listener.rs index adacf63d9..2a070773d 100644 --- a/src/sys/unix/uds/listener.rs +++ b/src/sys/unix/uds/listener.rs @@ -1,18 +1,27 @@ use super::socket_addr; use crate::net::{SocketAddr, UnixStream}; use crate::sys::unix::net::new_socket; +use std::os::unix::ffi::OsStrExt; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::os::unix::net; use std::path::Path; use std::{io, mem}; pub(crate) fn bind(path: &Path) -> io::Result { - let (sockaddr, socklen) = socket_addr(path)?; - let sockaddr = &sockaddr as *const libc::sockaddr_un as *const libc::sockaddr; + let socket_address = { + let (sockaddr, socklen) = socket_addr(path.as_os_str().as_bytes())?; + SocketAddr::from_parts(sockaddr, socklen) + }; + + bind_addr(&socket_address) +} +pub(crate) fn bind_addr(address: &SocketAddr) -> io::Result { let fd = new_socket(libc::AF_UNIX, libc::SOCK_STREAM)?; let socket = unsafe { net::UnixListener::from_raw_fd(fd) }; - syscall!(bind(fd, sockaddr, socklen))?; + let sockaddr = address.raw_sockaddr() as *const libc::sockaddr_un as *const libc::sockaddr; + + syscall!(bind(fd, sockaddr, *address.raw_socklen()))?; syscall!(listen(fd, 1024))?; Ok(socket) diff --git a/src/sys/unix/uds/mod.rs b/src/sys/unix/uds/mod.rs index 7ffe498c1..2f8a898fa 100644 --- a/src/sys/unix/uds/mod.rs +++ b/src/sys/unix/uds/mod.rs @@ -15,16 +15,14 @@ pub(in crate::sys) fn path_offset(sockaddr: &libc::sockaddr_un) -> usize { cfg_os_poll! { use std::cmp::Ordering; - use std::os::unix::ffi::OsStrExt; use std::os::unix::io::{RawFd, FromRawFd}; - use std::path::Path; use std::{io, mem}; pub(crate) mod datagram; pub(crate) mod listener; pub(crate) mod stream; - pub(in crate::sys) fn socket_addr(path: &Path) -> io::Result<(libc::sockaddr_un, libc::socklen_t)> { + pub(in crate::sys) fn socket_addr(bytes: &[u8]) -> io::Result<(libc::sockaddr_un, libc::socklen_t)> { let sockaddr = mem::MaybeUninit::::zeroed(); // This is safe to assume because a `libc::sockaddr_un` filled with `0` @@ -39,7 +37,6 @@ cfg_os_poll! { sockaddr.sun_family = libc::AF_UNIX as libc::sa_family_t; - let bytes = path.as_os_str().as_bytes(); match (bytes.first(), bytes.len().cmp(&sockaddr.sun_path.len())) { // Abstract paths don't need a null terminator (Some(&0), Ordering::Greater) => { @@ -128,6 +125,7 @@ cfg_os_poll! { #[cfg(test)] mod tests { use super::{path_offset, socket_addr}; + use std::os::unix::ffi::OsStrExt; use std::path::Path; use std::str; @@ -139,7 +137,7 @@ cfg_os_poll! { // Pathname addresses do have a null terminator, so `socklen` is // expected to be `PATH_LEN` + `offset` + 1. let path = Path::new(PATH); - let (sockaddr, actual) = socket_addr(path).unwrap(); + let (sockaddr, actual) = socket_addr(path.as_os_str().as_bytes()).unwrap(); let offset = path_offset(&sockaddr); let expected = PATH_LEN + offset + 1; assert_eq!(expected as libc::socklen_t, actual) @@ -152,9 +150,7 @@ cfg_os_poll! { // Abstract addresses do not have a null terminator, so `socklen` is // expected to be `PATH_LEN` + `offset`. - let abstract_path = str::from_utf8(PATH).unwrap(); - let path = Path::new(abstract_path); - let (sockaddr, actual) = socket_addr(path).unwrap(); + let (sockaddr, actual) = socket_addr(PATH).unwrap(); let offset = path_offset(&sockaddr); let expected = PATH_LEN + offset; assert_eq!(expected as libc::socklen_t, actual) diff --git a/src/sys/unix/uds/socketaddr.rs b/src/sys/unix/uds/socketaddr.rs index 4c7c41161..8e0ef53a4 100644 --- a/src/sys/unix/uds/socketaddr.rs +++ b/src/sys/unix/uds/socketaddr.rs @@ -73,6 +73,14 @@ cfg_os_poll! { SocketAddr { sockaddr, socklen } } + pub(crate) fn raw_sockaddr(&self) -> &libc::sockaddr_un { + &self.sockaddr + } + + pub(crate) fn raw_socklen(&self) -> &libc::socklen_t { + &self.socklen + } + /// Returns `true` if the address is unnamed. /// /// Documentation reflected in [`SocketAddr`] diff --git a/src/sys/unix/uds/stream.rs b/src/sys/unix/uds/stream.rs index 461917c12..261d27b0d 100644 --- a/src/sys/unix/uds/stream.rs +++ b/src/sys/unix/uds/stream.rs @@ -2,17 +2,26 @@ use super::{socket_addr, SocketAddr}; use crate::sys::unix::net::new_socket; use std::io; +use std::os::unix::ffi::OsStrExt; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::os::unix::net; use std::path::Path; pub(crate) fn connect(path: &Path) -> io::Result { - let (sockaddr, socklen) = socket_addr(path)?; - let sockaddr = &sockaddr as *const libc::sockaddr_un as *const libc::sockaddr; + let socket_address = { + let (sockaddr, socklen) = socket_addr(path.as_os_str().as_bytes())?; + SocketAddr::from_parts(sockaddr, socklen) + }; + connect_addr(&socket_address) +} + +pub(crate) fn connect_addr(address: &SocketAddr) -> io::Result { let fd = new_socket(libc::AF_UNIX, libc::SOCK_STREAM)?; let socket = unsafe { net::UnixStream::from_raw_fd(fd) }; - match syscall!(connect(fd, sockaddr, socklen)) { + let sockaddr = address.raw_sockaddr() as *const libc::sockaddr_un as *const libc::sockaddr; + + match syscall!(connect(fd, sockaddr, *address.raw_socklen())) { Ok(_) => {} Err(ref err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {} Err(e) => return Err(e), diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index babf18df9..5270bdc91 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -77,6 +77,55 @@ fn unix_stream_connect() { handle.join().unwrap(); } +#[test] +fn unix_stream_connect_addr() { + let (mut poll, mut events) = init_with_poll(); + let barrier = Arc::new(Barrier::new(2)); + let local_addr = { + // Workaround through a temporary listener using the same address, + // as there is currently no way of directly building a `SocketAddr`. + let path = temp_file("unix_stream_connect_addr"); + let listener = net::UnixListener::bind(path.clone()).unwrap(); + let mio_listener = mio::net::UnixListener::from_std(listener); + let address = mio_listener.local_addr().unwrap(); + drop(mio_listener); + _ = std::fs::remove_file(&path); + address + }; + let path = temp_file("unix_stream_connect_addr"); + let listener = net::UnixListener::bind(path).unwrap(); + let mut stream = UnixStream::connect_addr(&local_addr).unwrap(); + + let barrier_clone = barrier.clone(); + let handle = thread::spawn(move || { + let (stream, _) = listener.accept().unwrap(); + barrier_clone.wait(); + drop(stream); + }); + + poll.registry() + .register( + &mut stream, + TOKEN_1, + Interest::READABLE | Interest::WRITABLE, + ) + .unwrap(); + expect_events( + &mut poll, + &mut events, + vec![ExpectEvent::new(TOKEN_1, Interest::WRITABLE)], + ); + + barrier.wait(); + expect_events( + &mut poll, + &mut events, + vec![ExpectEvent::new(TOKEN_1, Interest::READABLE)], + ); + + handle.join().unwrap(); +} + #[test] fn unix_stream_from_std() { smoke_test(