Skip to content

Commit

Permalink
wasi-nn: refactor to allow preview2 access
Browse files Browse the repository at this point in the history
This change refactors the `wasmtime-wasi-nn` crate to allow access from
both `preview1` and `preview2` ABIs. Though the `wasi-nn` specification
has included a WIT description for some time, here we use some in-tree
files until WebAssembly/wasi-nn#38 is landed.
The `preview2` code is not exercised anywhere yet: ideally this would be
wired up once component model `resource`s are fully implemented in
Wasmtime.
  • Loading branch information
abrown committed Aug 8, 2023
1 parent de4ede0 commit 13dd460
Show file tree
Hide file tree
Showing 14 changed files with 624 additions and 186 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,19 @@ readme = "README.md"
edition.workspace = true

[dependencies]
# These dependencies are necessary for the witx-generation macros to work:
# These dependencies are necessary for the WIT-generation macros to work:
anyhow = { workspace = true }
wiggle = { workspace = true }

# This dependency is necessary for the WIT-generation macros to work:
wasmtime = { workspace = true, optional = true, features = ["component-model"] }

# These dependencies are necessary for the wasi-nn implementation:
openvino = { version = "0.5.0", features = ["runtime-linking"] }
thiserror = { workspace = true }

[build-dependencies]
walkdir = { workspace = true }

[features]
preview2 = ["wasmtime"]
31 changes: 26 additions & 5 deletions crates/wasi-nn/src/api.rs → crates/wasi-nn/src/backend/mod.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,38 @@
//! Define the Rust interface a backend must implement in order to be used by
//! this crate. the `Box<dyn ...>` types returned by these interfaces allow
//! this crate. The `Box<dyn ...>` types returned by these interfaces allow
//! implementations to maintain backend-specific state between calls.

use crate::witx::types::{ExecutionTarget, GraphBuilderArray, Tensor};
mod openvino;

use self::openvino::OpenvinoBackend;
use crate::types::{ExecutionTarget, Tensor};
use thiserror::Error;
use wiggle::GuestError;

#[derive(Hash, PartialEq, Eq, Clone, Copy)]
pub(crate) enum BackendKind {
OpenVINO,
}
impl From<u8> for BackendKind {
fn from(value: u8) -> Self {
match value {
0 => BackendKind::OpenVINO,
_ => panic!("invalid backend"),
}
}
}

///
pub(crate) fn list() -> Vec<(BackendKind, Box<dyn Backend>)> {
vec![(BackendKind::OpenVINO, Box::new(OpenvinoBackend::default()))]
}

