Skip to content

Commit

Permalink
net/uds: add methods to connect/bind with a socket address
Browse files Browse the repository at this point in the history
This enhances `connect` and `bind` logic for Unix Domain Sockets
(UDS), by adding methods which allow to directly use a socket
address. This mirrors similar features which already exist in
`mio` for TCP and UDP sockets.
  • Loading branch information
lucab committed Aug 5, 2023
1 parent 236fc31 commit f251907
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 16 deletions.
7 changes: 6 additions & 1 deletion src/net/uds/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@ pub struct UnixListener {
}

impl UnixListener {
/// Creates a new `UnixListener` bound to the specified socket.
/// Creates a new `UnixListener` bound to the specified socket `path`.
pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<UnixListener> {
sys::uds::listener::bind(path.as_ref()).map(UnixListener::from_std)
}

/// Creates a new `UnixListener` bound to the specified socket `address`.
pub fn bind_addr(address: &SocketAddr) -> io::Result<UnixListener> {
sys::uds::listener::bind_addr(address).map(UnixListener::from_std)
}

/// Creates a new `UnixListener` from a standard `net::UnixListener`.
///
/// This function is intended to be used to wrap a Unix listener from the
Expand Down
9 changes: 9 additions & 0 deletions src/net/uds/stream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::io_source::IoSource;
use crate::net::SocketAddr;
use crate::{event, sys, Interest, Registry, Token};

use std::fmt;
Expand All @@ -22,6 +23,14 @@ impl UnixStream {
sys::uds::stream::connect(path.as_ref()).map(UnixStream::from_std)
}

/// Connects to the socket named by `address`.
///
/// This may return a `WouldBlock` in which case the socket connection
/// cannot be completed immediately. Usually it means the backlog is full.
pub fn connect_addr(address: &SocketAddr) -> io::Result<UnixStream> {
sys::uds::stream::connect_addr(address).map(UnixStream::from_std)
}

/// Creates a new `UnixStream` from a standard `net::UnixStream`.
///
/// This function is intended to be used to wrap a Unix stream from the
Expand Down
8 changes: 8 additions & 0 deletions src/sys/shell/uds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ pub(crate) mod listener {
os_required!()
}

pub(crate) fn bind_addr(_: &SocketAddr) -> io::Result<net::UnixListener> {
os_required!()
}

pub(crate) fn accept(_: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> {
os_required!()
}
Expand All @@ -61,6 +65,10 @@ pub(crate) mod stream {
os_required!()
}

pub(crate) fn connect_addr(_: &SocketAddr) -> io::Result<net::UnixStream> {
os_required!()
}

pub(crate) fn pair() -> io::Result<(net::UnixStream, net::UnixStream)> {
os_required!()
}
Expand Down
3 changes: 2 additions & 1 deletion src/sys/unix/uds/datagram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ use super::{socket_addr, SocketAddr};
use crate::sys::unix::net::new_socket;

use std::io;
use std::os::unix::ffi::OsStrExt;
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::os::unix::net;
use std::path::Path;

pub(crate) fn bind(path: &Path) -> io::Result<net::UnixDatagram> {
let (sockaddr, socklen) = socket_addr(path)?;
let (sockaddr, socklen) = socket_addr(path.as_os_str().as_bytes())?;
let sockaddr = &sockaddr as *const libc::sockaddr_un as *const _;

let socket = unbound()?;
Expand Down
15 changes: 12 additions & 3 deletions src/sys/unix/uds/listener.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
use super::socket_addr;
use crate::net::{SocketAddr, UnixStream};
use crate::sys::unix::net::new_socket;
use std::os::unix::ffi::OsStrExt;
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::os::unix::net;
use std::path::Path;
use std::{io, mem};

pub(crate) fn bind(path: &Path) -> io::Result<net::UnixListener> {
let (sockaddr, socklen) = socket_addr(path)?;
let sockaddr = &sockaddr as *const libc::sockaddr_un as *const libc::sockaddr;
let socket_address = {
let (sockaddr, socklen) = socket_addr(path.as_os_str().as_bytes())?;
SocketAddr::from_parts(sockaddr, socklen)
};

bind_addr(&socket_address)
}

pub(crate) fn bind_addr(address: &SocketAddr) -> io::Result<net::UnixListener> {
let fd = new_socket(libc::AF_UNIX, libc::SOCK_STREAM)?;
let socket = unsafe { net::UnixListener::from_raw_fd(fd) };
syscall!(bind(fd, sockaddr, socklen))?;
let sockaddr = address.raw_sockaddr() as *const libc::sockaddr_un as *const libc::sockaddr;

syscall!(bind(fd, sockaddr, *address.raw_socklen()))?;
syscall!(listen(fd, 1024))?;

Ok(socket)
Expand Down
12 changes: 4 additions & 8 deletions src/sys/unix/uds/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,14 @@ pub(in crate::sys) fn path_offset(sockaddr: &libc::sockaddr_un) -> usize {

cfg_os_poll! {
use std::cmp::Ordering;
use std::os::unix::ffi::OsStrExt;
use std::os::unix::io::{RawFd, FromRawFd};
use std::path::Path;
use std::{io, mem};

pub(crate) mod datagram;
pub(crate) mod listener;
pub(crate) mod stream;

pub(in crate::sys) fn socket_addr(path: &Path) -> io::Result<(libc::sockaddr_un, libc::socklen_t)> {
pub(in crate::sys) fn socket_addr(bytes: &[u8]) -> io::Result<(libc::sockaddr_un, libc::socklen_t)> {
let sockaddr = mem::MaybeUninit::<libc::sockaddr_un>::zeroed();

// This is safe to assume because a `libc::sockaddr_un` filled with `0`
Expand All @@ -39,7 +37,6 @@ cfg_os_poll! {

sockaddr.sun_family = libc::AF_UNIX as libc::sa_family_t;

let bytes = path.as_os_str().as_bytes();
match (bytes.first(), bytes.len().cmp(&sockaddr.sun_path.len())) {
// Abstract paths don't need a null terminator
(Some(&0), Ordering::Greater) => {
Expand Down Expand Up @@ -128,6 +125,7 @@ cfg_os_poll! {
#[cfg(test)]
mod tests {
use super::{path_offset, socket_addr};
use std::os::unix::ffi::OsStrExt;
use std::path::Path;
use std::str;

Expand All @@ -139,7 +137,7 @@ cfg_os_poll! {
// Pathname addresses do have a null terminator, so `socklen` is
// expected to be `PATH_LEN` + `offset` + 1.
let path = Path::new(PATH);
let (sockaddr, actual) = socket_addr(path).unwrap();
let (sockaddr, actual) = socket_addr(path.as_os_str().as_bytes()).unwrap();
let offset = path_offset(&sockaddr);
let expected = PATH_LEN + offset + 1;
assert_eq!(expected as libc::socklen_t, actual)
Expand All @@ -152,9 +150,7 @@ cfg_os_poll! {

// Abstract addresses do not have a null terminator, so `socklen` is
// expected to be `PATH_LEN` + `offset`.
let abstract_path = str::from_utf8(PATH).unwrap();
let path = Path::new(abstract_path);
let (sockaddr, actual) = socket_addr(path).unwrap();
let (sockaddr, actual) = socket_addr(PATH).unwrap();
let offset = path_offset(&sockaddr);
let expected = PATH_LEN + offset;
assert_eq!(expected as libc::socklen_t, actual)
Expand Down
8 changes: 8 additions & 0 deletions src/sys/unix/uds/socketaddr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ cfg_os_poll! {
SocketAddr { sockaddr, socklen }
}

pub(crate) fn raw_sockaddr(&self) -> &libc::sockaddr_un {
&self.sockaddr
}

pub(crate) fn raw_socklen(&self) -> &libc::socklen_t {
&self.socklen
}

/// Returns `true` if the address is unnamed.
///
/// Documentation reflected in [`SocketAddr`]
Expand Down
15 changes: 12 additions & 3 deletions src/sys/unix/uds/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,26 @@ use super::{socket_addr, SocketAddr};
use crate::sys::unix::net::new_socket;

use std::io;
use std::os::unix::ffi::OsStrExt;
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::os::unix::net;
use std::path::Path;

pub(crate) fn connect(path: &Path) -> io::Result<net::UnixStream> {
let (sockaddr, socklen) = socket_addr(path)?;
let sockaddr = &sockaddr as *const libc::sockaddr_un as *const libc::sockaddr;
let socket_address = {
let (sockaddr, socklen) = socket_addr(path.as_os_str().as_bytes())?;
SocketAddr::from_parts(sockaddr, socklen)
};

connect_addr(&socket_address)
}

pub(crate) fn connect_addr(address: &SocketAddr) -> io::Result<net::UnixStream> {
let fd = new_socket(libc::AF_UNIX, libc::SOCK_STREAM)?;
let socket = unsafe { net::UnixStream::from_raw_fd(fd) };
match syscall!(connect(fd, sockaddr, socklen)) {
let sockaddr = address.raw_sockaddr() as *const libc::sockaddr_un as *const libc::sockaddr;

match syscall!(connect(fd, sockaddr, *address.raw_socklen())) {
Ok(_) => {}
Err(ref err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}
Err(e) => return Err(e),
Expand Down
49 changes: 49 additions & 0 deletions tests/unix_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,55 @@ fn unix_stream_connect() {
handle.join().unwrap();
}

#[test]
fn unix_stream_connect_addr() {
let (mut poll, mut events) = init_with_poll();
let barrier = Arc::new(Barrier::new(2));
let local_addr = {
// Workaround through a temporary listener using the same address,
// as there is currently no way of directly building a `SocketAddr`.
let path = temp_file("unix_stream_connect_addr");
let listener = net::UnixListener::bind(path.clone()).unwrap();
let mio_listener = mio::net::UnixListener::from_std(listener);
let address = mio_listener.local_addr().unwrap();
drop(mio_listener);
_ = std::fs::remove_file(&path);
address
};
let path = temp_file("unix_stream_connect_addr");
let listener = net::UnixListener::bind(path).unwrap();
let mut stream = UnixStream::connect_addr(&local_addr).unwrap();

let barrier_clone = barrier.clone();
let handle = thread::spawn(move || {
let (stream, _) = listener.accept().unwrap();
barrier_clone.wait();
drop(stream);
});

poll.registry()
.register(
&mut stream,
TOKEN_1,
Interest::READABLE | Interest::WRITABLE,
)
.unwrap();
expect_events(
&mut poll,
&mut events,
vec![ExpectEvent::new(TOKEN_1, Interest::WRITABLE)],
);

barrier.wait();
expect_events(
&mut poll,
&mut events,
vec![ExpectEvent::new(TOKEN_1, Interest::READABLE)],
);

handle.join().unwrap();
}

#[test]
fn unix_stream_from_std() {
smoke_test(
Expand Down

0 comments on commit f251907

Please sign in to comment.