diff --git a/src/module/exports.rs b/src/module/exports.rs index 66e4bc10..682f170c 100644 --- a/src/module/exports.rs +++ b/src/module/exports.rs @@ -1,5 +1,7 @@ //! Exported items in a wasm module. +use anyhow::Context; + use crate::emit::{Emit, EmitContext}; use crate::parse::IndicesToIds; use crate::tombstone_arena::{Id, Tombstone, TombstoneArena}; @@ -100,6 +102,16 @@ impl ModuleExports { }) } + /// Retrieve an exported function by name + pub fn get_func_by_name(&self, name: impl AsRef) -> Result { + self.iter() + .find_map(|expt| match expt.item { + ExportItem::Function(fid) if expt.name == name.as_ref() => Some(fid), + _ => None, + }) + .with_context(|| format!("unable to find function export '{}'", name.as_ref())) + } + /// Get a reference to a table export given its table id. pub fn get_exported_table(&self, t: TableId) -> Option<&Export> { self.iter().find(|e| match e.item { @@ -354,6 +366,16 @@ mod tests { assert!(module.exports.get_exported_func(fn_id).is_none()); } + #[test] + fn get_func_by_name() { + let mut module = Module::default(); + let fn_id: FunctionId = always_the_same_id(); + let export_id: ExportId = module.exports.add("dummy", fn_id); + assert!(module.exports.get_func_by_name("dummy").is_ok()); + module.exports.delete(export_id); + assert!(module.exports.get_func_by_name("dummy").is_err()); + } + #[test] fn iter_mut_can_update_export_item() { let mut module = Module::default(); diff --git a/src/module/functions/mod.rs b/src/module/functions/mod.rs index 78dd0703..dca3bb49 100644 --- a/src/module/functions/mod.rs +++ b/src/module/functions/mod.rs @@ -1,5 +1,15 @@ //! Functions within a wasm module. +use std::cmp; +use std::collections::BTreeMap; + +use anyhow::{bail, Context}; +use wasm_encoder::Encode; +use wasmparser::{FuncValidator, FunctionBody, Range, ValidatorResources}; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + mod local_function; use crate::emit::{Emit, EmitContext}; @@ -11,19 +21,19 @@ use crate::parse::IndicesToIds; use crate::tombstone_arena::{Id, Tombstone, TombstoneArena}; use crate::ty::TypeId; use crate::ty::ValType; -use std::cmp; -use std::collections::BTreeMap; -use wasm_encoder::Encode; -use wasmparser::{FuncValidator, FunctionBody, Range, ValidatorResources}; - -#[cfg(feature = "parallel")] -use rayon::prelude::*; +use crate::{ExportItem, Memory, MemoryId}; pub use self::local_function::LocalFunction; /// A function identifier. pub type FunctionId = Id; +/// Parameter(s) to a function +pub type FuncParams = Vec; + +/// Result(s) of a given function +pub type FuncResults = Vec; + /// A wasm function. /// /// Either defined locally or externally and then imported; see `FunctionKind`. @@ -418,6 +428,119 @@ impl Module { Ok(()) } + + /// Retrieve the ID for the first exported memory. + /// + /// This method does not work in contexts with [multi-memory enabled](https://github.com/WebAssembly/multi-memory), + /// and will error if more than one memory is present. + pub fn get_memory_id(&self) -> Result { + if self.memories.len() > 1 { + bail!("multiple memories unsupported") + } + + self.memories + .iter() + .next() + .map(Memory::id) + .context("module does not export a memory") + } + + /// Replace a single exported function with the result of the provided builder function. + /// + /// For example, if you wanted to replace an exported function with a no-op, + /// + /// ```ignore + /// // Since `FunctionBuilder` requires a mutable pointer to the module's types, + /// // we must build it *outside* the closure and `move` it in + /// let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); + /// + /// module.replace_exported_func(fid, move || { + /// builder.func_body().unreachable(); + /// builder.local_func(vec![]) + /// }); + /// ``` + /// + /// This function returns the function ID of the *new* function, + /// after it has been inserted into the module as an export. + pub fn replace_exported_func(&mut self, fid: FunctionId, fn_builder: F) -> Result + where + F: FnOnce((&FuncParams, &FuncResults)) -> Result, + { + match (self.exports.get_exported_func(fid), self.funcs.get(fid)) { + ( + Some(exported_fn), + Function { + kind: FunctionKind::Local(lf), + .. + }, + ) => { + // Retrieve the params & result types for the exported (local) function + let ty = self.types.get(lf.ty()); + let (params, results) = (ty.params().to_vec(), ty.results().to_vec()); + + // Add the function produced by `fn_builder` as a local function, + let new_fid = self.funcs.add_local( + fn_builder((¶ms, &results)).context("export fn builder failed")?, + ); + + // Mutate the existing export to use the new local function + let export = self.exports.get_mut(exported_fn.id()); + export.item = ExportItem::Function(new_fid); + + Ok(new_fid) + } + // The export didn't exist, or the function isn't the kind we expect + _ => bail!("cannot replace function [{fid:?}], it is not an exported function"), + } + } + + /// Replace a single imported function with the result of the provided builder function. + /// + /// ```ignore + /// // Since `FunctionBuilder` requires a mutable pointer to the module's types, + /// // we must build it *outside* the closure and `move` it in + /// let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); + /// + /// module.replace_imported_func(fid, move || { + /// builder.func_body().unreachable(); + /// builder.local_func(vec![]) + /// }); + /// ``` + /// + /// This function returns the function ID of the *new* function, and + /// removes the existing import that has been replaced (the function will become local). + pub fn replace_imported_func(&mut self, fid: FunctionId, fn_builder: F) -> Result + where + F: FnOnce((&FuncParams, &FuncResults)) -> Result, + { + // If the function is in the imports, replace it + match (self.imports.get_imported_func(fid), self.funcs.get(fid)) { + ( + Some(original_imported_fn), + Function { + kind: FunctionKind::Import(ImportedFunction { ty: tid, .. }), + .. + }, + ) => { + // Retrieve the params & result types for the imported function + let ty = self.types.get(*tid); + let (params, results) = (ty.params().to_vec(), ty.results().to_vec()); + + // Mutate the existing function, changing it from a FunctionKind::ImportedFunction + // to the local function produced by running the provided `fn_builder` + let func = self.funcs.get_mut(fid); + func.kind = FunctionKind::Local( + fn_builder((¶ms, &results)).context("import fn builder failed")?, + ); + + self.imports.delete(original_imported_fn.id()); + + Ok(fid) + } + // The export didn't exist, or the function isn't the kind we expect + _ => bail!("cannot replace function [{fid:?}], it is not an imported function"), + } + } } fn used_local_functions<'a>(cx: &mut EmitContext<'a>) -> Vec<(FunctionId, &'a LocalFunction, u64)> { @@ -535,3 +658,192 @@ impl Emit for ModuleFunctions { cx.code_transform.instruction_map = instruction_map.into_iter().collect(); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Export, FunctionBuilder, Module}; + + #[test] + fn get_memory_id() { + let mut module = Module::default(); + let expected_id = module.memories.add_local(false, 0, None); + assert!(module.get_memory_id().is_ok_and(|id| id == expected_id)); + } + + /// Running `replace_exported_func` with a closure that builds + /// a function should replace the existing function with the new one + #[test] + fn replace_exported_func() { + let mut module = Module::default(); + + // Create original function + let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); + builder.func_body().i32_const(1234).drop(); + let original_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs); + let original_export_id = module.exports.add("dummy", original_fn_id); + + // Create builder to use inside closure + let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); + + // Replace the existing function with a new one with a reversed const value + let new_fn_id = module + .replace_exported_func(original_fn_id, move |_| { + builder.func_body().i32_const(4321).drop(); + Ok(builder.local_func(vec![])) + }) + .expect("function replacement worked"); + + assert!( + module.exports.get_exported_func(original_fn_id).is_none(), + "replaced function cannot be gotten by ID" + ); + + // Ensure the function was replaced + match module + .exports + .get_exported_func(new_fn_id) + .expect("failed to unwrap exported func") + { + exp @ Export { + item: ExportItem::Function(fid), + .. + } => { + assert_eq!(*fid, new_fn_id, "retrieved function ID matches"); + assert_eq!(exp.id(), original_export_id, "export ID is unchanged"); + } + _ => panic!("expected an Export with a Function inside"), + } + } + + /// Running `replace_exported_func` with a closure that returns None + /// should replace the function with a generated no-op function + #[test] + fn replace_exported_func_generated_no_op() { + let mut module = Module::default(); + + // Create original function + let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); + builder.func_body().i32_const(1234).drop(); + let original_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs); + let original_export_id = module.exports.add("dummy", original_fn_id); + + // Create builder to use inside closure + let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); + + // Replace the existing function with a new one with a reversed const value + let new_fn_id = module + .replace_exported_func(original_fn_id, move |_| { + builder.func_body().unreachable(); + Ok(builder.local_func(vec![])) + }) + .expect("export function replacement worked"); + + assert!( + module.exports.get_exported_func(original_fn_id).is_none(), + "replaced export function cannot be gotten by ID" + ); + + // Ensure the function was replaced + match module + .exports + .get_exported_func(new_fn_id) + .expect("failed to unwrap exported func") + { + exp @ Export { + item: ExportItem::Function(fid), + name, + .. + } => { + assert_eq!(name, "dummy", "function name on export is unchanged"); + assert_eq!(*fid, new_fn_id, "retrieved function ID matches"); + assert_eq!(exp.id(), original_export_id, "export ID is unchanged"); + } + _ => panic!("expected an Export with a Function inside"), + } + } + + /// Running `replace_imported_func` with a closure that builds + /// a function should replace the existing function with the new one + #[test] + fn replace_imported_func() { + let mut module = Module::default(); + + // Create original import function + let types = module.types.add(&[], &[]); + let (original_fn_id, original_import_id) = module.add_import_func("mod", "dummy", types); + + // Create builder to use inside closure + let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); + + // Replace the existing function with a new one with a reversed const value + let new_fn_id = module + .replace_imported_func(original_fn_id, |_| { + builder.func_body().i32_const(4321).drop(); + Ok(builder.local_func(vec![])) + }) + .expect("import fn replacement worked"); + + assert!( + !module.imports.iter().any(|i| i.id() == original_import_id), + "original import is missing", + ); + + assert!( + module.imports.get_imported_func(original_fn_id).is_none(), + "replaced import function cannot be gotten by ID" + ); + + assert!( + module.imports.get_imported_func(new_fn_id).is_none(), + "new import function cannot be gotten by ID (it is now local)" + ); + + assert!( + matches!(module.funcs.get(new_fn_id).kind, FunctionKind::Local(_)), + "new local function has the right kind" + ); + } + + /// Running `replace_imported_func` with a closure that returns None + /// should replace the function with a generated no-op function + #[test] + fn replace_imported_func_generated_no_op() { + let mut module = Module::default(); + + // Create original import function + let types = module.types.add(&[], &[]); + let (original_fn_id, original_import_id) = module.add_import_func("mod", "dummy", types); + + // Create builder to use inside closure + let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); + + // Replace the existing function with a new one with a reversed const value + let new_fn_id = module + .replace_imported_func(original_fn_id, |_| { + builder.func_body().unreachable(); + Ok(builder.local_func(vec![])) + }) + .expect("import fn replacement worked"); + + assert!( + !module.imports.iter().any(|i| i.id() == original_import_id), + "original import is missing", + ); + + assert!( + module.imports.get_imported_func(original_fn_id).is_none(), + "replaced import function cannot be gotten by ID" + ); + + assert!( + module.imports.get_imported_func(new_fn_id).is_none(), + "new import function cannot be gotten by ID (it is now local)" + ); + + assert!( + matches!(module.funcs.get(new_fn_id).kind, FunctionKind::Local(_)), + "new local function has the right kind" + ); + } +} diff --git a/src/module/imports.rs b/src/module/imports.rs index de2f5bb4..7e89821e 100644 --- a/src/module/imports.rs +++ b/src/module/imports.rs @@ -1,11 +1,12 @@ //! A wasm module's imports. +use anyhow::{bail, Context}; + use crate::emit::{Emit, EmitContext}; use crate::parse::IndicesToIds; use crate::tombstone_arena::{Id, Tombstone, TombstoneArena}; use crate::{FunctionId, GlobalId, MemoryId, Result, TableId}; use crate::{Module, TypeId, ValType}; -use anyhow::bail; /// The id of an import. pub type ImportId = Id; @@ -103,6 +104,32 @@ impl ModuleImports { Some(import?.0) } + + /// Retrieve an imported function by name, including the module in which it resides + pub fn get_func_by_name( + &self, + module: impl AsRef, + name: impl AsRef, + ) -> Result { + self.iter() + .find_map(|impt| match impt.kind { + ImportKind::Function(fid) + if impt.module == module.as_ref() && impt.name == name.as_ref() => + { + Some(fid) + } + _ => None, + }) + .with_context(|| format!("unable to find function export '{}'", name.as_ref())) + } + + /// Retrieve an imported function by ID + pub fn get_imported_func(&self, fid: FunctionId) -> Option<&Import> { + self.arena.iter().find_map(|(_, import)| match import.kind { + ImportKind::Function(id) if fid == id => Some(import), + _ => None, + }) + } } impl Module { @@ -313,3 +340,36 @@ impl From for ImportKind { ImportKind::Table(id) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{FunctionBuilder, Module}; + + #[test] + fn get_imported_func() { + let mut module = Module::default(); + + let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); + builder.func_body().i32_const(1234).drop(); + let new_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs); + module.imports.add("mod", "dummy", new_fn_id); + + assert!(module.imports.get_imported_func(new_fn_id).is_some()); + } + + #[test] + fn get_func_by_name() { + let mut module = Module::default(); + + let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); + builder.func_body().i32_const(1234).drop(); + let new_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs); + module.imports.add("mod", "dummy", new_fn_id); + + assert!(module + .imports + .get_func_by_name("mod", "dummy") + .is_ok_and(|fid| fid == new_fn_id)); + } +} diff --git a/src/module/memories.rs b/src/module/memories.rs index 55342528..48acc339 100644 --- a/src/module/memories.rs +++ b/src/module/memories.rs @@ -115,6 +115,11 @@ impl ModuleMemories { pub fn iter_mut(&mut self) -> impl Iterator { self.arena.iter_mut().map(|(_, f)| f) } + + /// Get the number of memories in this module + pub fn len(&self) -> usize { + self.arena.len() + } } impl Module { @@ -165,3 +170,20 @@ impl Emit for ModuleMemories { cx.wasm_module.section(&wasm_memory_section); } } + +#[cfg(test)] +mod tests { + use crate::Module; + + #[test] + fn memories_len() { + let mut module = Module::default(); + assert_eq!(module.memories.len(), 0); + + module.memories.add_local(false, 0, Some(1024)); + assert_eq!(module.memories.len(), 1); + + module.memories.add_local(true, 1024, Some(2048)); + assert_eq!(module.memories.len(), 2); + } +}