diff --git a/src/main.rs b/src/main.rs index 592b992..84e1e8b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,12 +2,9 @@ use walrus::{ir::Instr, FunctionId}; use std::collections::HashMap; use log; -fn get_replacement_module_id(module: &walrus::Module, import_item: &walrus::Import, fn_id: FunctionId) -> Option { +fn get_replacement_module_id(module: &walrus::Module, module_name: &str, import_name: &str, fn_id: FunctionId) -> Option { - let module_name = import_item.module.as_str(); - let import_name = import_item.name.as_str(); - - // for now we only support wasi_unstable and wasi_snapshot_preview1 + // for now we only support wasi_unstable and wasi_snapshot_preview1 modules if module_name != "wasi_unstable" && module_name != "wasi_snapshot_preview1" { return None; } @@ -19,9 +16,7 @@ fn get_replacement_module_id(module: &walrus::Module, import_item: &walrus::Impo if let Some(name) = &fun.name { if *name == searched_function_name { - log::debug!("Function replacement found: {:?} -> {:?}.", - module.funcs.get(fn_id).name, - module.funcs.get(fun.id()).name); + log::debug!("Function replacement found: {:?} -> {:?}.", module.funcs.get(fn_id).name, module.funcs.get(fun.id()).name); assert_eq!(module.funcs.get(fn_id).ty(), module.funcs.get(fun.id()).ty()); @@ -39,9 +34,8 @@ fn get_replacement_module_id(module: &walrus::Module, import_item: &walrus::Impo match export.item { walrus::ExportItem::Function(exported_function) => { - log::debug!("Function replacement found in exports: {:?} -> {:?}.", - module.funcs.get(fn_id).name, - module.funcs.get(exported_function).name); + log::debug!("Function replacement found in exports: {:?} -> {:?}.", module.funcs.get(fn_id).name, module.funcs.get(exported_function).name); + assert_eq!(module.funcs.get(fn_id).ty(), module.funcs.get(exported_function).ty()); return Some(exported_function); @@ -49,7 +43,6 @@ fn get_replacement_module_id(module: &walrus::Module, import_item: &walrus::Impo }, walrus::ExportItem::Table(_) | walrus::ExportItem::Memory(_) | walrus::ExportItem::Global(_) => {}, } - } log::warn!("Could not find the replacement for the WASI function: {}::{}", module_name, import_name); @@ -67,9 +60,9 @@ fn gather_replacement_ids(m: &walrus::Module) -> HashMap match imp.kind { walrus::ImportKind::Function(fn_id) => { - + let replace_id = get_replacement_module_id( - m, imp, fn_id); + m, imp.module.as_str(), imp.name.as_str(), fn_id); if let Some(rep_id) = replace_id { fn_replacement_ids.insert(fn_id, rep_id); @@ -90,6 +83,7 @@ fn gather_replacement_ids(m: &walrus::Module) -> HashMap fn replace_calls(m: &mut walrus::Module, fn_replacement_ids: &HashMap) { + // replace dependent calls for fun in m.funcs.iter_mut() { @@ -102,11 +96,8 @@ fn replace_calls(m: &mut walrus::Module, fn_replacement_ids: &HashMap { - let block_id: walrus::ir::InstrSeqId = local_fun.entry_block(); - replace_calls_in_instructions(block_id, fn_replacement_ids, local_fun); - }, walrus::FunctionKind::Uninitialized(_) => {}, diff --git a/src/tests.rs b/src/tests.rs index 1093428..d5ee7a2 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -124,12 +124,13 @@ fn test_gather_replacement_ids() { (type (;3;) (func (param i32 i32) (result i32))) (type (;4;) (func (param i32 i32 i32))) (type (;5;) (func (param i32 i32 i32 i32) (result i32))) + (import "ic0" "debug_print" (func $_dprint (;0;) (type 2))) (import "ic0" "msg_reply" (func $_msg_reply (;1;) (type 0))) - (import "wasi_snapshot_preview1" "fd_write" (func $_wasi_snapshot_preview_fd_write (;2;) (type 5))) - (import "wasi_snapshot_preview1" "random_get" (func $_wasi_snapshot_preview_random_get (;3;) (type 3))) - (import "wasi_snapshot_preview1" "environ_get" (func $__imported_wasi_snapshot_preview1_environ_get (;4;) (type 3))) - (import "wasi_snapshot_preview1" "proc_exit" (func $__imported_wasi_snapshot_preview1_proc_exit (;5;) (type 1))) + (import "wasi_unstable" "fd_write" (func $_wasi_unstable_fd_write (;2;) (type 5))) + (import "wasi_unstable" "random_get" (func $_wasi_unstable_random_get (;3;) (type 3))) + (import "wasi_unstable" "environ_get" (func $__imported_wasi_unstable_environ_get (;4;) (type 3))) + (import "wasi_unstable" "proc_exit" (func $__imported_wasi_unstable_proc_exit (;5;) (type 1))) (func $_start (;6;) (type 0) i32.const 1 @@ -137,31 +138,32 @@ fn test_gather_replacement_ids() { call $__ic_custom_random_get i32.const 1 i32.const 2 - call $_wasi_snapshot_preview_random_get + call $_wasi_unstable_random_get i32.const 4 i32.const 5 - call $__ic_custom_fd_write + call $_wasi_unstable_fd_write drop ) - (func $__ic_custom_random_get (;7;) (type 3) (param i32 i32) (result i32) + (func $__ic_custom_random_get (;8;) (type 3) (param i32 i32) (result i32) call $_msg_reply i32.const 421 ) - (func $__ic_custom_fd_write (;8;) (type 5) (param i32 i32 i32 i32) (result i32) + (func $ic_dummy_fd_write (;7;) (type 5) (param i32 i32 i32 i32) (result i32) i32.const 0 i32.const 0 call $_dprint i32.const 42 ) + (export "__ic_custom_fd_write" (func $ic_dummy_fd_write)) (export "_start" (func $_start)) ) "#; let binary = wat::parse_str(wat).unwrap(); - let module = walrus::Module::from_buffer(&binary).unwrap(); + let mut module = walrus::Module::from_buffer(&binary).unwrap(); let id_reps: HashMap = gather_replacement_ids(&module).iter().map(|(x, y)| (x.index(), y.index())).collect();