Skip to content

Commit

Permalink
Refactor - Improve syscall errors (#536)
Browse files Browse the repository at this point in the history
* Implements StableResult::map() and StableResult::map_err().

* Moves StableResult and ProgramResult from vm.rs into error.rs

* Moves tests into the files where they belong.

* Makes rust interface of syscalls return Box<dyn std::error::Error>.
  • Loading branch information
Lichtso authored Oct 1, 2023
1 parent ae16614 commit 32b1a6f
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 152 deletions.
4 changes: 2 additions & 2 deletions src/debugger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ use gdbstub::target::ext::section_offsets::Offsets;

use crate::{
ebpf,
error::EbpfError,
error::{EbpfError, ProgramResult},
interpreter::{DebugState, Interpreter},
memory_region::AccessType,
vm::{ContextObject, ProgramResult},
vm::ContextObject,
};

type DynResult<T> = Result<T, Box<dyn std::error::Error>>;
Expand Down
3 changes: 2 additions & 1 deletion src/elf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1189,10 +1189,11 @@ mod test {
consts::{ELFCLASS32, ELFDATA2MSB, ET_REL},
types::{Elf64Ehdr, Elf64Shdr},
},
error::ProgramResult,
fuzz::fuzz,
program::BuiltinFunction,
syscalls,
vm::{ProgramResult, TestContextObject},
vm::TestContextObject,
};
use rand::{distributions::Uniform, Rng};
use std::{fs::File, io::Read};
Expand Down
106 changes: 106 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,109 @@ pub enum EbpfError {
#[error("Syscall error: {0}")]
SyscallError(Box<dyn Error>),
}

/// Same as `Result` but provides a stable memory layout
#[derive(Debug)]
#[repr(C, u64)]
pub enum StableResult<T, E> {
/// Success
Ok(T),
/// Failure
Err(E),
}

impl<T: std::fmt::Debug, E: std::fmt::Debug> StableResult<T, E> {
/// `true` if `Ok`
pub fn is_ok(&self) -> bool {
match self {
Self::Ok(_) => true,
Self::Err(_) => false,
}
}

/// `true` if `Err`
pub fn is_err(&self) -> bool {
match self {
Self::Ok(_) => false,
Self::Err(_) => true,
}
}

/// Returns the inner value if `Ok`, panics otherwise
pub fn unwrap(self) -> T {
match self {
Self::Ok(value) => value,
Self::Err(error) => panic!("unwrap {:?}", error),
}
}

/// Returns the inner error if `Err`, panics otherwise
pub fn unwrap_err(self) -> E {
match self {
Self::Ok(value) => panic!("unwrap_err {:?}", value),
Self::Err(error) => error,
}
}

/// Maps ok values, leaving error values untouched
pub fn map<U, O: FnOnce(T) -> U>(self, op: O) -> StableResult<U, E> {
match self {
Self::Ok(value) => StableResult::<U, E>::Ok(op(value)),
Self::Err(error) => StableResult::<U, E>::Err(error),
}
}

/// Maps error values, leaving ok values untouched
pub fn map_err<F, O: FnOnce(E) -> F>(self, op: O) -> StableResult<T, F> {
match self {
Self::Ok(value) => StableResult::<T, F>::Ok(value),
Self::Err(error) => StableResult::<T, F>::Err(op(error)),
}
}

#[cfg_attr(
any(
not(feature = "jit"),
target_os = "windows",
not(target_arch = "x86_64")
),
allow(dead_code)
)]
pub(crate) fn discriminant(&self) -> u64 {
unsafe { *(self as *const _ as *const u64) }
}
}

impl<T, E> From<StableResult<T, E>> for Result<T, E> {
fn from(result: StableResult<T, E>) -> Self {
match result {
StableResult::Ok(value) => Ok(value),
StableResult::Err(value) => Err(value),
}
}
}

