diff --git a/src/jit.rs b/src/jit.rs index 8e24b877..7ccdaf61 100644 --- a/src/jit.rs +++ b/src/jit.rs @@ -41,14 +41,22 @@ pub struct JitProgramArgument<'a> { struct JitProgramSections { pc_section: &'static mut [u64], text_section: &'static mut [u8], + total_allocation_size: usize, } #[cfg(not(target_os = "windows"))] macro_rules! libc_error_guard { + (succeeded?, mmap, $addr:expr, $($arg:expr),*) => {{ + *$addr = libc::mmap(*$addr, $($arg),*); + *$addr != libc::MAP_FAILED + }}; + (succeeded?, $function:ident, $($arg:expr),*) => { + libc::$function($($arg),*) == 0 + }; ($function:ident, $($arg:expr),*) => {{ const RETRY_COUNT: usize = 3; for i in 0..RETRY_COUNT { - if libc::$function($($arg),*) == 0 { + if libc_error_guard!(succeeded?, $function, $($arg),*) { break; } else if i + 1 == RETRY_COUNT { let args = vec![$(format!("{:?}", $arg)),*]; @@ -59,36 +67,42 @@ macro_rules! libc_error_guard { return Err(EbpfError::LibcInvocationFailed(stringify!($function), args, errno)); } } - }} + }}; } impl JitProgramSections { - fn new(pc: usize, code_size: usize) -> Result> { - let _pc_loc_table_size = round_to_page_size(pc * 8); - let _code_size = round_to_page_size(code_size); + fn new(_pc: usize, _code_size: usize) -> Result> { #[cfg(target_os = "windows")] { Ok(Self { pc_section: &mut [], text_section: &mut [], + total_allocation_size: 0, }) } #[cfg(not(target_os = "windows"))] unsafe { + fn round_to_page_size(value: usize, page_size: usize) -> usize { + (value + page_size - 1) / page_size * page_size + } + let page_size = libc::sysconf(libc::_SC_PAGESIZE) as usize; + let pc_loc_table_size = round_to_page_size(_pc * 8, page_size); + let code_size = round_to_page_size(_code_size, page_size); let mut raw: *mut libc::c_void = std::ptr::null_mut(); - libc_error_guard!(posix_memalign, &mut raw, PAGE_SIZE, _pc_loc_table_size + _code_size); - std::ptr::write_bytes(raw, 0x00, _pc_loc_table_size); - std::ptr::write_bytes(raw.add(_pc_loc_table_size), 0xcc, _code_size); // Populate with debugger traps + libc_error_guard!(mmap, &mut raw, pc_loc_table_size + code_size, libc::PROT_READ | libc::PROT_WRITE, libc::MAP_ANONYMOUS | libc::MAP_PRIVATE, 0, 0); + std::ptr::write_bytes(raw, 0x00, pc_loc_table_size); + std::ptr::write_bytes(raw.add(pc_loc_table_size), 0xcc, code_size); // Populate with debugger traps Ok(Self { - pc_section: std::slice::from_raw_parts_mut(raw as *mut u64, pc), - text_section: std::slice::from_raw_parts_mut(raw.add(_pc_loc_table_size) as *mut u8, _code_size), + pc_section: std::slice::from_raw_parts_mut(raw as *mut u64, _pc), + text_section: std::slice::from_raw_parts_mut(raw.add(pc_loc_table_size) as *mut u8, code_size), + total_allocation_size: pc_loc_table_size + code_size, }) } } fn seal(&mut self) -> Result<(), EbpfError> { - #[cfg(not(target_os = "windows"))] - if !self.pc_section.is_empty() { + if self.total_allocation_size > 0 { + #[cfg(not(target_os = "windows"))] unsafe { libc_error_guard!(mprotect, self.pc_section.as_mut_ptr() as *mut _, self.pc_section.len(), libc::PROT_READ); libc_error_guard!(mprotect, self.text_section.as_mut_ptr() as *mut _, self.text_section.len(), libc::PROT_EXEC | libc::PROT_READ); @@ -100,12 +114,10 @@ impl JitProgramSections { impl Drop for JitProgramSections { fn drop(&mut self) { - #[cfg(not(target_os = "windows"))] - if !self.pc_section.is_empty() { + if self.total_allocation_size > 0 { + #[cfg(not(target_os = "windows"))] unsafe { - libc::mprotect(self.pc_section.as_mut_ptr() as *mut _, round_to_page_size(self.pc_section.len()), libc::PROT_READ | libc::PROT_WRITE); - libc::mprotect(self.text_section.as_mut_ptr() as *mut _, round_to_page_size(self.text_section.len()), libc::PROT_READ | libc::PROT_WRITE); - libc::free(self.pc_section.as_ptr() as *mut _); + libc::munmap(self.pc_section.as_ptr() as *mut _, self.total_allocation_size); } } } @@ -692,11 +704,6 @@ fn emit_set_exception_kind(jit: &mut JitCompiler, err: Ebpf X86Instruction::store_immediate(OperandSize::S64, R10, X86IndirectAccess::Offset(8), err_kind as i64).emit(jit) } -const PAGE_SIZE: usize = 4096; -fn round_to_page_size(value: usize) -> usize { - (value + PAGE_SIZE - 1) / PAGE_SIZE * PAGE_SIZE -} - #[derive(Debug)] struct Jump { location: usize,