diff --git a/src/api/src/types.rs b/src/api/src/types.rs index e6205da..684546e 100644 --- a/src/api/src/types.rs +++ b/src/api/src/types.rs @@ -84,7 +84,7 @@ pub struct BaoDMInfo { pub fd: i32, } -#[derive(Debug, Deserialize, Serialize, PartialEq)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] /// Struct representing a Device configuration. /// /// # Attributes @@ -123,6 +123,8 @@ pub struct DeviceConfig { pub guest_cid: Option, // Vhost-user device specific fields pub socket_path: Option, + // Console device specific fields + pub pty_alias: Option, } #[derive(Debug, Deserialize, Serialize, PartialEq)] diff --git a/src/virtio/src/console/virtio/console_handler.rs b/src/virtio/src/console/virtio/console_handler.rs index 462f514..20f4466 100644 --- a/src/virtio/src/console/virtio/console_handler.rs +++ b/src/virtio/src/console/virtio/console_handler.rs @@ -2,6 +2,7 @@ use super::queue_handler::{INPUT_QUEUE_INDEX, OUTPUT_QUEUE_INDEX}; use crate::device::SignalUsedQueue; use std::io::Write; use std::result; +use std::sync::{Arc, Mutex}; use virtio_console::console::{Console, Error as ConsoleError}; use virtio_queue::{Queue, QueueOwnedT, QueueT}; use vm_memory::bitmap::AtomicBitmap; @@ -14,7 +15,7 @@ pub struct ConsoleQueueHandler { pub mem: GuestMemoryMmap, pub input_queue: Queue, pub output_queue: Queue, - pub console: Console, + pub console: Arc>>, } impl ConsoleQueueHandler @@ -32,17 +33,21 @@ where // To see why this is done in a loop, please look at the `Queue::enable_notification` // comments in `virtio_queue`. loop { - if self.console.is_input_buffer_empty() { + if self.console.lock().unwrap().is_input_buffer_empty() { break; } // Disable the notifications. self.input_queue.disable_notification(&self.mem)?; - while !self.console.is_input_buffer_empty() { + while !self.console.lock().unwrap().is_input_buffer_empty() { // Process the queue. if let Some(mut chain) = self.input_queue.iter(&self.mem.clone())?.next() { - let sent_bytes = self.console.process_receiveq_chain(&mut chain)?; + let sent_bytes = self + .console + .lock() + .unwrap() + .process_receiveq_chain(&mut chain)?; if sent_bytes > 0 { self.input_queue.add_used( @@ -85,7 +90,10 @@ where // Process the queue. while let Some(mut chain) = self.output_queue.iter(&self.mem.clone())?.next() { - self.console.process_transmitq_chain(&mut chain)?; + self.console + .lock() + .unwrap() + .process_transmitq_chain(&mut chain)?; self.output_queue .add_used(chain.memory(), chain.head_index(), 0)?; diff --git a/src/virtio/src/console/virtio/device.rs b/src/virtio/src/console/virtio/device.rs index b1e9aac..eb5e45c 100644 --- a/src/virtio/src/console/virtio/device.rs +++ b/src/virtio/src/console/virtio/device.rs @@ -1,4 +1,5 @@ use super::console_handler::ConsoleQueueHandler; +use super::pty_handler::PtyHandler; use super::queue_handler::QueueHandler; use crate::device::{SingleFdSignalQueue, Subscriber, VirtioDeviceT}; use crate::device::{VirtioDevType, VirtioDeviceCommon}; @@ -9,6 +10,7 @@ use event_manager::{ EventManager, MutEventSubscriber, RemoteEndpoint, Result as EvmgrResult, SubscriberId, }; use std::borrow::{Borrow, BorrowMut}; +use std::os::unix::net::UnixStream; use std::sync::{Arc, Mutex}; use virtio_bindings::virtio_config::VIRTIO_F_IN_ORDER; use virtio_console::console::Console; @@ -27,6 +29,7 @@ use vm_device::MutDeviceMmio; pub struct VirtioConsole { pub common: VirtioDeviceCommon, pub endpoint: RemoteEndpoint, + pub config: DeviceConfig, } impl VirtioDeviceT for VirtioConsole { @@ -58,6 +61,7 @@ impl VirtioDeviceT for VirtioConsole { let console = Arc::new(Mutex::new(VirtioConsole { common: common_device, endpoint: remote_endpoint, + config: config.clone(), })); // Register the MMIO device within the device manager with the specified range. @@ -114,8 +118,12 @@ impl VirtioDeviceActions for VirtioConsole { type E = Error; fn activate(&mut self) -> Result<()> { + // Create socket to act as console output and forward it to pty + let (socket_out, socket_in) = UnixStream::pair().unwrap(); + socket_in.set_nonblocking(true).unwrap(); + // Create the backend. - let console = Console::default(); + let console = Arc::new(Mutex::new(Console::new(socket_out))); // Create the driver notify object. let driver_notify = SingleFdSignalQueue { @@ -132,14 +140,16 @@ impl VirtioDeviceActions for VirtioConsole { mem: self.common.mem(), input_queue: self.common.config.queues.remove(0), output_queue: self.common.config.queues.remove(0), - console, + console: Arc::clone(&console), }; // Create the queue handler. + let input_ioeventfd = ioevents.remove(0); + let output_ioeventfd = ioevents.remove(0); let handler = Arc::new(Mutex::new(QueueHandler { inner, - input_ioeventfd: ioevents.remove(0), - output_ioeventfd: ioevents.remove(0), + input_ioeventfd: input_ioeventfd.try_clone().unwrap(), + output_ioeventfd, })); // Register the queue handler with the `EventManager`. We could record the `sub_id` @@ -152,6 +162,20 @@ impl VirtioDeviceActions for VirtioConsole { }) .unwrap(); + // Create pty handler and register it as a event subscriber + let pty_handler = Arc::new(Mutex::new(PtyHandler::new( + socket_in, + Arc::clone(&console), + input_ioeventfd, + &self.config, + ))); + + self.endpoint + .call_blocking(|mgr| -> EvmgrResult { + Ok(mgr.add_subscriber(pty_handler)) + }) + .unwrap(); + // Set the device as activated. self.common.config.device_activated = true; diff --git a/src/virtio/src/console/virtio/mod.rs b/src/virtio/src/console/virtio/mod.rs index 5b45c5a..0b237fe 100644 --- a/src/virtio/src/console/virtio/mod.rs +++ b/src/virtio/src/console/virtio/mod.rs @@ -1,3 +1,4 @@ pub mod console_handler; pub mod device; +mod pty_handler; pub mod queue_handler; diff --git a/src/virtio/src/console/virtio/pty_handler.rs b/src/virtio/src/console/virtio/pty_handler.rs new file mode 100644 index 0000000..52b0e15 --- /dev/null +++ b/src/virtio/src/console/virtio/pty_handler.rs @@ -0,0 +1,117 @@ +use std::fs::{File, OpenOptions}; +use std::io::{Read, Write}; +use std::os::fd::AsRawFd; +use std::os::unix::fs::OpenOptionsExt; +use std::sync::{Arc, Mutex}; + +use api::types::DeviceConfig; +use event_manager::{EventOps, Events, MutEventSubscriber}; +use libc::IN_NONBLOCK; +use std::os::unix::net::UnixStream; +use virtio_console::console::Console; +use vm_memory::WriteVolatile; +use vmm_sys_util::epoll::EventSet; +use vmm_sys_util::eventfd::EventFd; + +const SOURCE_PTY: u32 = 0; +const SOURCE_SOCKET: u32 = 1; + +const BUFFER_SIZE: usize = 128; + +pub(super) struct PtyHandler { + pub pty: File, + pub socket: UnixStream, + pub console: Arc>>, + pub input_ioeventfd: EventFd, +} + +impl PtyHandler +where + W: Write + WriteVolatile, +{ + pub fn new( + socket: UnixStream, + console: Arc>>, + input_ioeventfd: EventFd, + config: &DeviceConfig, + ) -> Self { + let pty = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .custom_flags(IN_NONBLOCK) + .open("/dev/ptmx") + .unwrap(); + + let pty_name = unsafe { + libc::grantpt(pty.as_raw_fd()); + libc::unlockpt(pty.as_raw_fd()); + std::ffi::CStr::from_ptr(libc::ptsname(pty.as_raw_fd())) + }; + + let pty_path = if let Some(pty_alias) = config.pty_alias.clone() { + std::os::unix::fs::symlink(pty_name.to_str().unwrap(), pty_alias.as_str()) + .expect(&format!("Failed to create pty handler alias {}", pty_alias)); + pty_alias + } else { + String::from(pty_name.to_str().unwrap()) + }; + + println!("virtio-console device id {} at {}", config.id, pty_path); + + Self { + pty, + socket, + console, + input_ioeventfd, + } + } +} + +impl MutEventSubscriber for PtyHandler +where + W: Write + WriteVolatile, +{ + fn init(&mut self, ops: &mut EventOps) { + ops.add(Events::with_data( + &self.pty, + SOURCE_PTY, + EventSet::IN | EventSet::EDGE_TRIGGERED, + )) + .expect("Failed to init pty event"); + + ops.add(Events::with_data( + &self.socket, + SOURCE_SOCKET, + EventSet::IN | EventSet::EDGE_TRIGGERED, + )) + .expect("Failed to init socket event"); + } + + fn process(&mut self, events: Events, ops: &mut EventOps) { + let mut buf = [0u8; BUFFER_SIZE]; + + match events.data() { + SOURCE_PTY => { + while let Ok(n) = self.pty.read(&mut buf) { + let mut v: Vec<_> = buf[..n].iter().cloned().collect(); + self.console.lock().unwrap().enqueue_data(&mut v).unwrap(); + self.input_ioeventfd.write(1).unwrap(); + } + } + SOURCE_SOCKET => { + while let Ok(n) = self.socket.read(&mut buf) { + let v: Vec<_> = buf[..n].iter().cloned().collect(); + self.pty.write(&v).unwrap(); + } + } + _ => { + log::error!( + "PtyHandler unexpected event data: {}. Removing event...", + events.data() + ); + ops.remove(events).expect("Failed to remove event"); + } + } + } +}