From e114c8c039886dcd2663fb7c9118f0421edeef64 Mon Sep 17 00:00:00 2001 From: "Christopher L. Crutchfield" Date: Wed, 19 Jun 2024 10:46:22 -0700 Subject: [PATCH] fix: undo additional error checking. this will be pushed into the iphone app instead. --- src/fish/fish_segmentation.rs | 61 +++++++---------------------------- 1 file changed, 12 insertions(+), 49 deletions(-) diff --git a/src/fish/fish_segmentation.rs b/src/fish/fish_segmentation.rs index ee02517..33ef84c 100644 --- a/src/fish/fish_segmentation.rs +++ b/src/fish/fish_segmentation.rs @@ -10,7 +10,7 @@ use ndarray::{s, Array2, Array3, ArrayBase, Axis, Dim, IxDynImpl, OwnedRepr}; use opencv::core::{Mat, Point2i, Size, VectorToVec, CV_8UC1}; use opencv::imgproc::{fill_poly, find_contours_with_hierarchy, resize_def, CHAIN_APPROX_NONE, LINE_8, RETR_CCOMP}; use opencv::types::{VectorOfPoint, VectorOfVec4i, VectorOfVectorOfPoint}; -use ort::{GraphOptimizationLevel, Session, SessionOutputs}; +use ort::{GraphOptimizationLevel, Session}; use reqwest::blocking::get; use cv_convert::TryIntoCv; @@ -275,8 +275,17 @@ impl FishSegmentation { } } - fn parse_results_from_inference(&self, outputs: SessionOutputs) -> + fn do_inference(&self, img: &Array3, model: &Session) -> Result<(ArrayBase, Dim>, ArrayBase, Dim>, ArrayBase, Dim>), ort::Error> { + let mut clone = img.clone(); + clone.swap_axes(2, 1); + clone.swap_axes(1, 0); + + println!("RUST: clone: {}, {}, {}", clone.shape()[0], clone.shape()[1], clone.shape()[2]); + + println!("RUST: Before Run"); + let outputs = model.run(ort::inputs!["argument_1.1" => clone.view()]?)?; + println!("RUST: After Run"); // boxes=tensor18, classes=pred_classes, masks=5232, scores=2339, img_size=onnx::Split_174 println!("RUST: Before parsing results"); @@ -292,45 +301,6 @@ impl FishSegmentation { Ok((boxes, masks, scores)) } - fn do_inference(&self, img: &Array3, model: &Session) -> - Result<(ArrayBase, Dim>, ArrayBase, Dim>, ArrayBase, Dim>), SegmentationError> { - let mut clone = img.clone(); - clone.swap_axes(2, 1); - clone.swap_axes(1, 0); - - println!("RUST: clone: {}, {}, {}", clone.shape()[0], clone.shape()[1], clone.shape()[2]); - - println!("RUST: Before Run"); - match ort::inputs!["argument_1.1" => clone.view()] { - Ok(input) => { - match model.run(input) { - Ok(outputs) => { - println!("RUST: After Run"); - - match self.parse_results_from_inference(outputs) { - Ok((boxes, masks, scores)) => Ok((boxes, masks, scores)), - Err(err) => { - match err { - ort::Error::PointerShouldBeNull(err) => { - if err == "GetTensorMutableData" { - Err(SegmentationError::FishNotFound) - } - else { - Err(SegmentationError::OrtErr(ort::Error::PointerShouldBeNull(err))) - } - }, - other => Err(SegmentationError::OrtErr(other)) - } - } - } - }, - Err(err) => Err(SegmentationError::OrtErr(err)) - } - }, - Err(err) => Err(SegmentationError::OrtErr(err)) - } - } - fn do_paste_mask(&self, masks: &Array2, img_h: u32, img_w: u32) -> Result, SegmentationError> { let masks_unsqueezed = masks.clone().insert_axis(Axis(2)); @@ -513,14 +483,7 @@ impl FishSegmentation { Ok(masks) } - Err(error) => { - match error { - SegmentationError::FishNotFound => { - Ok(Array2::::zeros((orig_height, orig_width))) - }, - other => Err(other) - } - } + Err(error) => Err(SegmentationError::OrtErr(error)) } } }