From 3dc92f2d83f86df0e433bb500262c1392154295c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Tue, 2 Jul 2024 20:59:15 +0200 Subject: [PATCH] Add more instructions to JIT2 --- src/machine/jit2.rs | 243 ++++++++++++++++++++++++++++++++++-- src/machine/mod.rs | 4 +- src/machine/system_calls.rs | 4 +- 3 files changed, 239 insertions(+), 12 deletions(-) diff --git a/src/machine/jit2.rs b/src/machine/jit2.rs index 83fd918d7..018d6c00c 100644 --- a/src/machine/jit2.rs +++ b/src/machine/jit2.rs @@ -11,6 +11,9 @@ use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{Linkage, Module}; use cranelift_codegen::Context; use cranelift::prelude::codegen::ir::immediates::Offset32; +use cranelift::prelude::codegen::ir::entities::Value; + +use std::ops::Index; #[derive(Debug, PartialEq)] pub enum JitCompileError { @@ -24,6 +27,12 @@ pub struct JitMachine { module: JITModule, ctx: Context, func_ctx: FunctionBuilderContext, + heap_as_ptr: *const u8, + heap_as_ptr_sig: Signature, + heap_push: *const u8, + heap_push_sig: Signature, + heap_len: *const u8, + heap_len_sig: Signature, } impl std::fmt::Debug for JitMachine { @@ -85,27 +94,247 @@ impl JitMachine { let code_ptr: *const u8 = unsafe { std::mem::transmute(module.get_finalized_function(func)) }; trampolines.push(code_ptr); } - - + let heap_as_ptr = Vec::::as_ptr as *const u8; + let mut heap_as_ptr_sig = module.make_signature(); + heap_as_ptr_sig.params.push(AbiParam::new(pointer_type)); + heap_as_ptr_sig.returns.push(AbiParam::new(pointer_type)); + let heap_push = Vec::::push as *const u8; + let mut heap_push_sig = module.make_signature(); + heap_push_sig.params.push(AbiParam::new(pointer_type)); + heap_push_sig.params.push(AbiParam::new(types::I64)); + let heap_len = Vec::::len as *const u8; + let mut heap_len_sig = module.make_signature(); + heap_len_sig.params.push(AbiParam::new(pointer_type)); + heap_len_sig.returns.push(AbiParam::new(types::I64)); JitMachine { trampolines, module, ctx, func_ctx, + heap_as_ptr, + heap_as_ptr_sig, + heap_push, + heap_push_sig, + heap_len, + heap_len_sig, } } - // TODO: Compile taking into account arity - pub fn compile(&mut self, name: &str, code: Code) -> Result<(), JitCompileError> { + pub fn compile(&mut self, name: &str, arity: usize, code: Code) -> Result<(), JitCompileError> { + let mut sig = self.module.make_signature(); + sig.params.push(AbiParam::new(types::I64)); + for _ in 1..=arity { + sig.params.push(AbiParam::new(types::I64)); + sig.returns.push(AbiParam::new(types::I64)); + } + sig.call_conv = isa::CallConv::Tail; + self.ctx.func.signature = sig.clone(); + let mut fn_builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.func_ctx); let block = fn_builder.create_block(); fn_builder.append_block_params_for_function_params(block); fn_builder.switch_to_block(block); + fn_builder.seal_block(block); + + let heap = fn_builder.block_params(block)[0]; + let mode = Variable::new(0); + fn_builder.declare_var(mode, types::I8); + let s = Variable::new(1); + fn_builder.declare_var(s, types::I64); + let fail = Variable::new(2); + fn_builder.declare_var(fail, types::I8); + + let mut registers = vec![]; + for i in 1..=arity { + let reg = fn_builder.block_params(block)[i]; + registers.push(reg); + } + + macro_rules! heap_len { + () => { + {let sig_ref = fn_builder.import_signature(self.heap_len_sig.clone()); + let heap_len_fn = fn_builder.ins().iconst(types::I64, self.heap_len as i64); + let call_heap_len = fn_builder.ins().call_indirect(sig_ref, heap_len_fn, &[heap]); + let heap_len = fn_builder.inst_results(call_heap_len)[0]; + heap_len} + } + } + + macro_rules! heap_as_ptr { + () => { + { + let sig_ref = fn_builder.import_signature(self.heap_as_ptr_sig.clone()); + let heap_as_ptr_fn = fn_builder.ins().iconst(types::I64, self.heap_as_ptr as i64); + let call_heap_as_ptr = fn_builder.ins().call_indirect(sig_ref, heap_as_ptr_fn, &[heap]); + let heap_ptr = fn_builder.inst_results(call_heap_as_ptr)[0]; + heap_ptr + } + } + } + + macro_rules! store { + ($x:expr) => { + { + let merge_block = fn_builder.create_block(); + fn_builder.append_block_param(merge_block, types::I64); + let is_var_block = fn_builder.create_block(); + fn_builder.append_block_param(is_var_block, types::I64); + let is_not_var_block = fn_builder.create_block(); + fn_builder.append_block_param(is_not_var_block, types::I64); + let tag = fn_builder.ins().band_imm($x, 64); + let is_var = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64); + fn_builder.ins().brif(is_var, is_var_block, &[$x], is_not_var_block, &[$x]); + // is_var + fn_builder.switch_to_block(is_var_block); + fn_builder.seal_block(is_var_block); + let param = fn_builder.block_params(is_var_block)[0]; + let idx = fn_builder.ins().ushr_imm(param, 8); + let heap_ptr = heap_as_ptr!(); + let idx_ptr = fn_builder.ins().imul_imm(idx, 8); + let idx_ptr = fn_builder.ins().iadd(heap_ptr, idx_ptr); + let heap_value = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx_ptr, Offset32::new(0)); + fn_builder.ins().jump(merge_block, &[heap_value]); + // is_not_var + fn_builder.switch_to_block(is_not_var_block); + fn_builder.seal_block(is_not_var_block); + let param = fn_builder.block_params(is_not_var_block)[0]; + fn_builder.ins().jump(merge_block, &[param]); + // merge + fn_builder.switch_to_block(merge_block); + fn_builder.seal_block(merge_block); + fn_builder.block_params(merge_block)[0] + } + } + } + + macro_rules! deref { + ($x:expr) => { + { + let exit_block = fn_builder.create_block(); + fn_builder.append_block_param(exit_block, types::I64); + let loop_block = fn_builder.create_block(); + fn_builder.append_block_param(loop_block, types::I64); + fn_builder.ins().jump(loop_block, &[$x]); + fn_builder.switch_to_block(loop_block); + let addr = fn_builder.block_params(loop_block)[0]; + let value = store!(addr); + // check if is var + let tag = fn_builder.ins().band_imm(value, 64); + let is_var = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64); + let not_equal = fn_builder.ins().icmp(IntCC::NotEqual, value, addr); + let check = fn_builder.ins().band(is_var, not_equal); + fn_builder.ins().brif(check, loop_block, &[value], exit_block, &[value]); + // exit + fn_builder.seal_block(loop_block); + fn_builder.seal_block(exit_block); + fn_builder.switch_to_block(exit_block); + fn_builder.block_params(exit_block)[0] + + } + } + } for wam_instr in code { match wam_instr { + // TODO Missing RegType Perm + Instruction::PutStructure(name, arity, reg) => { + let atom_cell = atom_as_cell!(name, arity); + let atom = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(atom_cell.into_bytes())); + let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); + let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); + fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, atom]); + let str_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(str_loc_as_cell!(0).into_bytes())); + let heap_len = heap_len!(); + let heap_len_shift = fn_builder.ins().ishl_imm(heap_len, 8); + let str_cell = fn_builder.ins().bor(heap_len_shift, str_cell); + match reg { + RegType::Temp(x) => { + registers[x] = str_cell; + } + _ => unimplemented!() + } + } + // TODO Missing RegType Perm + Instruction::SetVariable(reg) => { + let heap_loc_cell = heap_loc_as_cell!(0); + let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes())); + let heap_len = heap_len!(); + let heap_len_shift = fn_builder.ins().ishl_imm(heap_len, 8); + let heap_loc_cell = fn_builder.ins().bor(heap_len_shift, heap_loc_cell); + let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); + let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); + fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, heap_loc_cell]); + match reg { + RegType::Temp(x) => { + registers[x] = heap_loc_cell; + } + _ => unimplemented!() + } + } + // TODO: Missing RegType Perm + Instruction::SetValue(reg) => { + let value = match reg { + RegType::Temp(x) => { + registers[x] + }, + _ => unimplemented!() + }; + let value = store!(value); + + let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); + let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); + fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, value]); + } + // TODO: Missing RegType Perm. Let's suppose Mode is local to each predicate + // TODO: Missing support for PStr and CStr + Instruction::UnifyVariable(reg) => { + let read_block = fn_builder.create_block(); + let write_block = fn_builder.create_block(); + let exit_block = fn_builder.create_block(); + let mode_value = fn_builder.use_var(mode); + fn_builder.ins().brif(mode_value, write_block, &[], read_block, &[]); + fn_builder.seal_block(read_block); + fn_builder.seal_block(write_block); + // read + fn_builder.switch_to_block(read_block); + let heap_ptr = heap_as_ptr!(); + let s_value = fn_builder.use_var(s); + let idx = fn_builder.ins().iadd(heap_ptr, s_value); + let value = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx, Offset32::new(0)); + let value = deref!(value); + match reg { + RegType::Temp(x) => { + registers[x] = value; + }, + _ => unimplemented!() + } + let sum_s = fn_builder.ins().iadd_imm(s_value, 8); + fn_builder.def_var(s, sum_s); + fn_builder.ins().jump(exit_block, &[]); + // write (equal to SetVariable) + fn_builder.switch_to_block(write_block); + let heap_loc_cell = heap_loc_as_cell!(0); + let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes())); + let heap_len = heap_len!(); + let heap_len_shift = fn_builder.ins().ishl_imm(heap_len, 8); + let heap_loc_cell = fn_builder.ins().bor(heap_len_shift, heap_loc_cell); + let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); + let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); + fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, heap_loc_cell]); + match reg { + RegType::Temp(x) => { + registers[x] = heap_loc_cell; + } + _ => unimplemented!() + } + fn_builder.ins().jump(exit_block, &[]); + // exit + fn_builder.switch_to_block(exit_block); + fn_builder.seal_block(exit_block); + + } Instruction::Proceed => { - fn_builder.ins().return_(&[]); + fn_builder.ins().return_(®isters); break; }, _ => { @@ -118,8 +347,6 @@ impl JitMachine { fn_builder.seal_all_blocks(); fn_builder.finalize(); - let mut sig = self.module.make_signature(); - sig.call_conv = isa::CallConv::Tail; let func = self.module.declare_function(name, Linkage::Local, &sig).unwrap(); self.module.define_function(func, &mut self.ctx).unwrap(); @@ -129,6 +356,6 @@ impl JitMachine { } pub fn exec(&self, name: &str, machine_st: &mut MachineState) -> Result<(), ()> { - Ok(()) + Err(()) } } diff --git a/src/machine/mod.rs b/src/machine/mod.rs index 2a7e343b6..78e83e1d9 100644 --- a/src/machine/mod.rs +++ b/src/machine/mod.rs @@ -13,7 +13,7 @@ pub mod dispatch; pub mod gc; pub mod heap; #[cfg(feature = "jit")] -pub mod jit; +pub mod jit2; pub mod lib_machine; pub mod load_state; pub mod machine_errors; @@ -42,7 +42,7 @@ use crate::machine::compile::*; use crate::machine::copier::*; use crate::machine::heap::*; #[cfg(feature = "jit")] -use crate::machine::jit::*; +use crate::machine::jit2::*; use crate::machine::loader::*; use crate::machine::machine_errors::*; use crate::machine::machine_indices::*; diff --git a/src/machine/system_calls.rs b/src/machine/system_calls.rs index 7ff5353af..518aaa73a 100644 --- a/src/machine/system_calls.rs +++ b/src/machine/system_calls.rs @@ -20,7 +20,7 @@ use crate::machine::code_walker::*; use crate::machine::copier::*; use crate::machine::heap::*; #[cfg(feature = "jit")] -use crate::machine::jit::*; +use crate::machine::jit2::*; use crate::machine::machine_errors::*; use crate::machine::machine_indices::*; use crate::machine::machine_state::*; @@ -5038,7 +5038,7 @@ impl Machine { let mut code = vec![]; walk_code(&self.code, first_idx, |instr| code.push(instr.clone())); - match self.jit_machine.compile(&format!("{}/{}", name.as_str(), arity), code) { + match self.jit_machine.compile(&name.as_str(), arity, code) { Err(JitCompileError::UndefinedPredicate) => { eprintln!("jit_compiler: undefined_predicate"); self.machine_st.fail = true;