diff --git a/src/chrdev.rs b/src/chrdev.rs index 9398b0e3..0a68f738 100644 --- a/src/chrdev.rs +++ b/src/chrdev.rs @@ -10,7 +10,7 @@ use crate::bindings; use crate::c_types; use crate::error::{Error, KernelResult}; use crate::types::CStr; -use crate::user_ptr::{UserSlicePtr, UserSlicePtrWriter}; +use crate::user_ptr::{UserSlicePtr, UserSlicePtrReader, UserSlicePtrWriter}; pub fn builder(name: &'static CStr, minors: Range) -> KernelResult { Ok(Builder { @@ -158,6 +158,32 @@ unsafe extern "C" fn read_callback( } } +unsafe extern "C" fn write_callback( + file: *mut bindings::file, + buf: *const c_types::c_char, + len: c_types::c_size_t, + offset: *mut bindings::loff_t, +) -> c_types::c_ssize_t { + let mut data = match UserSlicePtr::new(buf as *mut c_types::c_void, len) { + Ok(ptr) => ptr.reader(), + Err(e) => return e.to_kernel_errno().try_into().unwrap(), + }; + let f = &*((*file).private_data as *const T); + // No FMODE_UNSIGNED_OFFSET support, so offset must be in [0, 2^63). + // See discussion in #113 + let positive_offset = match (*offset).try_into() { + Ok(v) => v, + Err(_) => return Error::EINVAL.to_kernel_errno().try_into().unwrap(), + }; + match f.write(&mut data, positive_offset) { + Ok(()) => { + let read = len - data.len(); + read.try_into().unwrap() + } + Err(e) => e.to_kernel_errno().try_into().unwrap(), + } +} + unsafe extern "C" fn release_callback( _inode: *mut bindings::inode, file: *mut bindings::file, @@ -193,6 +219,7 @@ impl FileOperationsVtable { FileOperationsVtable(bindings::file_operations { open: Some(open_callback::), read: Some(read_callback::), + write: Some(write_callback::), release: Some(release_callback::), llseek: Some(llseek_callback::), @@ -232,7 +259,6 @@ impl FileOperationsVtable { splice_read: None, splice_write: None, unlocked_ioctl: None, - write: None, write_iter: None, }) } @@ -260,6 +286,12 @@ pub trait FileOperations: Sync + Sized { Err(Error::EINVAL) } + /// Writes data from userspace o this file. Corresponds to the `write` + /// function pointer in `struct file_operations`. + fn write(&self, _buf: &mut UserSlicePtrReader, _offset: u64) -> KernelResult<()> { + Err(Error::EINVAL) + } + /// Changes the position of the file. Corresponds to the `llseek` function /// pointer in `struct file_operations`. fn seek(&self, _file: &File, _offset: SeekFrom) -> KernelResult { diff --git a/src/user_ptr.rs b/src/user_ptr.rs index b9d333cf..eefdbe28 100644 --- a/src/user_ptr.rs +++ b/src/user_ptr.rs @@ -66,9 +66,7 @@ impl UserSlicePtr { /// Returns EFAULT if the address does not currently point to /// mapped, readable memory. pub fn read_all(self) -> error::KernelResult> { - let mut data = vec![0; self.1]; - self.reader().read(&mut data)?; - Ok(data) + self.reader().read_all() } /// Construct a `UserSlicePtrReader` that can incrementally read @@ -97,6 +95,27 @@ impl UserSlicePtr { pub struct UserSlicePtrReader(*mut c_types::c_void, usize); impl UserSlicePtrReader { + /// Returns the number of bytes left to be read from this. Note that even + /// reading less than this number of bytes may return an Error(). + pub fn len(&self) -> usize { + self.1 + } + + /// Returns `true` if `self.len()` is 0. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Read all data remaining in the user slice and return it in a `Vec`. + /// + /// Returns EFAULT if the address does not currently point to + /// mapped, readable memory. + pub fn read_all(&mut self) -> error::KernelResult> { + let mut data = vec![0; self.1]; + self.read(&mut data)?; + Ok(data) + } + pub fn read(&mut self, data: &mut [u8]) -> error::KernelResult<()> { if data.len() > self.1 || data.len() > u32::MAX as usize { return Err(error::Error::EFAULT); diff --git a/tests/chrdev/src/lib.rs b/tests/chrdev/src/lib.rs index 7c7b7059..e932709d 100644 --- a/tests/chrdev/src/lib.rs +++ b/tests/chrdev/src/lib.rs @@ -1,5 +1,8 @@ #![no_std] +use alloc::string::ToString; +use core::sync::atomic::{AtomicUsize, Ordering}; + use linux_kernel_module::{self, cstr}; struct CycleFile; @@ -48,6 +51,41 @@ impl linux_kernel_module::chrdev::FileOperations for SeekFile { } } +struct WriteFile { + written: AtomicUsize, +} + +impl linux_kernel_module::chrdev::FileOperations for WriteFile { + const VTABLE: linux_kernel_module::chrdev::FileOperationsVtable = + linux_kernel_module::chrdev::FileOperationsVtable::new::(); + + fn open() -> linux_kernel_module::KernelResult { + return Ok(WriteFile { + written: AtomicUsize::new(0), + }); + } + + fn write( + &self, + buf: &mut linux_kernel_module::user_ptr::UserSlicePtrReader, + _offset: u64, + ) -> linux_kernel_module::KernelResult<()> { + let data = buf.read_all()?; + self.written.fetch_add(data.len(), Ordering::SeqCst); + return Ok(()); + } + + fn read( + &self, + buf: &mut linux_kernel_module::user_ptr::UserSlicePtrWriter, + _offset: u64, + ) -> linux_kernel_module::KernelResult<()> { + let val = self.written.load(Ordering::SeqCst); + buf.write(&val.to_string().as_bytes()[offset..])?; + return Ok(()); + } +} + struct ChrdevTestModule { _chrdev_registration: linux_kernel_module::chrdev::Registration, } @@ -58,6 +96,7 @@ impl linux_kernel_module::KernelModule for ChrdevTestModule { linux_kernel_module::chrdev::builder(cstr!("chrdev-tests"), 0..2)? .register_device::() .register_device::() + .register_device::() .build()?; Ok(ChrdevTestModule { _chrdev_registration: chrdev_registration, diff --git a/tests/chrdev/tests/tests.rs b/tests/chrdev/tests/tests.rs index 16a0a4f4..efa7370f 100644 --- a/tests/chrdev/tests/tests.rs +++ b/tests/chrdev/tests/tests.rs @@ -1,5 +1,5 @@ use std::fs; -use std::io::{Read, Seek, SeekFrom}; +use std::io::{Read, Seek, SeekFrom, Write}; use std::os::unix::prelude::FileExt; use std::path::PathBuf; use std::process::Command; @@ -56,6 +56,7 @@ fn mknod(path: &PathBuf, major: libc::dev_t, minor: libc::dev_t) -> UnlinkOnDrop const READ_FILE_MINOR: libc::dev_t = 0; const SEEK_FILE_MINOR: libc::dev_t = 1; +const WRITE_FILE_MINOR: libc::dev_t = 2; #[test] fn test_mknod() { @@ -178,3 +179,40 @@ fn test_lseek() { ); }); } + +#[test] +fn test_write_unimplemented() { + with_kernel_module(|| { + let device_number = get_device_major_number(); + let p = temporary_file_path(); + let _u = mknod(&p, device_number, READ_FILE_MINOR); + + let mut f = fs::File::open(&p).unwrap(); + assert_eq!( + f.write(&[1, 2, 3]).unwrap_err().raw_os_error().unwrap(), + libc::EINVAL + ); + }) +} + +#[test] +fn test_write() { + with_kernel_module(|| { + let device_number = get_device_major_number(); + let p = temporary_file_path(); + let _u = mknod(&p, device_number, WRITE_FILE_MINOR); + + let mut f = fs::File::open(&p).unwrap(); + assert_eq!(f.write(&[1, 2, 3]).unwrap(), 3); + + let mut buf = vec![]; + f.read_to_end(&mut buf).unwrap(); + assert_eq!(&buf, b"3"); + + assert_eq!(f.write(&[1, 2, 3, 4, 5]).unwrap(), 5); + + let mut buf = vec![]; + f.read_to_end(&mut buf).unwrap(); + assert_eq!(&buf, b"8"); + }) +}