From 634e68c00f189452dc38e78294e4e3ce01715da9 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Mon, 16 Dec 2024 10:33:47 -0500 Subject: [PATCH] Avoid manual polling Rather than manually implementing non-blocking reads using polling, simply configure the underlying file descriptor to be non-blocking where necessary. Replace libc calls with the normal abstractions from the standard library. This makes the code less error prone and properly encodes type system invariants (such as `io::Read::read` requiring its receiver to be mutable) which were previously not preserved because the implementation used raw file descriptors. --- src/unix_term.rs | 284 ++++++++++++++++++----------------------------- 1 file changed, 111 insertions(+), 173 deletions(-) diff --git a/src/unix_term.rs b/src/unix_term.rs index 27e3624d..87e5a073 100644 --- a/src/unix_term.rs +++ b/src/unix_term.rs @@ -1,8 +1,7 @@ use std::env; -use std::convert::TryFrom as _; use std::fmt::Display; use std::fs; -use std::io::{self, BufRead, BufReader}; +use std::io::{self, BufRead, BufReader, Read}; use std::mem; use std::os::fd::{AsRawFd, RawFd}; use std::str; @@ -95,6 +94,15 @@ impl Input { } } +impl Read for Input { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self { + Self::Stdin(s) => s.read(buf), + Self::File(f) => f.read(buf), + } + } +} + // NB: this is not a full BufRead implementation because io::Stdin does not implement BufRead. impl Input { fn read_line(&mut self, buf: &mut String) -> io::Result { @@ -145,202 +153,132 @@ pub(crate) fn read_secure() -> io::Result { }) } -fn poll_fd(fd: RawFd, timeout: i32) -> io::Result { - let mut pollfd = libc::pollfd { - fd, - events: libc::POLLIN, - revents: 0, - }; - let ret = unsafe { libc::poll(&mut pollfd as *mut _, 1, timeout) }; - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(pollfd.revents & libc::POLLIN != 0) - } -} - -#[cfg(target_os = "macos")] -fn select_fd(fd: RawFd, timeout: i32) -> io::Result { - unsafe { - let mut read_fd_set: libc::fd_set = mem::zeroed(); - - let mut timeout_val; - let timeout = if timeout < 0 { - std::ptr::null_mut() - } else { - timeout_val = libc::timeval { - tv_sec: (timeout / 1000) as _, - tv_usec: (timeout * 1000) as _, - }; - &mut timeout_val - }; - - libc::FD_ZERO(&mut read_fd_set); - libc::FD_SET(fd, &mut read_fd_set); - let ret = libc::select( - fd + 1, - &mut read_fd_set, - std::ptr::null_mut(), - std::ptr::null_mut(), - timeout, - ); - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(libc::FD_ISSET(fd, &read_fd_set)) +fn read_single_char(input: &mut T) -> io::Result> { + let original = unsafe { libc::fcntl(input.as_raw_fd(), libc::F_GETFL) }; + c_result(|| unsafe { + libc::fcntl( + input.as_raw_fd(), + libc::F_SETFL, + original | libc::O_NONBLOCK, + ) + })?; + let mut buf = [0u8; 1]; + let result = read_bytes(input, &mut buf); + c_result(|| unsafe { libc::fcntl(input.as_raw_fd(), libc::F_SETFL, original) })?; + match result { + Ok(()) => { + let [byte] = buf; + Ok(Some(byte as char)) } - } -} - -fn select_or_poll_term_fd(fd: RawFd, timeout: i32) -> io::Result { - // There is a bug on macos that ttys cannot be polled, only select() - // works. However given how problematic select is in general, we - // normally want to use poll there too. - #[cfg(target_os = "macos")] - { - if unsafe { libc::isatty(fd) == 1 } { - return select_fd(fd, timeout); + Err(err) => { + if err.kind() == io::ErrorKind::WouldBlock { + Ok(None) + } else { + Err(err) + } } } - poll_fd(fd, timeout) } -fn read_single_char(fd: RawFd) -> io::Result> { - // timeout of zero means that it will not block - let is_ready = select_or_poll_term_fd(fd, 0)?; - - if is_ready { - // if there is something to be read, take 1 byte from it - let mut buf: [u8; 1] = [0]; - - read_bytes(fd, &mut buf)?; - Ok(Some(buf[0] as char)) +fn read_bytes(input: &mut impl Read, buf: &mut [u8]) -> io::Result<()> { + input.read_exact(buf)?; + if buf.starts_with(b"\x03") { + Err(io::Error::new( + io::ErrorKind::Interrupted, + "read interrupted", + )) } else { - //there is nothing to be read - Ok(None) - } -} - -// Similar to libc::read. Read count bytes into slice buf from descriptor fd. -// If successful, return the number of bytes read. -// Will return an error if nothing was read, i.e when called at end of file. -fn read_bytes(fd: RawFd, buf: &mut [u8]) -> io::Result<()> { - let read = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut _, buf.len()) }; - match usize::try_from(read) { - Err(std::num::TryFromIntError { .. }) => Err(io::Error::last_os_error()), - Ok(read) => { - if read != buf.len() { - Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "Reached end of file", - )) - } else if buf.starts_with(b"\x03") { - Err(io::Error::new( - io::ErrorKind::Interrupted, - "read interrupted", - )) - } else { - Ok(()) - } - } + Ok(()) } } -fn read_single_key_impl(fd: RawFd) -> Result { - loop { - match read_single_char(fd)? { - Some('\x1b') => { - // Escape was read, keep reading in case we find a familiar key - break if let Some(c1) = read_single_char(fd)? { - if c1 == '[' { - if let Some(c2) = read_single_char(fd)? { - match c2 { - 'A' => Ok(Key::ArrowUp), - 'B' => Ok(Key::ArrowDown), - 'C' => Ok(Key::ArrowRight), - 'D' => Ok(Key::ArrowLeft), - 'H' => Ok(Key::Home), - 'F' => Ok(Key::End), - 'Z' => Ok(Key::BackTab), - _ => { - let c3 = read_single_char(fd)?; - if let Some(c3) = c3 { - if c3 == '~' { - match c2 { - '1' => Ok(Key::Home), // tmux - '2' => Ok(Key::Insert), - '3' => Ok(Key::Del), - '4' => Ok(Key::End), // tmux - '5' => Ok(Key::PageUp), - '6' => Ok(Key::PageDown), - '7' => Ok(Key::Home), // xrvt - '8' => Ok(Key::End), // xrvt - _ => Ok(Key::UnknownEscSeq(vec![c1, c2, c3])), - } - } else { - Ok(Key::UnknownEscSeq(vec![c1, c2, c3])) +fn read_single_key_impl(fd: &mut T) -> Result { + // NB: this doesn't use `read_single_char` because we want a blocking read here. + let mut buf = [0u8; 1]; + read_bytes(fd, &mut buf)?; + let [byte] = buf; + match byte { + b'\x1b' => { + // Escape was read, keep reading in case we find a familiar key + if let Some(c1) = read_single_char(fd)? { + if c1 == '[' { + if let Some(c2) = read_single_char(fd)? { + match c2 { + 'A' => Ok(Key::ArrowUp), + 'B' => Ok(Key::ArrowDown), + 'C' => Ok(Key::ArrowRight), + 'D' => Ok(Key::ArrowLeft), + 'H' => Ok(Key::Home), + 'F' => Ok(Key::End), + 'Z' => Ok(Key::BackTab), + _ => { + let c3 = read_single_char(fd)?; + if let Some(c3) = c3 { + if c3 == '~' { + match c2 { + '1' => Ok(Key::Home), // tmux + '2' => Ok(Key::Insert), + '3' => Ok(Key::Del), + '4' => Ok(Key::End), // tmux + '5' => Ok(Key::PageUp), + '6' => Ok(Key::PageDown), + '7' => Ok(Key::Home), // xrvt + '8' => Ok(Key::End), // xrvt + _ => Ok(Key::UnknownEscSeq(vec![c1, c2, c3])), } } else { - // \x1b[ and 1 more char - Ok(Key::UnknownEscSeq(vec![c1, c2])) + Ok(Key::UnknownEscSeq(vec![c1, c2, c3])) } + } else { + // \x1b[ and 1 more char + Ok(Key::UnknownEscSeq(vec![c1, c2])) } } - } else { - // \x1b[ and no more input - Ok(Key::UnknownEscSeq(vec![c1])) } } else { - // char after escape is not [ + // \x1b[ and no more input Ok(Key::UnknownEscSeq(vec![c1])) } } else { - //nothing after escape - Ok(Key::Escape) - }; - } - Some(c) => { - let byte = c as u8; - let mut buf: [u8; 4] = [byte, 0, 0, 0]; - - break if byte & 224u8 == 192u8 { - // a two byte unicode character - read_bytes(fd, &mut buf[1..][..1])?; - Ok(key_from_utf8(&buf[..2])) - } else if byte & 240u8 == 224u8 { - // a three byte unicode character - read_bytes(fd, &mut buf[1..][..2])?; - Ok(key_from_utf8(&buf[..3])) - } else if byte & 248u8 == 240u8 { - // a four byte unicode character - read_bytes(fd, &mut buf[1..][..3])?; - Ok(key_from_utf8(&buf[..4])) - } else { - Ok(match c { - '\n' | '\r' => Key::Enter, - '\x7f' => Key::Backspace, - '\t' => Key::Tab, - '\x01' => Key::Home, // Control-A (home) - '\x05' => Key::End, // Control-E (end) - '\x08' => Key::Backspace, // Control-H (8) (Identical to '\b') - _ => Key::Char(c), - }) - }; - } - None => { - // there is no subsequent byte ready to be read, block and wait for input - // negative timeout means that it will block indefinitely - match select_or_poll_term_fd(fd, -1) { - Ok(_) => continue, - Err(_) => break Err(io::Error::last_os_error()), + // char after escape is not [ + Ok(Key::UnknownEscSeq(vec![c1])) } + } else { + //nothing after escape + Ok(Key::Escape) + } + } + byte => { + let mut buf: [u8; 4] = [byte, 0, 0, 0]; + if byte & 224u8 == 192u8 { + // a two byte unicode character + read_bytes(fd, &mut buf[1..][..1])?; + Ok(key_from_utf8(&buf[..2])) + } else if byte & 240u8 == 224u8 { + // a three byte unicode character + read_bytes(fd, &mut buf[1..][..2])?; + Ok(key_from_utf8(&buf[..3])) + } else if byte & 248u8 == 240u8 { + // a four byte unicode character + read_bytes(fd, &mut buf[1..][..3])?; + Ok(key_from_utf8(&buf[..4])) + } else { + Ok(match byte as char { + '\n' | '\r' => Key::Enter, + '\x7f' => Key::Backspace, + '\t' => Key::Tab, + '\x01' => Key::Home, // Control-A (home) + '\x05' => Key::End, // Control-E (end) + '\x08' => Key::Backspace, // Control-H (8) (Identical to '\b') + c => Key::Char(c), + }) } } } } pub(crate) fn read_single_key(ctrlc_key: bool) -> io::Result { - let input = Input::unbuffered()?; + let mut input = Input::unbuffered()?; let mut termios = core::mem::MaybeUninit::uninit(); c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?; @@ -349,7 +287,7 @@ pub(crate) fn read_single_key(ctrlc_key: bool) -> io::Result { unsafe { libc::cfmakeraw(&mut termios) }; termios.c_oflag = original.c_oflag; c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &termios) })?; - let rv = read_single_key_impl(input.as_raw_fd()); + let rv = read_single_key_impl(&mut input); c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &original) })?; // if the user hit ^C we want to signal SIGINT to ourselves.