diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 5a3e2466..d3a1a0dc 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -73,6 +73,7 @@ jobs: run: | cargo tarpaulin -p ort --features fetch-models --verbose --timeout 120 --out xml - name: Upload to codecov.io - uses: codecov/codecov-action@v2 + uses: codecov/codecov-action@v3 with: + token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false diff --git a/.gitignore b/.gitignore index 6d829e87..a76dbd4f 100644 --- a/.gitignore +++ b/.gitignore @@ -191,3 +191,6 @@ WixTools/ # IDEA .idea + +# Glassbench results +/glassbench*.db diff --git a/Cargo.toml b/Cargo.toml index 0711d72c..f0926671 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,19 +3,21 @@ members = [ 'ort-sys', 'examples/gpt2', 'examples/model-info', - 'examples/yolov8' + 'examples/yolov8', + 'examples/modnet' ] default-members = [ '.', 'examples/gpt2', 'examples/model-info', - 'examples/yolov8' + 'examples/yolov8', + 'examples/modnet' ] [package] name = "voicevox-ort" -description = "A safe Rust wrapper for ONNX Runtime 1.16 - Optimize and Accelerate Machine Learning Inferencing" -version = "2.0.0-alpha.4" +description = "A safe Rust wrapper for ONNX Runtime 1.17 - Optimize and Accelerate Machine Learning Inferencing" +version = "2.0.0-rc.0" edition = "2021" rust-version = "1.70" license = "MIT OR Apache-2.0" @@ -28,7 +30,7 @@ authors = [ "pyke.io ", "Nicolas Bigaouette " ] -include = [ "src/", "examples/", "tests/", "LICENSE-APACHE", "LICENSE-MIT", "README.md" ] +include = [ "src/", "LICENSE-APACHE", "LICENSE-MIT", "README.md" ] [profile.release] opt-level = 3 @@ -77,6 +79,7 @@ ndarray = { version = "0.15", optional = true } thiserror = "1.0" voicevox-ort-sys = { version = "2.0.0-alpha.4", path = "ort-sys" } libloading = { version = "0.8", optional = true } +compact_str = "0.7" ureq = { version = "2.1", optional = true, default-features = false, features = [ "tls" ] } tracing = "0.1" @@ -95,3 +98,8 @@ ureq = "2.1" image = "0.24" test-log = { version = "0.2", default-features = false, features = [ "trace" ] } tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] } +glassbench = "0.4" + +[[bench]] +name = "squeezenet" +harness = false diff --git a/README.md b/README.md index b512344a..83294348 100644 --- a/README.md +++ b/README.md @@ -3,10 +3,10 @@
Coverage Results Crates.io Open Collective backers and sponsors
- Crates.io ONNX Runtime + Crates.io ONNX Runtime -`ort` is an (unofficial) [ONNX Runtime](https://onnxruntime.ai/) 1.16 wrapper for Rust based on the now inactive [`onnxruntime-rs`](https://github.com/nbigaouette/onnxruntime-rs). ONNX Runtime accelerates ML inference on both CPU & GPU. +`ort` is an (unofficial) [ONNX Runtime](https://onnxruntime.ai/) 1.17 wrapper for Rust based on the now inactive [`onnxruntime-rs`](https://github.com/nbigaouette/onnxruntime-rs). ONNX Runtime accelerates ML inference on both CPU & GPU. ## ๐Ÿ“– Documentation - [Guide](https://ort.pyke.io/) diff --git a/benches/squeezenet.rs b/benches/squeezenet.rs new file mode 100644 index 00000000..5da0239e --- /dev/null +++ b/benches/squeezenet.rs @@ -0,0 +1,60 @@ +use std::{path::Path, sync::Arc}; + +use glassbench::{pretend_used, Bench}; +use image::{imageops::FilterType, ImageBuffer, Pixel, Rgb}; +use ndarray::{s, Array4}; +use ort::{GraphOptimizationLevel, Session}; + +fn load_squeezenet_data() -> ort::Result<(Session, Array4)> { + const IMAGE_TO_LOAD: &str = "mushroom.png"; + + ort::init().with_name("integration_test").commit()?; + + let session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level1)? + .with_intra_threads(1)? + .with_model_downloaded("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/squeezenet.onnx") + .expect("Could not download model from file"); + + let input0_shape: &Vec = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"); + + let image_buffer: ImageBuffer, Vec> = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join(IMAGE_TO_LOAD)) + .unwrap() + .resize(input0_shape[2] as u32, input0_shape[3] as u32, FilterType::Nearest) + .to_rgb8(); + + let mut array = ndarray::Array::from_shape_fn((1, 3, 224, 224), |(_, c, j, i)| { + let pixel = image_buffer.get_pixel(i as u32, j as u32); + let channels = pixel.channels(); + (channels[c] as f32) / 255.0 + }); + + let mean = [0.485, 0.456, 0.406]; + let std = [0.229, 0.224, 0.225]; + for c in 0..3 { + let mut channel_array = array.slice_mut(s![0, c, .., ..]); + channel_array -= mean[c]; + channel_array /= std[c]; + } + + Ok((session, array)) +} + +fn bench_squeezenet(bench: &mut Bench) { + let (session, data) = load_squeezenet_data().unwrap(); + bench.task("ArrayView", |task| { + task.iter(|| { + pretend_used(session.run(ort::inputs![data.view()].unwrap()).unwrap()); + }) + }); + + let raw = Arc::new(data.as_standard_layout().as_slice().unwrap().to_owned().into_boxed_slice()); + let shape: Vec = data.shape().iter().map(|c| *c as _).collect(); + bench.task("Raw data", |task| { + task.iter(|| { + pretend_used(session.run(ort::inputs![(shape.clone(), Arc::clone(&raw))].unwrap()).unwrap()); + }) + }); +} + +glassbench::glassbench!("SqueezeNet", bench_squeezenet,); diff --git a/codecov.yml b/codecov.yml index d39c511f..49416a8b 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,3 +1,2 @@ ignore: - - src/download/**/*.rs - - src/download.rs + - "src/execution_providers" diff --git a/docs/introduction.mdx b/docs/introduction.mdx index c94c808d..75d63cd0 100644 --- a/docs/introduction.mdx +++ b/docs/introduction.mdx @@ -3,49 +3,58 @@ title: Introduction ---

- `ort` is an open-source Rust binding for [ONNX Runtime](https://github.com/microsoft/onnxruntime). + `ort` is an open-source Rust binding for [ONNX Runtime](https://onnxruntime.ai/).

-`ort` makes it easy to deploy your machine learning models to production via ONNX Runtime, a [hardware-accelerated](/perf/execution-providers) inference engine. With `ort` + ONNX Runtime, you can run almost any ML model (including ResNet, YOLOv8, BERT, LLaMA), often much faster than PyTorch, and with the added bonus of Rust's efficiency. +`ort` makes it easy to deploy your machine learning models to production via [ONNX Runtime](https://onnxruntime.ai/), a hardware-accelerated inference engine. With `ort` + ONNX Runtime, you can run almost any ML model (including ResNet, YOLOv8, BERT, LLaMA) on almost any hardware, often far faster than PyTorch, and with the added bonus of Rust's efficiency. + + + These docs are for the latest alpha version of `ort`, `2.0.0-rc.0`. This version is production-ready (just not API stable) and we recommend new & existing projects use it. + # Why `ort`? There are a few other ONNX Runtime crates out there, so why use `ort`? For one, `ort` simply supports more features: -| Feature comparison | **๐Ÿ“• ort** | **๐Ÿ“— [ors](https://github.com/HaoboGu/ors)** | **๐ŸชŸ [onnxruntime-rs](https://github.com/microsoft/onnxruntime/tree/main/rust)** | -|------------------------|-----------|-----------|----------------------| -| Upstream version | **v1.16.3** | v1.12.0 | v1.8 | -| `dlopen()`? | โœ… | โœ… | โŒ | -| Execution providers? | โœ… | โŒ | โŒ | -| IOBinding? | โœ… | โŒ | โŒ | -| String tensors? | โœ… | โŒ | โš ๏ธ input only | -| Multiple output types? | โœ… | โœ… | โŒ | -| Multiple input types? | โœ… | โœ… | โŒ | -| In-memory session? | โœ… | โœ… | โœ… | -| WebAssembly? | โœ… | โŒ | โŒ | +| Feature comparison | **๐Ÿ“• ort** | **๐Ÿ“— [ors](https://github.com/HaoboGu/ors)** | **๐ŸชŸ [onnxruntime-rs](https://github.com/microsoft/onnxruntime/tree/main/rust)** | +|---------------------------|-----------|-----------|----------------------| +| Upstream version | **v1.17.0** | v1.12.0 | v1.8 | +| `dlopen()`? | โœ… | โœ… | โŒ | +| Execution providers? | โœ… | โŒ | โŒ | +| I/O Binding? | โœ… | โŒ | โŒ | +| String tensors? | โœ… | โŒ | โš ๏ธ input only | +| Multiple output types? | โœ… | โœ… | โŒ | +| Multiple input types? | โœ… | โœ… | โŒ | +| In-memory session? | โœ… | โœ… | โœ… | +| WebAssembly? | โœ… | โŒ | โŒ | +| Provides static binaries? | โœ… | โŒ | โŒ | +| Sequence & map types? | โœ… | โŒ | โŒ | Users of `ort` appreciate its ease of use and ergonomic API. `ort` is also battle tested in some pretty serious production scenarios. - [**Twitter**](https://twitter.com/) uses `ort` in part of their recommendations system, serving hundreds of millions of requests a day. - [**Bloop**](https://bloop.ai/)'s semantic code search feature is powered by `ort`. +- [**SurrealDB**](https://surrealdb.com/) uses `ort` in their [`surrealml`](https://github.com/surrealdb/surrealml) package. - [**Numerical Elixir**](https://github.com/elixir-nx) uses `ort` to create ONNX Runtime bindings for the Elixir language. -- [**`rust-bert`**](https://github.com/guillaume-be/rust-bert) implements many ready-to-use NLP pipelines in Rust a la Hugging Face Transformers, with an optional `ort` backend. +- [**`rust-bert`**](https://github.com/guillaume-be/rust-bert) implements many ready-to-use NLP pipelines in Rust ร  la Hugging Face Transformers with both [`tch`](https://crates.io/crates/tch) & `ort` backends. - [**`edge-transformers`**](https://github.com/npc-engine/edge-transformers) also implements Hugging Face Transformers pipelines in Rust using `ort`. -- We use `ort` in nearly all of our ML projects, including [VITRI](https://vitri.pyke.io/) ๐Ÿ˜Š # Getting started - If you have a [supported platform](/setup/platforms), installing `ort` couldn't be any simpler! + If you have a [supported platform](/setup/platforms) (and you probably do), installing `ort` couldn't be any simpler! Just add it to your Cargo dependencies: ```toml [dependencies] - ort = "2.0" + ort = "2.0.0-alpha.4" ``` - Your model will need to be converted to the [ONNX](https://onnx.ai/) format before you can use it; here's how to do that: - - The awesome folks at Hugging Face have [a guide](https://huggingface.co/docs/transformers/serialization) to export Transformers models to ONNX with ๐Ÿค— Optimum. - - For other PyTorch models, see the [`torch.onnx` module docs](https://pytorch.org/docs/stable/onnx.html). + Your model will need to be converted to the [ONNX](https://onnx.ai/) format before you can use it. + - The awesome folks at Hugging Face have [a guide](https://huggingface.co/docs/transformers/serialization) to export ๐Ÿค— Transformers models to ONNX with ๐Ÿค— Optimum. + - For other PyTorch models: [`torch.onnx`](https://pytorch.org/docs/stable/onnx.html) + - For `scikit-learn`: [`sklearn-onnx`](https://onnx.ai/sklearn-onnx/) + - For TensorFlow, Keras, TFlite, TensorFlow.js: [`tf2onnx`](https://github.com/onnx/tensorflow-onnx) + - For PaddlePaddle: [`Paddle2ONNX`](https://github.com/PaddlePaddle/Paddle2ONNX) Once you've got a model, load it via `ort` by creating a [`Session`](/fundamentals/session): @@ -64,6 +73,8 @@ Users of `ort` appreciate its ease of use and ergonomic API. `ort` is also battl ```rust let outputs = model.run(ort::inputs!["image" => image]?)?; + + // Postprocessing let output = outputs["output0"] .extract_tensor::() .unwrap() @@ -74,7 +85,7 @@ Users of `ort` appreciate its ease of use and ergonomic API. `ort` is also battl ... ``` - There are some [more useful examples](https://github.com/pykeio/ort/tree/main/examples) in our repo! + There are some more useful examples [in the `ort` repo](https://github.com/pykeio/ort/tree/main/examples)! diff --git a/docs/migrating/v2.mdx b/docs/migrating/v2.mdx index 2c22484f..03bdada4 100644 --- a/docs/migrating/v2.mdx +++ b/docs/migrating/v2.mdx @@ -14,7 +14,7 @@ ort::init() .commit()?; ``` -`commit()` must be called before any sessions are created to take effect. Otherwise, the environment will be set to default and cannot be modified afterwards. +`commit()` must be called before any sessions are created to take effect. Otherwise, a default environment will be created. The global environment can be updated afterward by calling `commit()` on another `EnvironmentBuilder`, however you'll need to recreate sessions after comitting the new environment in order for them to use it. ## Session creation `SessionBuilder::new(&environment)` has been soft-replaced with `Session::builder()`: @@ -144,6 +144,14 @@ The `ort::sys` module has been split out into [its own `ort-sys` crate](https:// ### `ndarray` is now optional The dependency on `ndarray` is now declared optional. If you use `ort` with `default-features = false`, you'll need to add the `ndarray` feature. +## Model Zoo structs have been removed +ONNX pushed a new Model Zoo structure that adds hundreds of different models. This is impractical to maintain, so the built-in structs have been removed. + +You can still use `Session::with_model_downloaded`, it just now takes a URL string instead of a struct. + +## Changes to logging +Environment-level logging configuration (i.e. `EnvironmentBuilder::with_log_level`) has been removed because it could cause unnecessary confusion with our `tracing` integration. + ## The Flattening All modules except `download` are no longer public. Exports have been flattened to the crate root, so i.e. `ort::session::Session` becomes `ort::Session`. @@ -153,3 +161,4 @@ The following types have been renamed with no other changes. - `OrtOwnedTensor` -> `Tensor` - `OrtResult`, `OrtError` -> `ort::Result`, `ort::Error` - `TensorDataToType` -> `ExtractTensorData` +- `TensorElementDataType`, `IntoTensorElementDataType` -> `TensorElementType`, `IntoTensorElementType` diff --git a/docs/migrating/version-mapping.mdx b/docs/migrating/version-mapping.mdx index ad5abae6..a48a6306 100644 --- a/docs/migrating/version-mapping.mdx +++ b/docs/migrating/version-mapping.mdx @@ -6,7 +6,7 @@ description: Information about `ort`'s versioning and relation to ONNX Runtime v ## A note on SemVer `ort` versions pre-2.0 were not SemVer compatible. From v2.0 onwards, breaking API changes are accompanied by a **major version update**. -Updates to the version of ONNX Runtime used by `ort` may occur on **minor** version updates, i.e. 2.0 ships with ONNX Runtime 1.16.2, but 2.1 may ship with 1.17.0. ONNX Runtime is generally forward compatible, but in case you require a specific version of ONNX Runtime, you should pin the minor version in your `Cargo.toml` using a [tilde requirement](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#tilde-requirements): +Updates to the version of ONNX Runtime used by `ort` may occur on **minor** version updates, i.e. 2.0 ships with ONNX Runtime 1.17.0, but 2.1 may ship with 1.18.0. ONNX Runtime is generally forward compatible, but in case you require a specific version of ONNX Runtime, you should pin the minor version in your `Cargo.toml` using a [tilde requirement](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#tilde-requirements): ```toml [dependencies] ort = { version = "~2.0", ... } @@ -16,12 +16,10 @@ ort = { version = "~2.0", ... } | **ort** | **ONNX Runtime** | | -------- | ----------------:| -| v2.0.0+ | v1.16.2 | +| v2.0.0+ | v1.17.0 | | v1.16.0-v1.16.2 | v1.16.0 | | v1.15.0-v1.15.5 | v1.15.1 | | v1.14.2-v1.14.8 | v1.14.1 | | v1.14.0-v1.14.1 | v1.14.0 | | v1.13.1-v1.13.3 | v1.13.1 | | v1.13.0 | v1.12.1 | - -If you need support for an old (<1.15) version of `ort`, or need an even older version of ONNX Runtime, [contact us](mailto:contact@pyke.io). diff --git a/docs/setup/linking.mdx b/docs/setup/linking.mdx index a39da4e0..15a95267 100644 --- a/docs/setup/linking.mdx +++ b/docs/setup/linking.mdx @@ -32,7 +32,7 @@ To use `load-dynamic`: ort = { version = "2", features = [ "load-dynamic" ] } ``` - + ```shell @@ -41,22 +41,15 @@ To use `load-dynamic`: ```rust - use std::env; - fn main() -> anyhow::Result<()> { - // Find our downloaded ONNX Runtime dylibs and set the environment variable for ort - env::set_var("ORT_DYLIB_PATH", crate::internal::find_onnxruntime_dylib()?); + // Find our custom ONNX Runtime dylibs and initialize `ort` with it. + let dylib_path = crate::internal::find_onnxruntime_dylib()?; // /etc/.../libonnxruntime.so - // IMPORTANT: You must set the environment variable **before** you use `ort`!!! - ort::init().commit()?; + ort::init_from(dylib_path).commit()?; Ok(()) } ``` - - - You MUST set the environment variable **before** you use any `ort` APIs! - diff --git a/docs/setup/platforms.mdx b/docs/setup/platforms.mdx index 6384b8ce..00ece644 100644 --- a/docs/setup/platforms.mdx +++ b/docs/setup/platforms.mdx @@ -12,13 +12,18 @@ Here are the supported platforms and binary availability status, as of v2.0. | Platform | x86 | x86-64 | ARMv7 | ARM64 | WASM32 | |:-------- |:------- |:------ |:------ |:------ |:------ | -| **Windows** | โญ• | ๐ŸŸข | โญ• | ๐Ÿ”ท | โŒ | -| **Linux** | โญ• | ๐ŸŸข | โญ• | ๐Ÿ”ท | โŒ | -| **macOS** | โŒ | ๐Ÿ”ท | โŒ | ๐Ÿ”ท | โŒ | +| **Windows** | โญ• | ๐ŸŸข\* | โญ• | ๐Ÿ”ท\* | โŒ | +| **Linux** | โญ• | ๐ŸŸขโ€  | โญ• | ๐Ÿ”ทโ€ก | โŒ | +| **macOS** | โŒ | ๐Ÿ”ทยง | โŒ | ๐Ÿ”ทยง | โŒ | | **iOS** | โŒ | โŒ | โŒ | โญ• | โŒ | | **Android** | โŒ | โŒ | โญ• | โญ• | โŒ | | **Web** | โŒ | โŒ | โŒ | โŒ | ๐Ÿ”ท | +\* Recent version of Windows 10/11 required for pyke binaries.
+โ€  glibc โ‰ฅ 2.31 (Ubuntu โ‰ฅ 20.04) required for pyke binaries.
+โ€ก glibc โ‰ฅ 2.35 (Ubuntu โ‰ฅ 22.04) required for pyke binaries.
+ยง macOS โ‰ฅ 10.15 required for pyke binaries. + If your platform is marked as ๐ŸŸข or ๐Ÿ”ท, you're in luck! Almost no setup will be required to get `ort` up and running. For platforms marked as โญ•, you'll need to [compile ONNX Runtime from source](https://onnxruntime.ai/docs/build/). Certain execution providers may not have binaries available. You can check EP binary support in the [execution providers](/perf/execution-providers) documentation. diff --git a/docs/troubleshooting/performance.mdx b/docs/troubleshooting/performance.mdx index 71be35ff..9351076e 100644 --- a/docs/troubleshooting/performance.mdx +++ b/docs/troubleshooting/performance.mdx @@ -52,7 +52,7 @@ title: 'Troubleshoot: Performance' ## Inference is slow, even with an EP! There are a few things you could try to improve performance: -- **Run `onnxsim` on the model.** Direct exports from PyTorch can leave a lot of junk nodes in the graph, which could hinder performance. [`onnxsim`](https://github.com/daquexian/onnx-simplifier) is a neat tool that can be used to simplify the ONNX graph and remove junk. +- **Run `onnxsim` on the model.** Direct exports from PyTorch can leave a lot of junk nodes in the graph, which could hinder performance. [`onnxsim`](https://github.com/daquexian/onnx-simplifier) is a neat tool that can be used to simplify the ONNX graph and potentially improve performance. - **Export with an older opset.** Some EPs might not support newer, more complex nodes. Try targeting an older ONNX opset when exporting your model to force it to export with simpler operations. - **Use the [transformer optimization tool](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers).** This is another neat tool that converts certain transformer-based models to far more optimized graphs. - **Try other EPs.** There may be multiple EPs for your hardware that have a more performant implementation. @@ -60,5 +60,5 @@ There are a few things you could try to improve performance: - For AMD, you can try ROCm, MIGraphX, or DirectML. - For ARM, you can try ArmNN, ACL, or XNNPACK. - See [Execution providers](/perf/execution-providers) for more information on supported EPs. -- **Use [`I/O binding`](/perf/io-binding).** This can reduce the latency between copying the session inputs/outputs to/from devices. +- **Use [`I/O binding`](/perf/io-binding).** This can reduce latency caused by copying the session inputs/outputs to/from devices. - **[Quantize your model.](https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html)** You can try quantizing your model to 8-bit precision. This comes with a small accuracy loss, but can sometimes provide a large performance boost. If the accuracy loss is too high, you can also use [float16/mixed precision](https://onnxruntime.ai/docs/performance/model-optimizations/float16.html). diff --git a/examples/modnet/Cargo.toml b/examples/modnet/Cargo.toml new file mode 100644 index 00000000..8ee083be --- /dev/null +++ b/examples/modnet/Cargo.toml @@ -0,0 +1,17 @@ +[package] +publish = false +name = "example-modnet" +version = "0.0.0" +edition = "2021" + +[dependencies] +ort = { path = "../../" } +ndarray = "0.15" +tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] } +image = "0.24" +tracing = "0.1" +show-image = { version = "0.13", features = [ "image", "raqote" ] } + +[features] +load-dynamic = [ "ort/load-dynamic" ] +cuda = [ "ort/cuda" ] diff --git a/examples/modnet/data/photo.jpg b/examples/modnet/data/photo.jpg new file mode 100644 index 00000000..0ccd6869 Binary files /dev/null and b/examples/modnet/data/photo.jpg differ diff --git a/examples/modnet/examples/modnet.rs b/examples/modnet/examples/modnet.rs new file mode 100644 index 00000000..42e1a3d8 --- /dev/null +++ b/examples/modnet/examples/modnet.rs @@ -0,0 +1,83 @@ +#![allow(clippy::manual_retain)] + +use std::{ops::Mul, path::Path}; + +use image::{imageops::FilterType, GenericImageView, ImageBuffer, Rgba}; +use ndarray::Array; +use ort::{inputs, CUDAExecutionProvider, Session}; +use show_image::{event, AsImageView, WindowOptions}; + +#[show_image::main] +fn main() -> ort::Result<()> { + tracing_subscriber::fmt::init(); + + ort::init() + .with_execution_providers([CUDAExecutionProvider::default().build()]) + .commit()?; + + let model = + Session::builder()?.with_model_downloaded("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/modnet_photographic_portrait_matting.onnx")?; + + let original_img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("photo.jpg")).unwrap(); + let (img_width, img_height) = (original_img.width(), original_img.height()); + let img = original_img.resize_exact(512, 512, FilterType::Triangle); + let mut input = Array::zeros((1, 3, 512, 512)); + for pixel in img.pixels() { + let x = pixel.0 as _; + let y = pixel.1 as _; + let [r, g, b, _] = pixel.2.0; + input[[0, 0, y, x]] = (r as f32 - 127.5) / 127.5; + input[[0, 1, y, x]] = (g as f32 - 127.5) / 127.5; + input[[0, 2, y, x]] = (b as f32 - 127.5) / 127.5; + } + + let outputs = model.run(inputs!["input" => input.view()]?)?; + + let binding = outputs["output"].extract_tensor::().unwrap(); + let output = binding.view(); + + // convert to 8-bit + let output = output.mul(255.0).map(|x| *x as u8); + let output = output.into_raw_vec(); + + // change rgb to rgba + let output_img = ImageBuffer::from_fn(512, 512, |x, y| { + let i = (x + y * 512) as usize; + Rgba([output[i], output[i], output[i], 255]) + }); + + let mut output = image::imageops::resize(&output_img, img_width, img_height, FilterType::Triangle); + output.enumerate_pixels_mut().for_each(|(x, y, pixel)| { + let origin = original_img.get_pixel(x, y); + pixel.0[3] = pixel.0[0]; + pixel.0[0] = origin.0[0]; + pixel.0[1] = origin.0[1]; + pixel.0[2] = origin.0[2]; + }); + + let window = show_image::context() + .run_function_wait(move |context| -> Result<_, String> { + let mut window = context + .create_window( + "ort + modnet", + WindowOptions { + size: Some([img_width, img_height]), + ..WindowOptions::default() + } + ) + .map_err(|e| e.to_string())?; + window.set_image("photo", &output.as_image_view().map_err(|e| e.to_string())?); + Ok(window.proxy()) + }) + .unwrap(); + + for event in window.event_channel().unwrap() { + if let event::WindowEvent::KeyboardInput(event) = event { + if event.input.key_code == Some(event::VirtualKeyCode::Escape) && event.input.state.is_pressed() { + break; + } + } + } + + Ok(()) +} diff --git a/ort-sys/Cargo.toml b/ort-sys/Cargo.toml index 5fd5758e..53f0ab4d 100644 --- a/ort-sys/Cargo.toml +++ b/ort-sys/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "voicevox-ort-sys" -description = "Unsafe Rust bindings for ONNX Runtime 1.16 - Optimize and Accelerate Machine Learning Inferencing" -version = "2.0.0-alpha.4" +description = "Unsafe Rust bindings for ONNX Runtime 1.17 - Optimize and Accelerate Machine Learning Inferencing" +version = "2.0.0-rc.0" edition = "2021" rust-version = "1.70" license = "MIT OR Apache-2.0" diff --git a/ort-sys/build.rs b/ort-sys/build.rs index 90b556ea..8ad84d32 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -65,6 +65,8 @@ fn copy_libraries(lib_dir: &Path, out_dir: &Path) { for out_dir in [out_dir.to_path_buf(), out_dir.join("examples"), out_dir.join("deps")] { #[cfg(windows)] let mut copy_fallback = false; + #[cfg(not(windows))] + let copy_fallback = false; let lib_files = std::fs::read_dir(lib_dir).unwrap_or_else(|_| panic!("Failed to read contents of `{}` (does it exist?)", lib_dir.display())); for lib_file in lib_files.filter(|e| { @@ -88,11 +90,12 @@ fn copy_libraries(lib_dir: &Path, out_dir: &Path) { #[cfg(unix)] std::os::unix::fs::symlink(&lib_path, &out_path).unwrap(); } - println!("cargo:rerun-if-changed={}", out_path.to_str().unwrap()); + if !copy_fallback { + println!("cargo:rerun-if-changed={}", out_path.to_str().unwrap()); + } } // If we had to fallback to copying files on Windows, break early to avoid copying to 3 different directories - #[cfg(windows)] if copy_fallback { break; } @@ -119,6 +122,7 @@ fn static_link_prerequisites(using_pyke_libs: bool) { println!("cargo:rustc-link-lib=stdc++"); } else if target_os == "windows" && (using_pyke_libs || cfg!(feature = "directml")) { println!("cargo:rustc-link-lib=dxguid"); + println!("cargo:rustc-link-lib=DXCORE"); println!("cargo:rustc-link-lib=DXGI"); println!("cargo:rustc-link-lib=D3D12"); println!("cargo:rustc-link-lib=DirectML"); @@ -155,11 +159,12 @@ fn prepare_libort_dir() -> (PathBuf, bool) { #[allow(clippy::type_complexity)] let static_configs: Vec<(PathBuf, PathBuf, PathBuf, Box PathBuf>)> = vec![ (lib_dir.join(&profile), lib_dir.join("lib"), lib_dir.join("_deps"), Box::new(|p: PathBuf, profile| p.join(profile))), + (lib_dir.join(&profile), lib_dir.join("lib"), lib_dir.join(&profile).join("_deps"), Box::new(|p: PathBuf, _| p)), (lib_dir.clone(), lib_dir.join("lib"), lib_dir.parent().unwrap().join("_deps"), Box::new(|p: PathBuf, _| p)), (lib_dir.join("onnxruntime"), lib_dir.join("onnxruntime").join("lib"), lib_dir.join("_deps"), Box::new(|p: PathBuf, _| p)), ]; for (lib_dir, extension_lib_dir, external_lib_dir, transform_dep) in static_configs { - if lib_dir.join(platform_format_lib("onnxruntime_common")).exists() { + if lib_dir.join(platform_format_lib("onnxruntime_common")).exists() && external_lib_dir.exists() { add_search_dir(&lib_dir); for lib in &["common", "flatbuffers", "framework", "graph", "mlas", "optimizer", "providers", "session", "util"] { @@ -276,16 +281,16 @@ fn prepare_libort_dir() -> (PathBuf, bool) { "E23AB2606B4529B655C8BD08EB3E94C5E5AD4338460BC3A83F8015A838BB1AEF" ), //"aarch64-pc-windows-msvc" => ( - // "https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.16.3/ortrs-msort_static-v1.16.3-aarch64-pc-windows-msvc.tgz", - // "B35F6526EAF61527531D6F73EBA19EF09D6B0886FB66C14E1B594EE70F447817" + // "https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.0/ortrs-msort_static-v1.17.0-aarch64-pc-windows-msvc.tgz", + // "27DDC61E1416E3F1BC6137C8365B563F73BA5A6CE8D7008E5CD4E36B4F037FDA" //), "aarch64-unknown-linux-gnu" => ( "https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.16.3/onnxruntime-linux-arm64-1.16.3.tgz", "B0E8016897310F3DBBCD57C8E27C79EFE61B8FEA70BA7C73AD5825E0438046BF" ), //"wasm32-unknown-emscripten" => ( - // "https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.16.3/ortrs-msort_static-v1.16.3-wasm32-unknown-emscripten.tgz", - // "468F74FB4C7451DC94EBABC080779CDFF0C7DA0617D85ADF21D5435A96F9D470" + // "https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.0/ortrs-msort_static-v1.17.0-wasm32-unknown-emscripten.tgz", + // "E1ADBF06922649A59AB9D0459E9D5985B002C3AE830B512B7AED030BDA859C55" //), "x86_64-apple-darwin" => ( "https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.16.3/onnxruntime-osx-x86_64-1.16.3.tgz", @@ -357,7 +362,7 @@ fn prepare_libort_dir() -> (PathBuf, bool) { let lib_dir = cache_dir.join(ort_extract_dir); if !lib_dir.exists() { let downloaded_file = fetch_file(prebuilt_url); - assert!(verify_file(&downloaded_file, prebuilt_hash)); + assert!(verify_file(&downloaded_file, prebuilt_hash), "hash does not match!"); extract_tgz(&downloaded_file, &cache_dir); } diff --git a/ort-sys/src/lib.rs b/ort-sys/src/lib.rs index 33c97fcb..4edc4c1f 100644 --- a/ort-sys/src/lib.rs +++ b/ort-sys/src/lib.rs @@ -76,7 +76,11 @@ pub enum ONNXTensorElementDataType { ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 = 13, ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 = 14, ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 = 15, - ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 = 16 + ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 = 16, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN = 17, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ = 18, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2 = 19, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ = 20 } #[repr(i32)] #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] @@ -277,6 +281,11 @@ pub struct OrtOpAttr { pub struct OrtLogger { _unused: [u8; 0] } +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct OrtShapeInferContext { + _unused: [u8; 0] +} pub type OrtStatusPtr = *mut OrtStatus; #[doc = " \\brief Memory allocation interface\n\n Structure of function pointers that defines a memory allocator. This can be created and filled in by the user for custom allocators.\n\n When an allocator is passed to any function, be sure that the allocator object is not destroyed until the last allocated object using it is freed."] #[repr(C)] @@ -406,6 +415,7 @@ pub enum OrtCudnnConvAlgoSearch { } #[doc = " \\brief CUDA Provider Options\n\n \\see OrtApi::SessionOptionsAppendExecutionProvider_CUDA"] #[repr(C)] +#[derive(Debug, Copy, Clone)] pub struct OrtCUDAProviderOptions { #[doc = " \\brief CUDA device Id\n Defaults to 0."] pub device_id: ::std::os::raw::c_int, @@ -494,6 +504,7 @@ fn bindgen_test_layout_OrtCUDAProviderOptions() { } #[doc = " \\brief ROCM Provider Options\n\n \\see OrtApi::SessionOptionsAppendExecutionProvider_ROCM"] #[repr(C)] +#[derive(Debug, Copy, Clone)] pub struct OrtROCMProviderOptions { #[doc = " \\brief ROCM device Id\n Defaults to 0."] pub device_id: ::std::os::raw::c_int, @@ -582,6 +593,7 @@ fn bindgen_test_layout_OrtROCMProviderOptions() { } #[doc = " \\brief TensorRT Provider Options\n\n \\see OrtApi::SessionOptionsAppendExecutionProvider_TensorRT"] #[repr(C)] +#[derive(Debug, Copy, Clone)] pub struct OrtTensorRTProviderOptions { #[doc = "< CUDA device id (0 = default device)"] pub device_id: ::std::os::raw::c_int, @@ -706,14 +718,16 @@ fn bindgen_test_layout_OrtTensorRTProviderOptions() { pub struct OrtMIGraphXProviderOptions { pub device_id: ::std::os::raw::c_int, pub migraphx_fp16_enable: ::std::os::raw::c_int, - pub migraphx_int8_enable: ::std::os::raw::c_int + pub migraphx_int8_enable: ::std::os::raw::c_int, + pub migraphx_use_native_calibration_table: ::std::os::raw::c_int, + pub migraphx_int8_calibration_table_name: *const ::std::os::raw::c_char } #[test] fn bindgen_test_layout_OrtMIGraphXProviderOptions() { const UNINIT: ::std::mem::MaybeUninit = ::std::mem::MaybeUninit::uninit(); let ptr = UNINIT.as_ptr(); - assert_eq!(::std::mem::size_of::(), 12usize, concat!("Size of: ", stringify!(OrtMIGraphXProviderOptions))); - assert_eq!(::std::mem::align_of::(), 4usize, concat!("Alignment of ", stringify!(OrtMIGraphXProviderOptions))); + assert_eq!(::std::mem::size_of::(), 24usize, concat!("Size of: ", stringify!(OrtMIGraphXProviderOptions))); + assert_eq!(::std::mem::align_of::(), 8usize, concat!("Alignment of ", stringify!(OrtMIGraphXProviderOptions))); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).device_id) as usize - ptr as usize }, 0usize, @@ -729,14 +743,25 @@ fn bindgen_test_layout_OrtMIGraphXProviderOptions() { 8usize, concat!("Offset of field: ", stringify!(OrtMIGraphXProviderOptions), "::", stringify!(migraphx_int8_enable)) ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).migraphx_use_native_calibration_table) as usize - ptr as usize }, + 12usize, + concat!("Offset of field: ", stringify!(OrtMIGraphXProviderOptions), "::", stringify!(migraphx_use_native_calibration_table)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).migraphx_int8_calibration_table_name) as usize - ptr as usize }, + 16usize, + concat!("Offset of field: ", stringify!(OrtMIGraphXProviderOptions), "::", stringify!(migraphx_int8_calibration_table_name)) + ); } #[doc = " \\brief OpenVINO Provider Options\n\n \\see OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO"] #[repr(C)] +#[derive(Debug, Copy, Clone)] pub struct OrtOpenVINOProviderOptions { #[doc = " \\brief Device type string\n\n Valid settings are one of: \"CPU_FP32\", \"CPU_FP16\", \"GPU_FP32\", \"GPU_FP16\""] pub device_type: *const ::std::os::raw::c_char, #[doc = "< 0 = disabled, nonzero = enabled"] - pub enable_vpu_fast_compile: ::std::os::raw::c_uchar, + pub enable_npu_fast_compile: ::std::os::raw::c_uchar, pub device_id: *const ::std::os::raw::c_char, #[doc = "< 0 = Use default number of threads"] pub num_of_threads: size_t, @@ -759,9 +784,9 @@ fn bindgen_test_layout_OrtOpenVINOProviderOptions() { concat!("Offset of field: ", stringify!(OrtOpenVINOProviderOptions), "::", stringify!(device_type)) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).enable_vpu_fast_compile) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).enable_npu_fast_compile) as usize - ptr as usize }, 8usize, - concat!("Offset of field: ", stringify!(OrtOpenVINOProviderOptions), "::", stringify!(enable_vpu_fast_compile)) + concat!("Offset of field: ", stringify!(OrtOpenVINOProviderOptions), "::", stringify!(enable_npu_fast_compile)) ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).device_id) as usize - ptr as usize }, @@ -1760,13 +1785,59 @@ pub struct OrtApi { resource: *mut *mut ::std::os::raw::c_void ) -> OrtStatusPtr ) + >, + pub SetUserLoggingFunction: ::std::option::Option< + _system!( + unsafe fn( + options: *mut OrtSessionOptions, + user_logging_function: OrtLoggingFunction, + user_logging_param: *mut ::std::os::raw::c_void + ) -> OrtStatusPtr + ) + >, + pub ShapeInferContext_GetInputCount: ::std::option::Option<_system!(unsafe fn(context: *const OrtShapeInferContext, out: *mut size_t) -> OrtStatusPtr)>, + pub ShapeInferContext_GetInputTypeShape: ::std::option::Option< + _system!(unsafe fn(context: *const OrtShapeInferContext, index: size_t, info: *mut *mut OrtTensorTypeAndShapeInfo) -> OrtStatusPtr) + >, + pub ShapeInferContext_GetAttribute: ::std::option::Option< + _system!(unsafe fn(context: *const OrtShapeInferContext, attr_name: *const ::std::os::raw::c_char, attr: *mut *const OrtOpAttr) -> OrtStatusPtr) + >, + pub ShapeInferContext_SetOutputTypeShape: + ::std::option::Option<_system!(unsafe fn(context: *const OrtShapeInferContext, index: size_t, info: *const OrtTensorTypeAndShapeInfo) -> OrtStatusPtr)>, + pub SetSymbolicDimensions: ::std::option::Option< + _system!(unsafe fn(info: *mut OrtTensorTypeAndShapeInfo, dim_params: *mut *const ::std::os::raw::c_char, dim_params_length: size_t) -> OrtStatusPtr) + >, + pub ReadOpAttr: ::std::option::Option< + _system!(unsafe fn(op_attr: *const OrtOpAttr, type_: OrtOpAttrType, data: *mut ::std::os::raw::c_void, len: size_t, out: *mut size_t) -> OrtStatusPtr) + >, + pub SetDeterministicCompute: ::std::option::Option<_system!(unsafe fn(options: *mut OrtSessionOptions, value: bool) -> OrtStatusPtr)>, + pub KernelContext_ParallelFor: ::std::option::Option< + _system!( + unsafe fn( + context: *const OrtKernelContext, + fn_: ::std::option::Option<_system!(unsafe fn(arg1: *mut ::std::os::raw::c_void, arg2: size_t))>, + total: size_t, + num_batch: size_t, + usr_data: *mut ::std::os::raw::c_void + ) -> OrtStatusPtr + ) + >, + pub SessionOptionsAppendExecutionProvider_OpenVINO_V2: ::std::option::Option< + _system!( + unsafe fn( + options: *mut OrtSessionOptions, + provider_options_keys: *const *const ::std::os::raw::c_char, + provider_options_values: *const *const ::std::os::raw::c_char, + num_keys: size_t + ) -> OrtStatusPtr + ) > } #[test] fn bindgen_test_layout_OrtApi() { const UNINIT: ::std::mem::MaybeUninit = ::std::mem::MaybeUninit::uninit(); let ptr = UNINIT.as_ptr(); - assert_eq!(::std::mem::size_of::(), 2128usize, concat!("Size of: ", stringify!(OrtApi))); + assert_eq!(::std::mem::size_of::(), 2208usize, concat!("Size of: ", stringify!(OrtApi))); assert_eq!(::std::mem::align_of::(), 8usize, concat!("Alignment of ", stringify!(OrtApi))); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).CreateStatus) as usize - ptr as usize }, @@ -3098,6 +3169,56 @@ fn bindgen_test_layout_OrtApi() { 2120usize, concat!("Offset of field: ", stringify!(OrtApi), "::", stringify!(KernelContext_GetResource)) ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).SetUserLoggingFunction) as usize - ptr as usize }, + 2128usize, + concat!("Offset of field: ", stringify!(OrtApi), "::", stringify!(SetUserLoggingFunction)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).ShapeInferContext_GetInputCount) as usize - ptr as usize }, + 2136usize, + concat!("Offset of field: ", stringify!(OrtApi), "::", stringify!(ShapeInferContext_GetInputCount)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).ShapeInferContext_GetInputTypeShape) as usize - ptr as usize }, + 2144usize, + concat!("Offset of field: ", stringify!(OrtApi), "::", stringify!(ShapeInferContext_GetInputTypeShape)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).ShapeInferContext_GetAttribute) as usize - ptr as usize }, + 2152usize, + concat!("Offset of field: ", stringify!(OrtApi), "::", stringify!(ShapeInferContext_GetAttribute)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).ShapeInferContext_SetOutputTypeShape) as usize - ptr as usize }, + 2160usize, + concat!("Offset of field: ", stringify!(OrtApi), "::", stringify!(ShapeInferContext_SetOutputTypeShape)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).SetSymbolicDimensions) as usize - ptr as usize }, + 2168usize, + concat!("Offset of field: ", stringify!(OrtApi), "::", stringify!(SetSymbolicDimensions)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).ReadOpAttr) as usize - ptr as usize }, + 2176usize, + concat!("Offset of field: ", stringify!(OrtApi), "::", stringify!(ReadOpAttr)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).SetDeterministicCompute) as usize - ptr as usize }, + 2184usize, + concat!("Offset of field: ", stringify!(OrtApi), "::", stringify!(SetDeterministicCompute)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).KernelContext_ParallelFor) as usize - ptr as usize }, + 2192usize, + concat!("Offset of field: ", stringify!(OrtApi), "::", stringify!(KernelContext_ParallelFor)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).SessionOptionsAppendExecutionProvider_OpenVINO_V2) as usize - ptr as usize }, + 2200usize, + concat!("Offset of field: ", stringify!(OrtApi), "::", stringify!(SessionOptionsAppendExecutionProvider_OpenVINO_V2)) + ); } #[repr(i32)] #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] @@ -3130,13 +3251,16 @@ pub struct OrtCustomOp { pub CreateKernelV2: ::std::option::Option< _system!(unsafe fn(op: *const OrtCustomOp, api: *const OrtApi, info: *const OrtKernelInfo, kernel: *mut *mut ::std::os::raw::c_void) -> OrtStatusPtr) >, - pub KernelComputeV2: ::std::option::Option<_system!(unsafe fn(op_kernel: *mut ::std::os::raw::c_void, context: *mut OrtKernelContext) -> OrtStatusPtr)> + pub KernelComputeV2: ::std::option::Option<_system!(unsafe fn(op_kernel: *mut ::std::os::raw::c_void, context: *mut OrtKernelContext) -> OrtStatusPtr)>, + pub InferOutputShapeFn: ::std::option::Option<_system!(unsafe fn(op: *const OrtCustomOp, arg1: *mut OrtShapeInferContext) -> OrtStatusPtr)>, + pub GetStartVersion: ::std::option::Option<_system!(unsafe fn(op: *const OrtCustomOp) -> ::std::os::raw::c_int)>, + pub GetEndVersion: ::std::option::Option<_system!(unsafe fn(op: *const OrtCustomOp) -> ::std::os::raw::c_int)> } #[test] fn bindgen_test_layout_OrtCustomOp() { const UNINIT: ::std::mem::MaybeUninit = ::std::mem::MaybeUninit::uninit(); let ptr = UNINIT.as_ptr(); - assert_eq!(::std::mem::size_of::(), 152usize, concat!("Size of: ", stringify!(OrtCustomOp))); + assert_eq!(::std::mem::size_of::(), 176usize, concat!("Size of: ", stringify!(OrtCustomOp))); assert_eq!(::std::mem::align_of::(), 8usize, concat!("Alignment of ", stringify!(OrtCustomOp))); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).version) as usize - ptr as usize }, @@ -3233,6 +3357,21 @@ fn bindgen_test_layout_OrtCustomOp() { 144usize, concat!("Offset of field: ", stringify!(OrtCustomOp), "::", stringify!(KernelComputeV2)) ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).InferOutputShapeFn) as usize - ptr as usize }, + 152usize, + concat!("Offset of field: ", stringify!(OrtCustomOp), "::", stringify!(InferOutputShapeFn)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).GetStartVersion) as usize - ptr as usize }, + 160usize, + concat!("Offset of field: ", stringify!(OrtCustomOp), "::", stringify!(GetStartVersion)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).GetEndVersion) as usize - ptr as usize }, + 168usize, + concat!("Offset of field: ", stringify!(OrtCustomOp), "::", stringify!(GetEndVersion)) + ); } _system_block! { pub fn OrtSessionOptionsAppendExecutionProvider_CUDA(options: *mut OrtSessionOptions, device_id: ::std::os::raw::c_int) -> OrtStatusPtr; @@ -3246,3 +3385,6 @@ _system_block! { _system_block! { pub fn OrtSessionOptionsAppendExecutionProvider_Dnnl(options: *mut OrtSessionOptions, use_arena: ::std::os::raw::c_int) -> OrtStatusPtr; } +_system_block! { + pub fn OrtSessionOptionsAppendExecutionProvider_Tensorrt(options: *mut OrtSessionOptions, device_id: ::std::os::raw::c_int) -> OrtStatusPtr; +} diff --git a/src/environment.rs b/src/environment.rs index 1373db0d..69a7f8e2 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -1,9 +1,4 @@ -#[cfg(feature = "load-dynamic")] -use std::sync::Arc; -use std::{ - ffi::CString, - sync::{atomic::AtomicPtr, OnceLock} -}; +use std::{cell::UnsafeCell, ffi::CString, sync::atomic::AtomicPtr, sync::Arc}; use tracing::debug; @@ -15,20 +10,40 @@ use super::{ #[cfg(feature = "load-dynamic")] use crate::G_ORT_DYLIB_PATH; -static G_ENV: OnceLock = OnceLock::new(); +struct EnvironmentSingleton { + cell: UnsafeCell>> +} + +unsafe impl Sync for EnvironmentSingleton {} + +static G_ENV: EnvironmentSingleton = EnvironmentSingleton { cell: UnsafeCell::new(None) }; #[derive(Debug)] -pub(crate) struct EnvironmentSingleton { +pub(crate) struct Environment { pub(crate) execution_providers: Vec, pub(crate) env_ptr: AtomicPtr } -pub(crate) fn get_environment() -> Result<&'static EnvironmentSingleton> { - if G_ENV.get().is_none() { - EnvironmentBuilder::default().commit()?; - Ok(G_ENV.get().unwrap()) +impl Drop for Environment { + #[tracing::instrument] + fn drop(&mut self) { + let env_ptr: *mut ort_sys::OrtEnv = *self.env_ptr.get_mut(); + + debug!("Releasing environment"); + + assert_ne!(env_ptr, std::ptr::null_mut()); + ortsys![unsafe ReleaseEnv(env_ptr)]; + } +} + +pub(crate) fn get_environment() -> Result<&'static Arc> { + if let Some(c) = unsafe { &*G_ENV.cell.get() } { + Ok(c) } else { - Ok(unsafe { G_ENV.get().unwrap_unchecked() }) + debug!("Environment not yet initialized, creating a new one"); + EnvironmentBuilder::default().commit()?; + + Ok(unsafe { (*G_ENV.cell.get()).as_ref().unwrap_unchecked() }) } } @@ -131,32 +146,32 @@ impl EnvironmentBuilder { /// Commit the configuration to a new [`Environment`]. pub fn commit(self) -> Result<()> { - if G_ENV.get().is_none() { - debug!("Environment not yet initialized, creating a new one"); - - let env_ptr = if let Some(global_thread_pool) = self.global_thread_pool_options { - let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut(); - let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger); - let logger_param: *mut std::ffi::c_void = std::ptr::null_mut(); - let cname = CString::new(self.name.clone()).unwrap(); - - let mut thread_options: *mut ort_sys::OrtThreadingOptions = std::ptr::null_mut(); - ortsys![unsafe CreateThreadingOptions(&mut thread_options) -> Error::CreateEnvironment; nonNull(thread_options)]; - if let Some(inter_op_parallelism) = global_thread_pool.inter_op_parallelism { - ortsys![unsafe SetGlobalInterOpNumThreads(thread_options, inter_op_parallelism) -> Error::CreateEnvironment]; - } - if let Some(intra_op_parallelism) = global_thread_pool.intra_op_parallelism { - ortsys![unsafe SetGlobalIntraOpNumThreads(thread_options, intra_op_parallelism) -> Error::CreateEnvironment]; - } - if let Some(spin_control) = global_thread_pool.spin_control { - ortsys![unsafe SetGlobalSpinControl(thread_options, if spin_control { 1 } else { 0 }) -> Error::CreateEnvironment]; - } - if let Some(intra_op_thread_affinity) = global_thread_pool.intra_op_thread_affinity { - let cstr = CString::new(intra_op_thread_affinity).unwrap(); - ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr()) -> Error::CreateEnvironment]; - } - - ortsys![unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools( + // drop global reference to previous environment + drop(unsafe { (*G_ENV.cell.get()).take() }); + + let env_ptr = if let Some(global_thread_pool) = self.global_thread_pool_options { + let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut(); + let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger); + let logger_param: *mut std::ffi::c_void = std::ptr::null_mut(); + let cname = CString::new(self.name.clone()).unwrap(); + + let mut thread_options: *mut ort_sys::OrtThreadingOptions = std::ptr::null_mut(); + ortsys![unsafe CreateThreadingOptions(&mut thread_options) -> Error::CreateEnvironment; nonNull(thread_options)]; + if let Some(inter_op_parallelism) = global_thread_pool.inter_op_parallelism { + ortsys![unsafe SetGlobalInterOpNumThreads(thread_options, inter_op_parallelism) -> Error::CreateEnvironment]; + } + if let Some(intra_op_parallelism) = global_thread_pool.intra_op_parallelism { + ortsys![unsafe SetGlobalIntraOpNumThreads(thread_options, intra_op_parallelism) -> Error::CreateEnvironment]; + } + if let Some(spin_control) = global_thread_pool.spin_control { + ortsys![unsafe SetGlobalSpinControl(thread_options, if spin_control { 1 } else { 0 }) -> Error::CreateEnvironment]; + } + if let Some(intra_op_thread_affinity) = global_thread_pool.intra_op_thread_affinity { + let cstr = CString::new(intra_op_thread_affinity).unwrap(); + ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr()) -> Error::CreateEnvironment]; + } + + ortsys![unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools( logging_function, logger_param, ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, @@ -164,30 +179,32 @@ impl EnvironmentBuilder { thread_options, &mut env_ptr ) -> Error::CreateEnvironment; nonNull(env_ptr)]; - ortsys![unsafe ReleaseThreadingOptions(thread_options)]; - env_ptr - } else { - let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut(); - let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger); - // FIXME: What should go here? - let logger_param: *mut std::ffi::c_void = std::ptr::null_mut(); - let cname = CString::new(self.name.clone()).unwrap(); - ortsys![unsafe CreateEnvWithCustomLogger( + ortsys![unsafe ReleaseThreadingOptions(thread_options)]; + env_ptr + } else { + let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut(); + let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger); + // FIXME: What should go here? + let logger_param: *mut std::ffi::c_void = std::ptr::null_mut(); + let cname = CString::new(self.name.clone()).unwrap(); + ortsys![unsafe CreateEnvWithCustomLogger( logging_function, logger_param, ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, cname.as_ptr(), &mut env_ptr ) -> Error::CreateEnvironment; nonNull(env_ptr)]; - env_ptr - }; - debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created"); + env_ptr + }; + debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created"); - let _ = G_ENV.set(EnvironmentSingleton { + unsafe { + *G_ENV.cell.get() = Some(Arc::new(Environment { execution_providers: self.execution_providers, env_ptr: AtomicPtr::new(env_ptr) - }); - } + })); + }; + Ok(()) } } @@ -223,11 +240,11 @@ mod tests { use super::*; fn is_env_initialized() -> bool { - G_ENV.get().is_some() && !G_ENV.get().unwrap().env_ptr.load(Ordering::Relaxed).is_null() + unsafe { (*G_ENV.cell.get()).as_ref() }.is_some() && !unsafe { (*G_ENV.cell.get()).as_ref() }.unwrap().env_ptr.load(Ordering::Relaxed).is_null() } fn env_ptr() -> Option<*mut ort_sys::OrtEnv> { - G_ENV.get().map(|f| f.env_ptr.load(Ordering::Relaxed)) + unsafe { (*G_ENV.cell.get()).as_ref() }.map(|f| f.env_ptr.load(Ordering::Relaxed)) } struct ConcurrentTestRun { diff --git a/src/error.rs b/src/error.rs index ba270dfa..aa6ba509 100644 --- a/src/error.rs +++ b/src/error.rs @@ -105,6 +105,15 @@ pub enum Error { /// Error occurred when extracting string data from an ONNX tensor #[error("Failed to get tensor string data: {0}")] GetStringTensorContent(ErrorInternal), + /// Error occurred when creating run options. + #[error("Failed to create run options: {0}")] + CreateRunOptions(ErrorInternal), + /// Error occurred when terminating run options. + #[error("Failed to terminate run options: {0}")] + RunOptionsSetTerminate(ErrorInternal), + /// Error occurred when unterminating run options. + #[error("Failed to unterminate run options: {0}")] + RunOptionsUnsetTerminate(ErrorInternal), /// Error occurred when converting data to a String #[error("Data was not UTF-8: {0}")] StringFromUtf8Error(#[from] string::FromUtf8Error), diff --git a/src/execution_providers/openvino.rs b/src/execution_providers/openvino.rs index 47d2d20e..4133480c 100644 --- a/src/execution_providers/openvino.rs +++ b/src/execution_providers/openvino.rs @@ -12,7 +12,7 @@ pub struct OpenVINOExecutionProvider { context: *mut c_void, enable_opencl_throttling: bool, enable_dynamic_shapes: bool, - enable_vpu_fast_compile: bool + enable_npu_fast_compile: bool } unsafe impl Send for OpenVINOExecutionProvider {} @@ -28,7 +28,7 @@ impl Default for OpenVINOExecutionProvider { context: std::ptr::null_mut(), enable_opencl_throttling: false, enable_dynamic_shapes: false, - enable_vpu_fast_compile: false + enable_npu_fast_compile: false } } } @@ -82,8 +82,8 @@ impl OpenVINOExecutionProvider { self } - pub fn with_vpu_fast_compile(mut self) -> Self { - self.enable_vpu_fast_compile = true; + pub fn with_npu_fast_compile(mut self) -> Self { + self.enable_npu_fast_compile = true; self } @@ -127,7 +127,7 @@ impl ExecutionProvider for OpenVINOExecutionProvider { context: self.context, enable_opencl_throttling: self.enable_opencl_throttling.into(), enable_dynamic_shapes: self.enable_dynamic_shapes.into(), - enable_vpu_fast_compile: self.enable_vpu_fast_compile.into() + enable_npu_fast_compile: self.enable_npu_fast_compile.into() }; return crate::error::status_to_result( crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_OpenVINO(session_builder.session_options_ptr, &openvino_options as *const _)] diff --git a/src/io_binding.rs b/src/io_binding.rs index 2d61256f..f2051edf 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -1,6 +1,12 @@ use std::{ffi::CString, fmt::Debug, ptr, sync::Arc}; -use crate::{memory::MemoryInfo, ortsys, session::output::SessionOutputs, value::Value, Error, Result, Session}; +use crate::{ + memory::MemoryInfo, + ortsys, + session::{output::SessionOutputs, RunOptions}, + value::Value, + Error, Result, Session +}; /// Enables binding of session inputs and/or outputs to pre-allocated memory. /// @@ -59,7 +65,19 @@ impl<'s> IoBinding<'s> { } pub fn run<'i: 's>(&'i self) -> Result> { - let run_options_ptr: *const ort_sys::OrtRunOptions = std::ptr::null(); + self.run_inner(None) + } + + pub fn run_with_options<'i: 's>(&'i self, run_options: Arc) -> Result> { + self.run_inner(Some(run_options)) + } + + fn run_inner<'i: 's>(&'i self, run_options: Option>) -> Result> { + let run_options_ptr = if let Some(run_options) = run_options { + run_options.run_options_ptr + } else { + std::ptr::null_mut() + }; ortsys![unsafe RunWithBinding(self.session.inner.session_ptr, run_options_ptr, self.ptr) -> Error::SessionRunWithIoBinding]; let mut count = self.output_names.len() as ort_sys::size_t; diff --git a/src/lib.rs b/src/lib.rs index 53019a62..0dab149d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,7 +33,7 @@ use tracing::Level; #[cfg(feature = "load-dynamic")] pub use self::environment::init_from; -pub use self::environment::{init, EnvironmentBuilder}; +pub use self::environment::{init, EnvironmentBuilder, EnvironmentGlobalThreadPoolOptions}; #[cfg(feature = "fetch-models")] #[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))] pub use self::error::FetchModelError; @@ -42,7 +42,7 @@ pub use self::execution_providers::*; pub use self::io_binding::IoBinding; pub use self::memory::{AllocationDevice, Allocator, MemoryInfo}; pub use self::metadata::ModelMetadata; -pub use self::session::{InMemorySession, Session, SessionBuilder, SessionInputs, SessionOutputs, SharedSessionInner}; +pub use self::session::{InMemorySession, RunOptions, Session, SessionBuilder, SessionInputs, SessionOutputs, SharedSessionInner}; #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub use self::tensor::{ArrayExtensions, ArrayViewHolder, Tensor, TensorData}; @@ -139,13 +139,13 @@ pub fn api() -> ort_sys::OrtApi { tracing::info!("Using ONNX Runtime version '{version_string}'"); let lib_minor_version = version_string.split('.').nth(1).map(|x| x.parse::().unwrap_or(0)).unwrap_or(0); - match lib_minor_version.cmp(&16) { + match lib_minor_version.cmp(&17) { std::cmp::Ordering::Less => panic!( - "ort 2.0 is not compatible with the ONNX Runtime binary found at `{}`; expected GetVersionString to return '1.16.x', but got '{version_string}'", + "ort 2.0 is not compatible with the ONNX Runtime binary found at `{}`; expected GetVersionString to return '1.17.x', but got '{version_string}'", dylib_path() ), std::cmp::Ordering::Greater => tracing::warn!( - "ort 2.0 may have compatibility issues with the ONNX Runtime binary found at `{}`; expected GetVersionString to return '1.16.x', but got '{version_string}'", + "ort 2.0 may have compatibility issues with the ONNX Runtime binary found at `{}`; expected GetVersionString to return '1.17.x', but got '{version_string}'", dylib_path() ), std::cmp::Ordering::Equal => {} diff --git a/src/session/input.rs b/src/session/input.rs index 15b91a2e..b3f3b904 100644 --- a/src/session/input.rs +++ b/src/session/input.rs @@ -1,16 +1,18 @@ use std::collections::HashMap; +use compact_str::CompactString; + use crate::Value; pub enum SessionInputs<'i, const N: usize = 0> { - ValueMap(HashMap<&'static str, Value>), + ValueMap(HashMap), ValueSlice(&'i [Value]), ValueArray([Value; N]) } -impl<'i> From> for SessionInputs<'i> { - fn from(val: HashMap<&'static str, Value>) -> Self { - SessionInputs::ValueMap(val) +impl<'i, K: Into> From> for SessionInputs<'i> { + fn from(val: HashMap) -> Self { + SessionInputs::ValueMap(val.into_iter().map(|c| (c.0.into(), c.1)).collect()) } } diff --git a/src/session/mod.rs b/src/session/mod.rs index 1ac4fae8..e0b7e579 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -17,6 +17,8 @@ use std::{ #[cfg(feature = "fetch-models")] use std::{path::PathBuf, time::Duration}; +use compact_str::CompactString; + #[cfg(feature = "fetch-models")] use super::error::FetchModelError; use super::{ @@ -32,6 +34,7 @@ use super::{ value::{Value, ValueType}, AllocatorType, GraphOptimizationLevel, MemType }; +use crate::environment::Environment; pub(crate) mod input; pub(crate) mod output; @@ -432,7 +435,11 @@ impl SessionBuilder { .collect::>>()?; Ok(Session { - inner: Arc::new(SharedSessionInner { session_ptr, allocator }), + inner: Arc::new(SharedSessionInner { + session_ptr, + allocator, + _environment: Arc::clone(env) + }), inputs, outputs }) @@ -490,7 +497,11 @@ impl SessionBuilder { .collect::>>()?; let session = Session { - inner: Arc::new(SharedSessionInner { session_ptr, allocator }), + inner: Arc::new(SharedSessionInner { + session_ptr, + allocator, + _environment: Arc::clone(env) + }), inputs, outputs }; @@ -503,21 +514,19 @@ impl SessionBuilder { #[derive(Debug)] pub struct SharedSessionInner { pub(crate) session_ptr: *mut ort_sys::OrtSession, - allocator: Allocator + allocator: Allocator, + _environment: Arc } unsafe impl Send for SharedSessionInner {} unsafe impl Sync for SharedSessionInner {} impl Drop for SharedSessionInner { - #[tracing::instrument(skip_all)] + #[tracing::instrument] fn drop(&mut self) { - // FIXME: ใชใ‚“ใง`warn!`ใงๅ‡บใ—ใฆใ„ใŸใฎใ‹ใŒไธๆ˜Žใ€‚ortใฎDiscordใจใ‹ใง่žใ„ใฆใฟใ‚‹ - // tracing::warn!("dropping SharedSessionInner"); - tracing::info!("dropping SharedSessionInner"); + tracing::debug!("dropping SharedSessionInner"); if !self.session_ptr.is_null() { - // tracing::warn!("dropping session ptr"); - tracing::info!("dropping session ptr"); + tracing::debug!("dropping session ptr"); ortsys![unsafe ReleaseSession(self.session_ptr)]; } self.session_ptr = std::ptr::null_mut(); @@ -565,6 +574,46 @@ pub struct Output { pub output_type: ValueType } +/// ONNX Run Options which is used to terminate/unterminate run(s) in a session +#[derive(Debug)] +pub struct RunOptions { + pub(crate) run_options_ptr: *mut ort_sys::OrtRunOptions +} + +// https://onnxruntime.ai/docs/api/c/struct_ort_api.html#ac2a08cac0a657604bd5899e0d1a13675 +unsafe impl Send for RunOptions {} +unsafe impl Sync for RunOptions {} + +impl RunOptions { + /// Creates a new [`RunOptions`]. + pub fn new() -> Result { + let mut run_options_ptr: *mut ort_sys::OrtRunOptions = std::ptr::null_mut(); + ortsys![unsafe CreateRunOptions(&mut run_options_ptr) -> Error::CreateRunOptions; nonNull(run_options_ptr)]; + Ok(Self { run_options_ptr }) + } + + /// Terminates the runs associated with [`RunOptions`]. + pub fn set_terminate(&self) -> Result<()> { + ortsys![unsafe RunOptionsSetTerminate(self.run_options_ptr) -> Error::RunOptionsSetTerminate]; + Ok(()) + } + + /// Unterminates the runs associated with [`RunOptions`]. + pub fn set_unterminate(&self) -> Result<()> { + ortsys![unsafe RunOptionsUnsetTerminate(self.run_options_ptr) -> Error::RunOptionsUnsetTerminate]; + Ok(()) + } +} + +impl Drop for RunOptions { + fn drop(&mut self) { + if !self.run_options_ptr.is_null() { + ortsys![unsafe ReleaseRunOptions(self.run_options_ptr)]; + } + self.run_options_ptr = std::ptr::null_mut(); + } +} + impl Session { pub fn builder() -> Result { SessionBuilder::new() @@ -589,24 +638,78 @@ impl Session { pub fn run<'s, 'i, const N: usize>(&'s self, input_values: impl Into>) -> Result> { match input_values.into() { SessionInputs::ValueSlice(input_values) => { - let outputs = self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values)?; + let outputs = self.run_inner( + &self + .inputs + .iter() + .map(|input| CompactString::new(input.name.as_str())) + .collect::>(), + input_values, + None + )?; Ok(outputs) } SessionInputs::ValueArray(input_values) => { - let outputs = self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), &input_values)?; + let outputs = self.run_inner( + &self + .inputs + .iter() + .map(|input| CompactString::new(input.name.as_str())) + .collect::>(), + &input_values, + None + )?; Ok(outputs) } SessionInputs::ValueMap(input_values) => { - let (input_names, values): (Vec<&'static str>, Vec) = input_values.into_iter().unzip(); - self.run_inner(&input_names, &values) + let (input_names, values): (Vec, Vec) = input_values.into_iter().unzip(); + self.run_inner(&input_names, &values, None) } } } - fn run_inner(&self, input_names: &[&str], input_values: &[Value]) -> Result> { + /// Run the input data through the ONNX graph, performing inference. + pub fn run_with_options<'s, 'i, const N: usize>( + &'s self, + input_values: impl Into>, + run_options: Arc + ) -> Result> { + match input_values.into() { + SessionInputs::ValueSlice(input_values) => { + let outputs = self.run_inner( + &self + .inputs + .iter() + .map(|input| CompactString::new(input.name.as_str())) + .collect::>(), + input_values, + Some(run_options) + )?; + Ok(outputs) + } + SessionInputs::ValueArray(input_values) => { + let outputs = self.run_inner( + &self + .inputs + .iter() + .map(|input| CompactString::new(input.name.as_str())) + .collect::>(), + &input_values, + Some(run_options) + )?; + Ok(outputs) + } + SessionInputs::ValueMap(input_values) => { + let (input_names, values): (Vec, Vec) = input_values.into_iter().unzip(); + self.run_inner(&input_names, &values, Some(run_options)) + } + } + } + + fn run_inner(&self, input_names: &[CompactString], input_values: &[Value], run_options: Option>) -> Result> { let input_names_ptr: Vec<*const c_char> = input_names .iter() - .map(|n| CString::new(*n).unwrap()) + .map(|n| CString::new(n.as_bytes()).unwrap()) .map(|n| n.into_raw() as *const c_char) .collect(); let output_names_ptr: Vec<*const c_char> = self @@ -621,10 +724,16 @@ impl Session { // The C API expects pointers for the arrays (pointers to C-arrays) let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.iter().map(|input_array_ort| input_array_ort.ptr() as *const _).collect(); + let run_options_ptr = if let Some(run_options) = &run_options { + run_options.run_options_ptr + } else { + std::ptr::null_mut() + }; + ortsys![ unsafe Run( self.inner.session_ptr, - ptr::null(), + run_options_ptr, input_names_ptr.as_ptr(), input_ort_values.as_ptr(), input_ort_values.len() as _, diff --git a/src/value.rs b/src/value.rs index 51feb82b..7bd737c0 100644 --- a/src/value.rs +++ b/src/value.rs @@ -764,3 +764,69 @@ pub(crate) unsafe fn extract_data_type_from_map_info(info_ptr: *const ort_sys::O value: value_type_sys.into() }) } + +#[cfg(test)] +mod tests { + use ndarray::{ArcArray1, Array1, CowArray}; + + use crate::*; + + #[test] + #[cfg(feature = "ndarray")] + fn test_tensor_value() -> crate::Result<()> { + let v: Vec = vec![1., 2., 3., 4., 5.]; + let value = Value::from_array(Array1::from_vec(v.clone()))?; + assert!(value.is_tensor()?); + assert_eq!(value.tensor_element_type()?, TensorElementType::Float32); + assert_eq!( + value.dtype()?, + ValueType::Tensor { + ty: TensorElementType::Float32, + dimensions: vec![v.len() as i64] + } + ); + + let (shape, data) = value.extract_raw_tensor::()?; + assert_eq!(shape, vec![v.len() as i64]); + assert_eq!(data, &v); + + Ok(()) + } + + #[test] + #[cfg(feature = "ndarray")] + fn test_tensor_lifetimes() -> crate::Result<()> { + let v: Vec = vec![1., 2., 3., 4., 5.]; + + let arc1 = ArcArray1::from_vec(v.clone()); + let mut arc2 = ArcArray1::clone(&arc1); + let value = Value::from_array(&mut arc2)?; + drop((arc1, arc2)); + + assert_eq!(value.extract_raw_tensor::()?.1, &v); + + let cow = CowArray::from(Array1::from_vec(v.clone())); + let value = Value::from_array(&cow)?; + assert_eq!(value.extract_raw_tensor::()?.1, &v); + + let owned = Array1::from_vec(v.clone()); + let value = Value::from_array(owned.view())?; + drop(owned); + assert_eq!(value.extract_raw_tensor::()?.1, &v); + + Ok(()) + } + + #[test] + fn test_tensor_raw_lifetimes() -> crate::Result<()> { + let v: Vec = vec![1., 2., 3., 4., 5.]; + + let arc = Arc::new(v.clone().into_boxed_slice()); + let shape = vec![v.len() as i64]; + let value = Value::from_array((shape, Arc::clone(&arc)))?; + drop(arc); + assert_eq!(value.extract_raw_tensor::()?.1, &v); + + Ok(()) + } +}