Skip to content

Commit

Permalink
feat: flame (#11)
Browse files Browse the repository at this point in the history
* feat: flame inference

* docs: todo

* chore: remove print

* fix: clippy
  • Loading branch information
mosure authored May 5, 2024
1 parent 02feb3b commit 7eff3a9
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 12 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@ www/assets/

mediamtx/
onnxruntime/

*.onnx
12 changes: 10 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "bevy_ort"
description = "bevy ort (onnxruntime) plugin"
version = "0.9.0"
version = "0.10.0"
edition = "2021"
authors = ["mosure <[email protected]>"]
license = "MIT"
Expand Down Expand Up @@ -29,11 +29,13 @@ default-run = "modnet"

[features]
default = [
"flame",
"lightglue",
"modnet",
"yolo_v8",
]

flame = []
lightglue = []
modnet = ["rayon"]
yolo_v8 = []
Expand All @@ -60,11 +62,12 @@ features = [
"bevy_winit",
"multi-threaded",
"png",
"tonemapping_luts",
]


[dependencies.ort]
version = "2.0.0-rc.0"
version = "2.0.0-rc.2"
default-features = false
features = [
"cuda",
Expand Down Expand Up @@ -93,6 +96,11 @@ opt-level = 3
path = "src/lib.rs"


[[bin]]
name = "flame"
path = "tools/flame.rs"
required-features = ["flame"]

[[bin]]
name = "lightglue"
path = "tools/lightglue.rs"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ a bevy plugin for the [ort](https://docs.rs/ort/latest/ort/) library
- [X] lightglue (feature matching)
- [X] modnet (photographic portrait matting)
- [X] yolo_v8 (object detection)
- [X] flame (parametric head model)


## library usage
Expand Down
4 changes: 2 additions & 2 deletions benches/modnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ fn modnet_output_to_luma_images_benchmark(c: &mut Criterion) {

let session = Session::builder().unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
.with_model_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap();
.commit_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap();

let data = vec![0u8; (1920 * 1080 * 4) as usize];
let image: Image = Image::new(
Expand Down Expand Up @@ -123,7 +123,7 @@ fn modnet_inference_benchmark(c: &mut Criterion) {

let session = Session::builder().unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
.with_model_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap();
.commit_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap();

MAX_RESOLUTIONS.iter().for_each(|(width, height)| {
let data = vec![0u8; *width as usize * *height as usize * 4];
Expand Down
4 changes: 2 additions & 2 deletions benches/yolo_v8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ fn process_output_benchmark(c: &mut Criterion) {

let session = Session::builder().unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
.with_model_from_file("assets/yolov8n.onnx").unwrap();
.commit_from_file("assets/yolov8n.onnx").unwrap();

RESOLUTIONS.iter()
.for_each(|(width, height)| {
Expand Down Expand Up @@ -117,7 +117,7 @@ fn inference_benchmark(c: &mut Criterion) {

let session = Session::builder().unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
.with_model_from_file("assets/yolov8n.onnx").unwrap();
.commit_from_file("assets/yolov8n.onnx").unwrap();

RESOLUTIONS.iter().for_each(|(width, height)| {
let data = vec![0u8; *width as usize * *height as usize * 4];
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl AssetLoader for OnnxLoader {
// TODO: add session configuration
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_model_from_memory(&bytes)?;
.commit_from_memory(&bytes)?;

Ok(Onnx {
session: Arc::new(Mutex::new(Some(session))),
Expand Down
163 changes: 163 additions & 0 deletions src/models/flame.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
use bevy::prelude::*;
use ndarray::Array2;
use serde::{Deserialize, Serialize};

use crate::{
inputs,
Onnx,
};



pub struct FlamePlugin;
impl Plugin for FlamePlugin {
fn build(&self, app: &mut App) {
app.init_resource::<Flame>();
}
}

#[derive(Resource, Default)]
pub struct Flame {
pub onnx: Handle<Onnx>,
}


#[derive(
Debug,
Clone,
)]
pub struct FlameInput {
pub shape: [[f32; 100]; 8],
pub pose: [[f32; 6]; 8],
pub expression: [[f32; 50]; 8],
pub neck: [[f32; 3]; 8],
pub eye: [[f32; 6]; 8],
}

impl Default for FlameInput {
fn default() -> Self {
let radian = std::f32::consts::PI / 180.0;

Self {
shape: [[0.0; 100]; 8],
pose: [
[0.0, 30.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, -30.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, 85.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, -48.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, 10.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, -15.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0 * radian, 0.0, 0.0, 0.0, 0.0],
],
expression: [[0.0; 50]; 8],
neck: [[0.0; 3]; 8],
eye: [[0.0; 6]; 8],
}
}
}


#[derive(
Debug,
Default,
Clone,
Deserialize,
Serialize,
)]
pub struct FlameOutput {
pub vertices: Vec<[f32; 3]>, // TODO: use Vec3 for binding
pub landmarks: Vec<[f32; 3]>,
}


pub fn flame_inference(
session: &ort::Session,
input: &FlameInput,
) -> FlameOutput {
let PreparedInput {
shape,
expression,
pose,
neck,
eye,
} = prepare_input(input);

let input_values = inputs![
"shape" => shape.view(),
"expression" => expression.view(),
"pose" => pose.view(),
"neck" => neck.view(),
"eye" => eye.view(),
].map_err(|e| e.to_string()).unwrap();
let outputs = session.run(input_values).map_err(|e| e.to_string());
let binding = outputs.ok().unwrap();

let vertices: &ort::Value = binding.get("vertices").unwrap();
let landmarks: &ort::Value = binding.get("landmarks").unwrap();

post_process(
vertices,
landmarks,
)
}


pub struct PreparedInput {
pub shape: Array2<f32>,
pub pose: Array2<f32>,
pub expression: Array2<f32>,
pub neck: Array2<f32>,
pub eye: Array2<f32>,
}

pub fn prepare_input(
input: &FlameInput,
) -> PreparedInput {
let shape = Array2::from_shape_vec((8, 100), input.shape.concat()).unwrap();
let pose = Array2::from_shape_vec((8, 6), input.pose.concat()).unwrap();
let expression = Array2::from_shape_vec((8, 50), input.expression.concat()).unwrap();
let neck = Array2::from_shape_vec((8, 3), input.neck.concat()).unwrap();
let eye = Array2::from_shape_vec((8, 6), input.eye.concat()).unwrap();

PreparedInput {
shape,
expression,
pose,
neck,
eye,
}
}


pub fn post_process(
vertices: &ort::Value,
landmarks: &ort::Value,
) -> FlameOutput {
let vertices_tensor = vertices.try_extract_tensor::<f32>().unwrap();
let vertices_view = vertices_tensor.view(); // [8, 5023, 3]

let landmarks_tensor = landmarks.try_extract_tensor::<f32>().unwrap();
let landmarks_view = landmarks_tensor.view(); // [8, 68, 3]

let vertices = vertices_view.outer_iter()
.flat_map(|subtensor| {
subtensor.outer_iter().map(|row| {
[row[0], row[1], row[2]]
}).collect::<Vec<[f32; 3]>>()
})
.collect::<Vec::<_>>();

let landmarks = landmarks_view.outer_iter()
.flat_map(|subtensor| {
subtensor.outer_iter().map(|row| {
[row[0], row[1], row[2]]
}).collect::<Vec<[f32; 3]>>()
})
.collect::<Vec::<_>>();

FlameOutput {
vertices,
landmarks,
}
}
6 changes: 3 additions & 3 deletions src/models/lightglue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ pub fn post_process(
kpts1: &ort::Value,
matches: &ort::Value,
) -> Result<Vec<GluedPair>, &'static str> {
let kpts0_tensor = kpts0.extract_tensor::<i64>().unwrap();
let kpts0_tensor = kpts0.try_extract_tensor::<i64>().unwrap();
let kpts0_view = kpts0_tensor.view();

let kpts1_tensor = kpts1.extract_tensor::<i64>().unwrap();
let kpts1_tensor = kpts1.try_extract_tensor::<i64>().unwrap();
let kpts1_view = kpts1_tensor.view();

let matches = matches.extract_tensor::<i64>().unwrap();
let matches = matches.try_extract_tensor::<i64>().unwrap();
let matches_view = matches.view();

Ok(
Expand Down
3 changes: 3 additions & 0 deletions src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#[cfg(feature = "flame")]
pub mod flame;

#[cfg(feature = "lightglue")]
pub mod lightglue;

Expand Down
2 changes: 1 addition & 1 deletion src/models/modnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub fn modnet_inference(
pub fn modnet_output_to_luma_images(
output_value: &ort::Value,
) -> Vec<Image> {
let tensor = output_value.extract_tensor::<f32>().unwrap();
let tensor = output_value.try_extract_tensor::<f32>().unwrap();
let data = tensor.view();

let shape = data.shape();
Expand Down
2 changes: 1 addition & 1 deletion src/models/yolo_v8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pub fn process_output(
) -> Vec<BoundingBox> {
let mut boxes = Vec::new();

let tensor = output.extract_tensor::<f32>().unwrap();
let tensor = output.try_extract_tensor::<f32>().unwrap();
let data = tensor.view().t().into_owned();

for detection in data.axis_iter(Axis(0)) {
Expand Down
70 changes: 70 additions & 0 deletions tools/flame.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use bevy::prelude::*;

use bevy_ort::{
BevyOrtPlugin,
models::flame::{
FlameInput,
FlameOutput,
flame_inference,
Flame,
FlamePlugin,
},
Onnx,
};


fn main() {
App::new()
.add_plugins((
DefaultPlugins,
BevyOrtPlugin,
FlamePlugin,
))
.add_systems(Startup, load_flame)
.add_systems(Update, inference)
.run();
}


fn load_flame(
asset_server: Res<AssetServer>,
mut flame: ResMut<Flame>,
) {
let flame_handle: Handle<Onnx> = asset_server.load("models/flame.onnx");
flame.onnx = flame_handle;
}


fn inference(
mut commands: Commands,
flame: Res<Flame>,
onnx_assets: Res<Assets<Onnx>>,
mut complete: Local<bool>,
) {
if *complete {
return;
}

let flame_output: Result<FlameOutput, String> = (|| {
let onnx = onnx_assets.get(&flame.onnx).ok_or("failed to get ONNX asset")?;
let session_lock = onnx.session.lock().map_err(|e| e.to_string())?;
let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?;

Ok(flame_inference(
session,
&FlameInput::default(),
))
})();

match flame_output {
Ok(_flame_output) => {
// TODO: insert mesh
// TODO: insert pan orbit camera
commands.spawn(Camera3dBundle::default());
*complete = true;
}
Err(e) => {
eprintln!("inference failed: {}", e);
}
}
}

0 comments on commit 7eff3a9

Please sign in to comment.