Skip to content

Commit

Permalink
fix: flame batch size constant
Browse files Browse the repository at this point in the history
  • Loading branch information
mosure committed May 5, 2024
1 parent 153a0f0 commit c13bed9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "bevy_ort"
description = "bevy ort (onnxruntime) plugin"
version = "0.11.0"
version = "0.11.1"
edition = "2021"
authors = ["mosure <[email protected]>"]
license = "MIT"
Expand Down
41 changes: 18 additions & 23 deletions src/models/flame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,39 +64,34 @@ fn flame_inference_system(
}


const FLAME_BATCH_SIZE: usize = 1;

#[derive(
Debug,
Clone,
Component,
Reflect,
)]
pub struct FlameInput {
pub shape: [[f32; 100]; 8],
pub pose: [[f32; 6]; 8],
pub expression: [[f32; 50]; 8],
pub neck: [[f32; 3]; 8],
pub eye: [[f32; 6]; 8],
pub shape: [[f32; 100]; FLAME_BATCH_SIZE],
pub pose: [[f32; 6]; FLAME_BATCH_SIZE],
pub expression: [[f32; 50]; FLAME_BATCH_SIZE],
pub neck: [[f32; 3]; FLAME_BATCH_SIZE],
pub eye: [[f32; 6]; FLAME_BATCH_SIZE],
}

impl Default for FlameInput {
fn default() -> Self {
let radian = std::f32::consts::PI / 180.0;

Self {
shape: [[0.0; 100]; 8],
shape: [[0.0; 100]; FLAME_BATCH_SIZE],
pose: [
[0.0, 30.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, -30.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, 85.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, -48.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, 10.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, -15.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0 * radian, 0.0, 0.0, 0.0, 0.0],
],
expression: [[0.0; 50]; 8],
neck: [[0.0; 3]; 8],
eye: [[0.0; 6]; 8],
expression: [[0.0; 50]; FLAME_BATCH_SIZE],
neck: [[0.0; 3]; FLAME_BATCH_SIZE],
eye: [[0.0; 6]; FLAME_BATCH_SIZE],
}
}
}
Expand Down Expand Up @@ -160,11 +155,11 @@ pub struct PreparedInput {
pub fn prepare_input(
input: &FlameInput,
) -> PreparedInput {
let shape = Array2::from_shape_vec((8, 100), input.shape.concat()).unwrap();
let pose = Array2::from_shape_vec((8, 6), input.pose.concat()).unwrap();
let expression = Array2::from_shape_vec((8, 50), input.expression.concat()).unwrap();
let neck = Array2::from_shape_vec((8, 3), input.neck.concat()).unwrap();
let eye = Array2::from_shape_vec((8, 6), input.eye.concat()).unwrap();
let shape = Array2::from_shape_vec((FLAME_BATCH_SIZE, 100), input.shape.concat()).unwrap();
let pose = Array2::from_shape_vec((FLAME_BATCH_SIZE, 6), input.pose.concat()).unwrap();
let expression = Array2::from_shape_vec((FLAME_BATCH_SIZE, 50), input.expression.concat()).unwrap();
let neck = Array2::from_shape_vec((FLAME_BATCH_SIZE, 3), input.neck.concat()).unwrap();
let eye = Array2::from_shape_vec((FLAME_BATCH_SIZE, 6), input.eye.concat()).unwrap();

PreparedInput {
shape,
Expand All @@ -181,10 +176,10 @@ pub fn post_process(
landmarks: &ort::Value,
) -> FlameOutput {
let vertices_tensor = vertices.try_extract_tensor::<f32>().unwrap();
let vertices_view = vertices_tensor.view(); // [8, 5023, 3]
let vertices_view = vertices_tensor.view(); // [FLAME_BATCH_SIZE, 5023, 3]

let landmarks_tensor = landmarks.try_extract_tensor::<f32>().unwrap();
let landmarks_view = landmarks_tensor.view(); // [8, 68, 3]
let landmarks_view = landmarks_tensor.view(); // [FLAME_BATCH_SIZE, 68, 3]

let vertices = vertices_view.outer_iter()
.flat_map(|subtensor| {
Expand Down

0 comments on commit c13bed9

Please sign in to comment.