impl<T, E> From<Result<T, E>> for StableResult<T, E> {
fn from(result: Result<T, E>) -> Self {
match result {
Ok(value) => Self::Ok(value),
Err(value) => Self::Err(value),
}
}
}

/// Return value of programs and syscalls
pub type ProgramResult = StableResult<u64, EbpfError>;

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_program_result_is_stable() {
let ok = ProgramResult::Ok(42);
assert_eq!(ok.discriminant(), 0);
let err = ProgramResult::Err(EbpfError::JitNotCompiled);
assert_eq!(err.discriminant(), 1);
}
}
4 changes: 2 additions & 2 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
use crate::{
ebpf::{self, STACK_PTR_REG},
elf::Executable,
error::EbpfError,
vm::{get_runtime_environment_key, Config, ContextObject, EbpfVm, ProgramResult},
error::{EbpfError, ProgramResult},
vm::{get_runtime_environment_key, Config, ContextObject, EbpfVm},
};

/// Virtual memory operation helper.
Expand Down
4 changes: 2 additions & 2 deletions src/jit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ use std::{fmt::Debug, mem, ptr};
use crate::{
ebpf::{self, FIRST_SCRATCH_REG, FRAME_PTR_REG, INSN_SIZE, SCRATCH_REGS, STACK_PTR_REG},
elf::Executable,
error::EbpfError,
error::{EbpfError, ProgramResult},
memory_management::{
allocate_pages, free_pages, get_system_page_size, protect_pages, round_to_page_size,
},
memory_region::{AccessType, MemoryMapping},
vm::{get_runtime_environment_key, Config, ContextObject, EbpfVm, ProgramResult},
vm::{get_runtime_environment_key, Config, ContextObject, EbpfVm},
x86::*,
};

