Skip to content

Commit

Permalink
fix: undo additional error checking. this will be pushed into the iph…
Browse files Browse the repository at this point in the history
…one app instead.
  • Loading branch information
ccrutchf committed Jun 19, 2024
1 parent 51b6cda commit e114c8c
Showing 1 changed file with 12 additions and 49 deletions.
61 changes: 12 additions & 49 deletions src/fish/fish_segmentation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use ndarray::{s, Array2, Array3, ArrayBase, Axis, Dim, IxDynImpl, OwnedRepr};
use opencv::core::{Mat, Point2i, Size, VectorToVec, CV_8UC1};
use opencv::imgproc::{fill_poly, find_contours_with_hierarchy, resize_def, CHAIN_APPROX_NONE, LINE_8, RETR_CCOMP};
use opencv::types::{VectorOfPoint, VectorOfVec4i, VectorOfVectorOfPoint};
use ort::{GraphOptimizationLevel, Session, SessionOutputs};
use ort::{GraphOptimizationLevel, Session};
use reqwest::blocking::get;
use cv_convert::TryIntoCv;

Expand Down Expand Up @@ -275,8 +275,17 @@ impl FishSegmentation {
}
}

fn parse_results_from_inference(&self, outputs: SessionOutputs) ->
fn do_inference(&self, img: &Array3<f32>, model: &Session) ->
Result<(ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>, ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>, ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>), ort::Error> {
let mut clone = img.clone();
clone.swap_axes(2, 1);
clone.swap_axes(1, 0);

println!("RUST: clone: {}, {}, {}", clone.shape()[0], clone.shape()[1], clone.shape()[2]);

println!("RUST: Before Run");
let outputs = model.run(ort::inputs!["argument_1.1" => clone.view()]?)?;
println!("RUST: After Run");

// boxes=tensor18, classes=pred_classes, masks=5232, scores=2339, img_size=onnx::Split_174
println!("RUST: Before parsing results");
Expand All @@ -292,45 +301,6 @@ impl FishSegmentation {
Ok((boxes, masks, scores))
}

fn do_inference(&self, img: &Array3<f32>, model: &Session) ->
Result<(ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>, ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>, ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>), SegmentationError> {
let mut clone = img.clone();
clone.swap_axes(2, 1);
clone.swap_axes(1, 0);

println!("RUST: clone: {}, {}, {}", clone.shape()[0], clone.shape()[1], clone.shape()[2]);

println!("RUST: Before Run");
match ort::inputs!["argument_1.1" => clone.view()] {
Ok(input) => {
match model.run(input) {
Ok(outputs) => {
println!("RUST: After Run");

match self.parse_results_from_inference(outputs) {
Ok((boxes, masks, scores)) => Ok((boxes, masks, scores)),
Err(err) => {
match err {
ort::Error::PointerShouldBeNull(err) => {
if err == "GetTensorMutableData" {
Err(SegmentationError::FishNotFound)
}
else {
Err(SegmentationError::OrtErr(ort::Error::PointerShouldBeNull(err)))
}
},
other => Err(SegmentationError::OrtErr(other))
}
}
}
},
Err(err) => Err(SegmentationError::OrtErr(err))
}
},
Err(err) => Err(SegmentationError::OrtErr(err))
}
}

fn do_paste_mask(&self, masks: &Array2<f32>, img_h: u32, img_w: u32) -> Result<Array2<f32>, SegmentationError> {
let masks_unsqueezed = masks.clone().insert_axis(Axis(2));

Expand Down Expand Up @@ -513,14 +483,7 @@ impl FishSegmentation {

Ok(masks)
}
Err(error) => {
match error {
SegmentationError::FishNotFound => {
Ok(Array2::<u8>::zeros((orig_height, orig_width)))
},
other => Err(other)
}
}
Err(error) => Err(SegmentationError::OrtErr(error))
}
}
}
Expand Down

0 comments on commit e114c8c

Please sign in to comment.