/// A [Backend] contains the necessary state to load [BackendGraph]s.
pub(crate) trait Backend: Send + Sync {
fn name(&self) -> &str;
fn load(
&mut self,
builders: &GraphBuilderArray<'_>,
builders: &[&[u8]],
target: ExecutionTarget,
) -> Result<Box<dyn BackendGraph>, BackendError>;
}
Expand All @@ -25,7 +46,7 @@ pub(crate) trait BackendGraph: Send + Sync {
/// A [BackendExecutionContext] performs the actual inference; this is the
/// backing implementation for a [crate::witx::types::GraphExecutionContext].
pub(crate) trait BackendExecutionContext: Send + Sync {
fn set_input(&mut self, index: u32, tensor: &Tensor<'_>) -> Result<(), BackendError>;
fn set_input<'a>(&mut self, index: u32, tensor: &Tensor<'a>) -> Result<(), BackendError>;
fn compute(&mut self) -> Result<(), BackendError>;
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError>;
}
Expand All @@ -39,7 +60,7 @@ pub enum BackendError {
#[error("Failed while accessing guest module")]
GuestAccess(#[from] GuestError),
#[error("The backend expects {0} buffers, passed {1}")]
InvalidNumberOfBuilders(u32, u32),
InvalidNumberOfBuilders(usize, usize),
#[error("Not enough memory to copy tensor data of size: {0}")]
NotEnoughMemory(usize),
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
//! Implements the wasi-nn API.
//! Implements a `wasi-nn` [`Backend`] using OpenVINO.

use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph};
use crate::witx::types::{ExecutionTarget, GraphBuilderArray, Tensor, TensorType};
use super::{Backend, BackendError, BackendExecutionContext, BackendGraph};
use crate::types::{ExecutionTarget, Tensor, TensorType};
use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc};
use std::sync::Arc;

#[derive(Default)]
pub(crate) struct OpenvinoBackend(Option<openvino::Core>);

unsafe impl Send for OpenvinoBackend {}
unsafe impl Sync for OpenvinoBackend {}

Expand All @@ -18,7 +17,7 @@ impl Backend for OpenvinoBackend {

fn load(
&mut self,
builders: &GraphBuilderArray<'_>,
builders: &[&[u8]],
target: ExecutionTarget,
) -> Result<Box<dyn BackendGraph>, BackendError> {
if builders.len() != 2 {
Expand All @@ -34,16 +33,8 @@ impl Backend for OpenvinoBackend {
}

// Read the guest array.
let builders = builders.as_ptr();
let xml = builders
.read()?
.as_slice()?
.expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)");
let weights = builders
.add(1)?
.read()?
.as_slice()?
.expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)");
let xml = &builders[0];
let weights = &builders[1];

// Construct OpenVINO graph structures: `cnn_network` contains the graph
// structure, `exec_network` can perform inference.
Expand All @@ -53,8 +44,9 @@ impl Backend for OpenvinoBackend {
.expect("openvino::Core was previously constructed");
let mut cnn_network = core.read_network_from_buffer(&xml, &weights)?;

// TODO this is a temporary workaround. We need a more eligant way to specify the layout in the long run.
// However, without this newer versions of OpenVINO will fail due to parameter mismatch.
// TODO: this is a temporary workaround. We need a more elegant way to
// specify the layout in the long run. However, without this newer
// versions of OpenVINO will fail due to parameter mismatch.
for i in 0..cnn_network.get_inputs_len()? {
let name = cnn_network.get_input_name(i)?;
cnn_network.set_input_layout(&name, Layout::NHWC)?;
Expand Down Expand Up @@ -85,27 +77,14 @@ impl BackendGraph for OpenvinoGraph {
struct OpenvinoExecutionContext(Arc<openvino::CNNNetwork>, openvino::InferRequest);

impl BackendExecutionContext for OpenvinoExecutionContext {
fn set_input(&mut self, index: u32, tensor: &Tensor<'_>) -> Result<(), BackendError> {
fn set_input<'a>(&mut self, index: u32, tensor: &Tensor<'a>) -> Result<(), BackendError> {
let input_name = self.0.get_input_name(index as usize)?;

// Construct the blob structure.
let dimensions = tensor
.dimensions
.as_slice()?
.expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)")
.iter()
.map(|d| *d as usize)
.collect::<Vec<_>>();
let precision = map_tensor_type_to_precision(tensor.type_);

// TODO There must be some good way to discover the layout here; this
// should not have to default to NHWC.
let desc = TensorDesc::new(Layout::NHWC, &dimensions, precision);
let data = tensor
.data
.as_slice()?
.expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)");
let blob = openvino::Blob::new(&desc, &data)?;
// Construct the blob structure. TODO: there must be some good way to
// discover the layout here; `desc` should not have to default to NHWC.
let precision = map_tensor_type_to_precision(tensor.ty);
let desc = TensorDesc::new(Layout::NHWC, tensor.dims, precision);
let blob = openvino::Blob::new(&desc, tensor.data)?;

// Actually assign the blob to the request.
self.1.set_blob(&input_name, &blob)?;
Expand Down Expand Up @@ -147,9 +126,9 @@ impl From<SetupError> for BackendError {
/// `ExecutionTarget` enum provided by wasi-nn.
fn map_execution_target_to_string(target: ExecutionTarget) -> &'static str {
match target {
ExecutionTarget::Cpu => "CPU",
ExecutionTarget::Gpu => "GPU",
ExecutionTarget::Tpu => unimplemented!("OpenVINO does not support TPU execution targets"),
ExecutionTarget::CPU => "CPU",
ExecutionTarget::GPU => "GPU",
ExecutionTarget::TPU => unimplemented!("OpenVINO does not support TPU execution targets"),
}
}

Expand Down
70 changes: 57 additions & 13 deletions crates/wasi-nn/src/ctx.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,59 @@
//! Implements the base structure (i.e. [WasiNnCtx]) that will provide the
//! implementation of the wasi-nn API.
use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph};
use crate::openvino::OpenvinoBackend;
use crate::r#impl::UsageError;
use crate::witx::types::{Graph, GraphEncoding, GraphExecutionContext};
//!

use crate::backend::{
self, Backend, BackendError, BackendExecutionContext, BackendGraph, BackendKind,
};
use crate::types::GraphEncoding;
use std::collections::HashMap;
use std::hash::Hash;
use thiserror::Error;
use wiggle::GuestError;

// #[derive(Eq, Hash, PartialEq, Clone, Copy)]
// pub(crate) struct GraphId(u32);
// impl From<u32> for GraphId {
// fn from(value: u32) -> Self {
// Self(value)
// }
// }
// impl Into<u32> for GraphId {
// fn into(self) -> u32 {
// self.0
// }
// }

// #[derive(Eq, Hash, PartialEq, Clone, Copy)]
// pub(crate) struct GraphExecutionContextId(u32);
// impl From<u32> for GraphExecutionContextId {
// fn from(value: u32) -> Self {
// Self(value)
// }
// }
// impl Into<u32> for GraphExecutionContextId {
// fn into(self) -> u32 {
// self.0
// }
// }

type GraphId = u32;
type GraphExecutionContextId = u32;

/// Capture the state necessary for calling into the backend ML libraries.
pub struct WasiNnCtx {
pub(crate) backends: HashMap<u8, Box<dyn Backend>>,
pub(crate) graphs: Table<Graph, Box<dyn BackendGraph>>,
pub(crate) executions: Table<GraphExecutionContext, Box<dyn BackendExecutionContext>>,
pub(crate) backends: HashMap<BackendKind, Box<dyn Backend>>,
pub(crate) graphs: Table<GraphId, Box<dyn BackendGraph>>,
pub(crate) executions: Table<GraphExecutionContextId, Box<dyn BackendExecutionContext>>,
}

impl WasiNnCtx {
/// Make a new context from the default state.
pub fn new() -> WasiNnResult<Self> {
let mut backends = HashMap::new();
backends.insert(
// This is necessary because Wiggle's variant types do not derive
// `Hash` and `Eq`.
GraphEncoding::Openvino.into(),
Box::new(OpenvinoBackend::default()) as Box<dyn Backend>,
);
for (kind, backend) in backend::list() {
backends.insert(kind, backend);
}
Ok(Self {
backends,
graphs: Table::default(),
Expand All @@ -45,6 +73,22 @@ pub enum WasiNnError {
UsageError(#[from] UsageError),
}

#[derive(Debug, Error)]
pub enum UsageError {
#[error("Invalid context; has the load function been called?")]
InvalidContext,
#[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")]
InvalidEncoding(GraphEncoding),
#[error("OpenVINO expects only two buffers (i.e. [ir, weights]), passed: {0}")]
InvalidNumberOfBuilders(u32),
#[error("Invalid graph handle; has it been loaded?")]
InvalidGraphHandle,
#[error("Invalid execution context handle; has it been initialized?")]
InvalidExecutionContextHandle,
#[error("Not enough memory to copy tensor data of size: {0}")]
NotEnoughMemory(u32),
}

pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;

/// Record handle entries in a table.
Expand Down
93 changes: 0 additions & 93 deletions crates/wasi-nn/src/impl.rs

This file was deleted.

10 changes: 5 additions & 5 deletions crates/wasi-nn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
mod api;
mod backend;
mod ctx;
mod r#impl;
mod openvino;
mod witx;

pub use ctx::WasiNnCtx;
pub use witx::wasi_ephemeral_nn::add_to_linker;
pub mod preview1;
#[cfg(feature = "preview2")]
pub mod preview2;
pub mod types;
Loading

0 comments on commit 13dd460

Please sign in to comment.