diff --git a/src/file_operations.rs b/src/file_operations.rs index aca55946..4376873e 100644 --- a/src/file_operations.rs +++ b/src/file_operations.rs @@ -6,7 +6,7 @@ use alloc::boxed::Box; use crate::bindings; use crate::c_types; use crate::error::{Error, KernelResult}; -use crate::user_ptr::{UserSlicePtr, UserSlicePtrWriter}; +use crate::user_ptr::{UserSlicePtr, UserSlicePtrReader, UserSlicePtrWriter}; pub struct File { ptr: *const bindings::file, @@ -70,6 +70,33 @@ unsafe extern "C" fn read_callback( } } +unsafe extern "C" fn write_callback( + file: *mut bindings::file, + buf: *const c_types::c_char, + len: c_types::c_size_t, + offset: *mut bindings::loff_t, +) -> c_types::c_ssize_t { + let mut data = match UserSlicePtr::new(buf as *mut c_types::c_void, len) { + Ok(ptr) => ptr.reader(), + Err(e) => return e.to_kernel_errno().try_into().unwrap(), + }; + let f = &*((*file).private_data as *const T); + // No FMODE_UNSIGNED_OFFSET support, so offset must be in [0, 2^63). + // See discussion in #113 + let positive_offset = match (*offset).try_into() { + Ok(v) => v, + Err(_) => return Error::EINVAL.to_kernel_errno().try_into().unwrap(), + }; + match f.write(&mut data, positive_offset) { + Ok(()) => { + let read = len - data.len(); + (*offset) += bindings::loff_t::try_from(read).unwrap(); + read.try_into().unwrap() + } + Err(e) => e.to_kernel_errno().try_into().unwrap(), + } +} + unsafe extern "C" fn release_callback( _inode: *mut bindings::inode, file: *mut bindings::file, @@ -168,6 +195,13 @@ impl FileOperationsVtableBuilder { } } +impl FileOperationsVtableBuilder { + pub const fn write(mut self) -> FileOperationsVtableBuilder { + self.0.write = Some(write_callback::); + self + } +} + impl FileOperationsVtableBuilder { pub const fn seek(mut self) -> FileOperationsVtableBuilder { self.0.llseek = Some(llseek_callback::); @@ -199,6 +233,12 @@ pub trait Read { fn read(&self, buf: &mut UserSlicePtrWriter, offset: u64) -> KernelResult<()>; } +pub trait Write { + /// Writes data from userspace o this file. Corresponds to the `write` + /// function pointer in `struct file_operations`. + fn write(&self, buf: &mut UserSlicePtrReader, offset: u64) -> KernelResult<()>; +} + pub trait Seek { /// Changes the position of the file. Corresponds to the `llseek` function /// pointer in `struct file_operations`. diff --git a/src/user_ptr.rs b/src/user_ptr.rs index b9d333cf..eefdbe28 100644 --- a/src/user_ptr.rs +++ b/src/user_ptr.rs @@ -66,9 +66,7 @@ impl UserSlicePtr { /// Returns EFAULT if the address does not currently point to /// mapped, readable memory. pub fn read_all(self) -> error::KernelResult> { - let mut data = vec![0; self.1]; - self.reader().read(&mut data)?; - Ok(data) + self.reader().read_all() } /// Construct a `UserSlicePtrReader` that can incrementally read @@ -97,6 +95,27 @@ impl UserSlicePtr { pub struct UserSlicePtrReader(*mut c_types::c_void, usize); impl UserSlicePtrReader { + /// Returns the number of bytes left to be read from this. Note that even + /// reading less than this number of bytes may return an Error(). + pub fn len(&self) -> usize { + self.1 + } + + /// Returns `true` if `self.len()` is 0. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Read all data remaining in the user slice and return it in a `Vec`. + /// + /// Returns EFAULT if the address does not currently point to + /// mapped, readable memory. + pub fn read_all(&mut self) -> error::KernelResult> { + let mut data = vec![0; self.1]; + self.read(&mut data)?; + Ok(data) + } + pub fn read(&mut self, data: &mut [u8]) -> error::KernelResult<()> { if data.len() > self.1 || data.len() > u32::MAX as usize { return Err(error::Error::EFAULT); diff --git a/tests/chrdev/src/lib.rs b/tests/chrdev/src/lib.rs index 553a69e6..223c0d48 100644 --- a/tests/chrdev/src/lib.rs +++ b/tests/chrdev/src/lib.rs @@ -1,5 +1,8 @@ #![no_std] +use alloc::string::ToString; +use core::sync::atomic::{AtomicUsize, Ordering}; + use linux_kernel_module::{self, cstr}; struct CycleFile; @@ -55,6 +58,48 @@ impl linux_kernel_module::file_operations::Seek for SeekFile { } } +struct WriteFile { + written: AtomicUsize, +} + +impl linux_kernel_module::file_operations::FileOperations for WriteFile { + const VTABLE: linux_kernel_module::file_operations::FileOperationsVtable = + linux_kernel_module::file_operations::FileOperationsVtable::builder::() + .read() + .write() + .build(); + + fn open() -> linux_kernel_module::KernelResult { + return Ok(WriteFile { + written: AtomicUsize::new(0), + }); + } +} + +impl linux_kernel_module::file_operations::Read for WriteFile { + fn read( + &self, + buf: &mut linux_kernel_module::user_ptr::UserSlicePtrWriter, + _offset: u64, + ) -> linux_kernel_module::KernelResult<()> { + let val = self.written.load(Ordering::SeqCst).to_string(); + buf.write(val.as_bytes())?; + return Ok(()); + } +} + +impl linux_kernel_module::file_operations::Write for WriteFile { + fn write( + &self, + buf: &mut linux_kernel_module::user_ptr::UserSlicePtrReader, + _offset: u64, + ) -> linux_kernel_module::KernelResult<()> { + let data = buf.read_all()?; + self.written.fetch_add(data.len(), Ordering::SeqCst); + return Ok(()); + } +} + struct ChrdevTestModule { _chrdev_registration: linux_kernel_module::chrdev::Registration, } @@ -62,9 +107,10 @@ struct ChrdevTestModule { impl linux_kernel_module::KernelModule for ChrdevTestModule { fn init() -> linux_kernel_module::KernelResult { let chrdev_registration = - linux_kernel_module::chrdev::builder(cstr!("chrdev-tests"), 0..2)? + linux_kernel_module::chrdev::builder(cstr!("chrdev-tests"), 0..3)? .register_device::() .register_device::() + .register_device::() .build()?; Ok(ChrdevTestModule { _chrdev_registration: chrdev_registration, diff --git a/tests/chrdev/tests/tests.rs b/tests/chrdev/tests/tests.rs index 16a0a4f4..e1523b71 100644 --- a/tests/chrdev/tests/tests.rs +++ b/tests/chrdev/tests/tests.rs @@ -1,5 +1,5 @@ use std::fs; -use std::io::{Read, Seek, SeekFrom}; +use std::io::{Read, Seek, SeekFrom, Write}; use std::os::unix::prelude::FileExt; use std::path::PathBuf; use std::process::Command; @@ -45,6 +45,7 @@ impl Drop for UnlinkOnDrop<'_> { fn mknod(path: &PathBuf, major: libc::dev_t, minor: libc::dev_t) -> UnlinkOnDrop { Command::new("sudo") .arg("mknod") + .arg("--mode=a=rw") .arg(path.to_str().unwrap()) .arg("c") .arg(major.to_string()) @@ -56,6 +57,7 @@ fn mknod(path: &PathBuf, major: libc::dev_t, minor: libc::dev_t) -> UnlinkOnDrop const READ_FILE_MINOR: libc::dev_t = 0; const SEEK_FILE_MINOR: libc::dev_t = 1; +const WRITE_FILE_MINOR: libc::dev_t = 2; #[test] fn test_mknod() { @@ -178,3 +180,44 @@ fn test_lseek() { ); }); } + +#[test] +fn test_write_unimplemented() { + with_kernel_module(|| { + let device_number = get_device_major_number(); + let p = temporary_file_path(); + let _u = mknod(&p, device_number, READ_FILE_MINOR); + + let mut f = fs::OpenOptions::new().write(true).open(&p).unwrap(); + assert_eq!( + f.write(&[1, 2, 3]).unwrap_err().raw_os_error().unwrap(), + libc::EBADF + ); + }) +} + +#[test] +fn test_write() { + with_kernel_module(|| { + let device_number = get_device_major_number(); + let p = temporary_file_path(); + let _u = mknod(&p, device_number, WRITE_FILE_MINOR); + + let mut f = fs::OpenOptions::new() + .read(true) + .write(true) + .open(&p) + .unwrap(); + assert_eq!(f.write(&[1, 2, 3]).unwrap(), 3); + + let mut buf = [0; 1]; + f.read_exact(&mut buf).unwrap(); + assert_eq!(&buf, b"3"); + + assert_eq!(f.write(&[1, 2, 3, 4, 5]).unwrap(), 5); + + let mut buf = [0; 1]; + f.read_exact(&mut buf).unwrap(); + assert_eq!(&buf, b"8"); + }) +}