From 7eff3a9a3a524f43a23f6d886af8637612e5b896 Mon Sep 17 00:00:00 2001 From: Mitchell Mosure Date: Sat, 4 May 2024 22:11:17 -0500 Subject: [PATCH] feat: flame (#11) * feat: flame inference * docs: todo * chore: remove print * fix: clippy --- .gitignore | 2 + Cargo.toml | 12 ++- README.md | 1 + benches/modnet.rs | 4 +- benches/yolo_v8.rs | 4 +- src/lib.rs | 2 +- src/models/flame.rs | 163 ++++++++++++++++++++++++++++++++++++++++ src/models/lightglue.rs | 6 +- src/models/mod.rs | 3 + src/models/modnet.rs | 2 +- src/models/yolo_v8.rs | 2 +- tools/flame.rs | 70 +++++++++++++++++ 12 files changed, 259 insertions(+), 12 deletions(-) create mode 100644 src/models/flame.rs create mode 100644 tools/flame.rs diff --git a/.gitignore b/.gitignore index 28fc69c..b7c9d4b 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,5 @@ www/assets/ mediamtx/ onnxruntime/ + +*.onnx diff --git a/Cargo.toml b/Cargo.toml index 6d53cdd..fd34eed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "bevy_ort" description = "bevy ort (onnxruntime) plugin" -version = "0.9.0" +version = "0.10.0" edition = "2021" authors = ["mosure "] license = "MIT" @@ -29,11 +29,13 @@ default-run = "modnet" [features] default = [ + "flame", "lightglue", "modnet", "yolo_v8", ] +flame = [] lightglue = [] modnet = ["rayon"] yolo_v8 = [] @@ -60,11 +62,12 @@ features = [ "bevy_winit", "multi-threaded", "png", + "tonemapping_luts", ] [dependencies.ort] -version = "2.0.0-rc.0" +version = "2.0.0-rc.2" default-features = false features = [ "cuda", @@ -93,6 +96,11 @@ opt-level = 3 path = "src/lib.rs" +[[bin]] +name = "flame" +path = "tools/flame.rs" +required-features = ["flame"] + [[bin]] name = "lightglue" path = "tools/lightglue.rs" diff --git a/README.md b/README.md index f78698a..d247aaa 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ a bevy plugin for the [ort](https://docs.rs/ort/latest/ort/) library - [X] lightglue (feature matching) - [X] modnet (photographic portrait matting) - [X] yolo_v8 (object detection) +- [X] flame (parametric head model) ## library usage diff --git a/benches/modnet.rs b/benches/modnet.rs index 853f23b..d5473c4 100644 --- a/benches/modnet.rs +++ b/benches/modnet.rs @@ -86,7 +86,7 @@ fn modnet_output_to_luma_images_benchmark(c: &mut Criterion) { let session = Session::builder().unwrap() .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() - .with_model_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap(); + .commit_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap(); let data = vec![0u8; (1920 * 1080 * 4) as usize]; let image: Image = Image::new( @@ -123,7 +123,7 @@ fn modnet_inference_benchmark(c: &mut Criterion) { let session = Session::builder().unwrap() .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() - .with_model_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap(); + .commit_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap(); MAX_RESOLUTIONS.iter().for_each(|(width, height)| { let data = vec![0u8; *width as usize * *height as usize * 4]; diff --git a/benches/yolo_v8.rs b/benches/yolo_v8.rs index 6e0439f..1d7ed29 100644 --- a/benches/yolo_v8.rs +++ b/benches/yolo_v8.rs @@ -80,7 +80,7 @@ fn process_output_benchmark(c: &mut Criterion) { let session = Session::builder().unwrap() .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() - .with_model_from_file("assets/yolov8n.onnx").unwrap(); + .commit_from_file("assets/yolov8n.onnx").unwrap(); RESOLUTIONS.iter() .for_each(|(width, height)| { @@ -117,7 +117,7 @@ fn inference_benchmark(c: &mut Criterion) { let session = Session::builder().unwrap() .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() - .with_model_from_file("assets/yolov8n.onnx").unwrap(); + .commit_from_file("assets/yolov8n.onnx").unwrap(); RESOLUTIONS.iter().for_each(|(width, height)| { let data = vec![0u8; *width as usize * *height as usize * 4]; diff --git a/src/lib.rs b/src/lib.rs index 62038ed..d8a19a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -88,7 +88,7 @@ impl AssetLoader for OnnxLoader { // TODO: add session configuration let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_model_from_memory(&bytes)?; + .commit_from_memory(&bytes)?; Ok(Onnx { session: Arc::new(Mutex::new(Some(session))), diff --git a/src/models/flame.rs b/src/models/flame.rs new file mode 100644 index 0000000..6da6e92 --- /dev/null +++ b/src/models/flame.rs @@ -0,0 +1,163 @@ +use bevy::prelude::*; +use ndarray::Array2; +use serde::{Deserialize, Serialize}; + +use crate::{ + inputs, + Onnx, +}; + + + +pub struct FlamePlugin; +impl Plugin for FlamePlugin { + fn build(&self, app: &mut App) { + app.init_resource::(); + } +} + +#[derive(Resource, Default)] +pub struct Flame { + pub onnx: Handle, +} + + +#[derive( + Debug, + Clone, +)] +pub struct FlameInput { + pub shape: [[f32; 100]; 8], + pub pose: [[f32; 6]; 8], + pub expression: [[f32; 50]; 8], + pub neck: [[f32; 3]; 8], + pub eye: [[f32; 6]; 8], +} + +impl Default for FlameInput { + fn default() -> Self { + let radian = std::f32::consts::PI / 180.0; + + Self { + shape: [[0.0; 100]; 8], + pose: [ + [0.0, 30.0 * radian, 0.0, 0.0, 0.0, 0.0], + [0.0, -30.0 * radian, 0.0, 0.0, 0.0, 0.0], + [0.0, 85.0 * radian, 0.0, 0.0, 0.0, 0.0], + [0.0, -48.0 * radian, 0.0, 0.0, 0.0, 0.0], + [0.0, 10.0 * radian, 0.0, 0.0, 0.0, 0.0], + [0.0, -15.0 * radian, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0 * radian, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0 * radian, 0.0, 0.0, 0.0, 0.0], + ], + expression: [[0.0; 50]; 8], + neck: [[0.0; 3]; 8], + eye: [[0.0; 6]; 8], + } + } +} + + +#[derive( + Debug, + Default, + Clone, + Deserialize, + Serialize, +)] +pub struct FlameOutput { + pub vertices: Vec<[f32; 3]>, // TODO: use Vec3 for binding + pub landmarks: Vec<[f32; 3]>, +} + + +pub fn flame_inference( + session: &ort::Session, + input: &FlameInput, +) -> FlameOutput { + let PreparedInput { + shape, + expression, + pose, + neck, + eye, + } = prepare_input(input); + + let input_values = inputs![ + "shape" => shape.view(), + "expression" => expression.view(), + "pose" => pose.view(), + "neck" => neck.view(), + "eye" => eye.view(), + ].map_err(|e| e.to_string()).unwrap(); + let outputs = session.run(input_values).map_err(|e| e.to_string()); + let binding = outputs.ok().unwrap(); + + let vertices: &ort::Value = binding.get("vertices").unwrap(); + let landmarks: &ort::Value = binding.get("landmarks").unwrap(); + + post_process( + vertices, + landmarks, + ) +} + + +pub struct PreparedInput { + pub shape: Array2, + pub pose: Array2, + pub expression: Array2, + pub neck: Array2, + pub eye: Array2, +} + +pub fn prepare_input( + input: &FlameInput, +) -> PreparedInput { + let shape = Array2::from_shape_vec((8, 100), input.shape.concat()).unwrap(); + let pose = Array2::from_shape_vec((8, 6), input.pose.concat()).unwrap(); + let expression = Array2::from_shape_vec((8, 50), input.expression.concat()).unwrap(); + let neck = Array2::from_shape_vec((8, 3), input.neck.concat()).unwrap(); + let eye = Array2::from_shape_vec((8, 6), input.eye.concat()).unwrap(); + + PreparedInput { + shape, + expression, + pose, + neck, + eye, + } +} + + +pub fn post_process( + vertices: &ort::Value, + landmarks: &ort::Value, +) -> FlameOutput { + let vertices_tensor = vertices.try_extract_tensor::().unwrap(); + let vertices_view = vertices_tensor.view(); // [8, 5023, 3] + + let landmarks_tensor = landmarks.try_extract_tensor::().unwrap(); + let landmarks_view = landmarks_tensor.view(); // [8, 68, 3] + + let vertices = vertices_view.outer_iter() + .flat_map(|subtensor| { + subtensor.outer_iter().map(|row| { + [row[0], row[1], row[2]] + }).collect::>() + }) + .collect::>(); + + let landmarks = landmarks_view.outer_iter() + .flat_map(|subtensor| { + subtensor.outer_iter().map(|row| { + [row[0], row[1], row[2]] + }).collect::>() + }) + .collect::>(); + + FlameOutput { + vertices, + landmarks, + } +} diff --git a/src/models/lightglue.rs b/src/models/lightglue.rs index 911bc12..6937f83 100644 --- a/src/models/lightglue.rs +++ b/src/models/lightglue.rs @@ -100,13 +100,13 @@ pub fn post_process( kpts1: &ort::Value, matches: &ort::Value, ) -> Result, &'static str> { - let kpts0_tensor = kpts0.extract_tensor::().unwrap(); + let kpts0_tensor = kpts0.try_extract_tensor::().unwrap(); let kpts0_view = kpts0_tensor.view(); - let kpts1_tensor = kpts1.extract_tensor::().unwrap(); + let kpts1_tensor = kpts1.try_extract_tensor::().unwrap(); let kpts1_view = kpts1_tensor.view(); - let matches = matches.extract_tensor::().unwrap(); + let matches = matches.try_extract_tensor::().unwrap(); let matches_view = matches.view(); Ok( diff --git a/src/models/mod.rs b/src/models/mod.rs index 32e9875..6c07ed6 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "flame")] +pub mod flame; + #[cfg(feature = "lightglue")] pub mod lightglue; diff --git a/src/models/modnet.rs b/src/models/modnet.rs index d26b0ac..0f1daa6 100644 --- a/src/models/modnet.rs +++ b/src/models/modnet.rs @@ -52,7 +52,7 @@ pub fn modnet_inference( pub fn modnet_output_to_luma_images( output_value: &ort::Value, ) -> Vec { - let tensor = output_value.extract_tensor::().unwrap(); + let tensor = output_value.try_extract_tensor::().unwrap(); let data = tensor.view(); let shape = data.shape(); diff --git a/src/models/yolo_v8.rs b/src/models/yolo_v8.rs index eb24863..0c28fed 100644 --- a/src/models/yolo_v8.rs +++ b/src/models/yolo_v8.rs @@ -90,7 +90,7 @@ pub fn process_output( ) -> Vec { let mut boxes = Vec::new(); - let tensor = output.extract_tensor::().unwrap(); + let tensor = output.try_extract_tensor::().unwrap(); let data = tensor.view().t().into_owned(); for detection in data.axis_iter(Axis(0)) { diff --git a/tools/flame.rs b/tools/flame.rs new file mode 100644 index 0000000..c75a058 --- /dev/null +++ b/tools/flame.rs @@ -0,0 +1,70 @@ +use bevy::prelude::*; + +use bevy_ort::{ + BevyOrtPlugin, + models::flame::{ + FlameInput, + FlameOutput, + flame_inference, + Flame, + FlamePlugin, + }, + Onnx, +}; + + +fn main() { + App::new() + .add_plugins(( + DefaultPlugins, + BevyOrtPlugin, + FlamePlugin, + )) + .add_systems(Startup, load_flame) + .add_systems(Update, inference) + .run(); +} + + +fn load_flame( + asset_server: Res, + mut flame: ResMut, +) { + let flame_handle: Handle = asset_server.load("models/flame.onnx"); + flame.onnx = flame_handle; +} + + +fn inference( + mut commands: Commands, + flame: Res, + onnx_assets: Res>, + mut complete: Local, +) { + if *complete { + return; + } + + let flame_output: Result = (|| { + let onnx = onnx_assets.get(&flame.onnx).ok_or("failed to get ONNX asset")?; + let session_lock = onnx.session.lock().map_err(|e| e.to_string())?; + let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?; + + Ok(flame_inference( + session, + &FlameInput::default(), + )) + })(); + + match flame_output { + Ok(_flame_output) => { + // TODO: insert mesh + // TODO: insert pan orbit camera + commands.spawn(Camera3dBundle::default()); + *complete = true; + } + Err(e) => { + eprintln!("inference failed: {}", e); + } + } +}