Expand Down
4 changes: 2 additions & 2 deletions src/memory_region.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
use crate::{
aligned_memory::Pod,
ebpf,
error::EbpfError,
error::{EbpfError, ProgramResult},
program::SBPFVersion,
vm::{Config, ProgramResult},
vm::Config,
};
use std::{
array,
Expand Down
44 changes: 40 additions & 4 deletions src/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ macro_rules! declare_builtin_function {
$arg_d:ident : u64,
$arg_e:ident : u64,
$memory_mapping:ident : &mut $MemoryMapping:ty,
) -> Result<u64, EbpfError> $rust:tt) => {
) -> Result<u64, Box<dyn std::error::Error>> $rust:tt) => {
$(#[$attr])*
pub struct $name {}
impl $name {
Expand All @@ -309,7 +309,7 @@ macro_rules! declare_builtin_function {
$arg_d: u64,
$arg_e: u64,
$memory_mapping: &mut $MemoryMapping,
) -> Result<u64, EbpfError> {
) -> Result<u64, Box<dyn std::error::Error>> {
$rust
}
/// VM interface
Expand All @@ -330,9 +330,9 @@ macro_rules! declare_builtin_function {
if config.enable_instruction_meter {
vm.context_object_pointer.consume(vm.previous_instruction_meter - vm.due_insn_count);
}
let converted_result: $crate::vm::ProgramResult = Self::rust(
let converted_result: $crate::error::ProgramResult = Self::rust(
vm.context_object_pointer, $arg_a, $arg_b, $arg_c, $arg_d, $arg_e, &mut vm.memory_mapping,
).into();
).map_err(|err| $crate::error::EbpfError::SyscallError(err)).into();
vm.program_result = converted_result;
if config.enable_instruction_meter {
vm.previous_instruction_meter = vm.context_object_pointer.get_remaining();
Expand All @@ -341,3 +341,39 @@ macro_rules! declare_builtin_function {
}
};
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{program::BuiltinFunction, syscalls, vm::TestContextObject};

#[test]
fn test_builtin_program_eq() {
let mut function_registry_a =
FunctionRegistry::<BuiltinFunction<TestContextObject>>::default();
function_registry_a
.register_function_hashed(*b"log", syscalls::SyscallString::vm)
.unwrap();
function_registry_a
.register_function_hashed(*b"log_64", syscalls::SyscallU64::vm)
.unwrap();
let mut function_registry_b =
FunctionRegistry::<BuiltinFunction<TestContextObject>>::default();
function_registry_b
.register_function_hashed(*b"log_64", syscalls::SyscallU64::vm)
.unwrap();
function_registry_b
.register_function_hashed(*b"log", syscalls::SyscallString::vm)
.unwrap();
let mut function_registry_c =
FunctionRegistry::<BuiltinFunction<TestContextObject>>::default();
function_registry_c
.register_function_hashed(*b"log_64", syscalls::SyscallU64::vm)
.unwrap();
let builtin_program_a = BuiltinProgram::new_loader(Config::default(), function_registry_a);
let builtin_program_b = BuiltinProgram::new_loader(Config::default(), function_registry_b);
assert_eq!(builtin_program_a, builtin_program_b);
let builtin_program_c = BuiltinProgram::new_loader(Config::default(), function_registry_c);
assert_ne!(builtin_program_a, builtin_program_c);
}
}
12 changes: 6 additions & 6 deletions src/syscalls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ declare_builtin_function!(
arg4: u64,
arg5: u64,
_memory_mapping: &mut MemoryMapping,
) -> Result<u64, EbpfError> {
) -> Result<u64, Box<dyn std::error::Error>> {
println!("bpf_trace_printf: {arg3:#x}, {arg4:#x}, {arg5:#x}");
let size_arg = |x| {
if x == 0 {
Expand Down Expand Up @@ -69,7 +69,7 @@ declare_builtin_function!(
arg4: u64,
arg5: u64,
_memory_mapping: &mut MemoryMapping,
) -> Result<u64, EbpfError> {
) -> Result<u64, Box<dyn std::error::Error>> {
Ok(arg1.wrapping_shl(32)
| arg2.wrapping_shl(24)
| arg3.wrapping_shl(16)
Expand All @@ -91,7 +91,7 @@ declare_builtin_function!(
_arg4: u64,
_arg5: u64,
memory_mapping: &mut MemoryMapping,
) -> Result<u64, EbpfError> {
) -> Result<u64, Box<dyn std::error::Error>> {
let host_addr: Result<u64, EbpfError> =
memory_mapping.map(AccessType::Store, vm_addr, len).into();
let host_addr = host_addr?;
Expand All @@ -116,7 +116,7 @@ declare_builtin_function!(
_arg4: u64,
_arg5: u64,
memory_mapping: &mut MemoryMapping,
) -> Result<u64, EbpfError> {
) -> Result<u64, Box<dyn std::error::Error>> {
// C-like strcmp, maybe shorter than converting the bytes to string and comparing?
if arg1 == 0 || arg2 == 0 {
return Ok(u64::MAX);
Expand Down Expand Up @@ -154,7 +154,7 @@ declare_builtin_function!(
_arg4: u64,
_arg5: u64,
memory_mapping: &mut MemoryMapping,
) -> Result<u64, EbpfError> {
) -> Result<u64, Box<dyn std::error::Error>> {
let host_addr: Result<u64, EbpfError> =
memory_mapping.map(AccessType::Load, vm_addr, len).into();
let host_addr = host_addr?;
Expand Down Expand Up @@ -185,7 +185,7 @@ declare_builtin_function!(
arg4: u64,
arg5: u64,
memory_mapping: &mut MemoryMapping,
) -> Result<u64, EbpfError> {
) -> Result<u64, Box<dyn std::error::Error>> {
println!(
"dump_64: {:#x}, {:#x}, {:#x}, {:#x}, {:#x}, {:?}",
arg1, arg2, arg3, arg4, arg5, memory_mapping as *const _
Expand Down
Loading

0 comments on commit 32b1a6f

Please sign in to comment.