Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx-ignore-output-shapes on by default #1188

Merged
merged 2 commits into from
Sep 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Unreleased
* ONNX ignoring output shapes is now the default

# 0.20.18 - 2023-08-30
* [intel] fix in AVX512F matrix vector product
* [tflite] alpha, embryonic support. some convolutional models working.
Expand Down
5 changes: 0 additions & 5 deletions api/py/tests/mobilenet_onnx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def test_inference_model():
def test_set_output_names_on_inference_model():
model = tract.onnx().model_for_path("./mobilenetv2-7.onnx")
model.set_input_fact(0, "B,3,224,224,f32")
model.set_output_fact(0, None)
model.analyse()
model.set_output_names(["mobilenetv20_output_pred_fwd"])
assert str(model.output_fact(0)) == "B,1000,1,1,F32"
Expand All @@ -100,7 +99,6 @@ def test_set_output_names():
def test_concretize():
model = tract.onnx().model_for_path("./mobilenetv2-7.onnx")
model.set_input_fact(0, "B,3,224,224,f32")
model.set_output_fact(0, None)
model.analyse()
typed = model.into_typed().into_decluttered()
assert str(typed.input_fact(0)) == "B,3,224,224,F32"
Expand All @@ -112,7 +110,6 @@ def test_concretize():
def test_pulse():
model = tract.onnx().model_for_path("./mobilenetv2-7.onnx")
model.set_input_fact(0, "B,3,224,224,f32")
model.set_output_fact(0, None)
model.analyse()
typed = model.into_typed().into_decluttered()
assert str(typed.input_fact(0)) == "B,3,224,224,F32"
Expand All @@ -128,7 +125,6 @@ def test_pulse():
def test_half():
model = tract.onnx().model_for_path("./mobilenetv2-7.onnx")
model.set_input_fact(0, "1,3,224,224,f32")
model.set_output_fact(0, None)
model.analyse()
typed = model.into_typed().into_decluttered()
typed.half()
Expand All @@ -138,7 +134,6 @@ def test_half():
def test_typed_model_to_nnef_and_back():
model = tract.onnx().model_for_path("./mobilenetv2-7.onnx")
model.set_input_fact(0, "B,3,224,224,f32")
model.set_output_fact(0, None)
model.analyse()
typed = model.into_typed()
with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down
4 changes: 0 additions & 4 deletions api/tests/mobilenet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ fn test_concretize() -> anyhow::Result<()> {
ensure_models()?;
let mut model = onnx()?.model_for_path("mobilenetv2-7.onnx")?;
model.set_input_fact(0, "B,3,224,224,f32")?;
model.set_output_fact(0, None)?;
model.analyse()?;
let mut typed = model.into_typed()?.into_decluttered()?;
assert_eq!(typed.input_fact(0)?.to_string(), "B,3,224,224,F32");
Expand All @@ -166,7 +165,6 @@ fn test_pulse() -> anyhow::Result<()> {
ensure_models()?;
let mut model = onnx()?.model_for_path("mobilenetv2-7.onnx")?;
model.set_input_fact(0, "B,3,224,224,f32")?;
model.set_output_fact(0, None)?;
model.analyse()?;
let mut typed = model.into_typed()?.into_decluttered()?;
assert_eq!(typed.input_fact(0)?.to_string(), "B,3,224,224,F32");
Expand All @@ -186,7 +184,6 @@ fn test_half() -> anyhow::Result<()> {
ensure_models()?;
let mut model = onnx()?.model_for_path("mobilenetv2-7.onnx")?;
model.set_input_fact(0, "B,3,224,224,f32")?;
model.set_output_fact(0, None)?;
model.analyse()?;
let mut typed = model.into_typed()?.into_decluttered()?;
typed.half()?;
Expand All @@ -200,7 +197,6 @@ fn test_typed_model_to_nnef_and_back() -> anyhow::Result<()> {
ensure_models()?;
let mut model = onnx()?.model_for_path("mobilenetv2-7.onnx")?;
model.set_input_fact(0, "B,3,224,224,f32")?;
model.set_output_fact(0, None)?;
model.analyse()?;
let typed = model.into_typed()?;
let dir = tempfile::tempdir()?;
Expand Down
Binary file modified examples/keras-tract-tf2/example.onnx
Binary file not shown.
Binary file modified examples/keras-tract-tf2/io.npz
Binary file not shown.
2 changes: 0 additions & 2 deletions examples/keras-tract-tf2/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ fn main() -> TractResult<()> {
let model = tract_onnx::onnx()
// load the model
.model_for_path("example.onnx")?
// kill over-spcified output fact in ONNX
.with_output_fact(0, Default::default())?
// optimize graph
.into_optimized()?
// make the model runnable and fix its inputs and outputs
Expand Down
6 changes: 3 additions & 3 deletions onnx/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ impl<'a> ParsingContext<'a> {
let mut outputs = vec![];
for output in graph.output.iter() {
let mut fact = InferenceFact::default();
if !self.framework.ignore_output_shapes {
if self.framework.use_output_shapes {
if let Some(f) = output.r#type.as_ref().and_then(|t| t.value.as_ref()) {
let pb::type_proto::Value::TensorType(f) = f;
fact = translate_inference_fact(&ctx, f)?
Expand Down Expand Up @@ -222,7 +222,7 @@ impl OnnxOpRegister {
#[derive(Clone, Default)]
pub struct Onnx {
pub op_register: OnnxOpRegister,
pub ignore_output_shapes: bool,
pub use_output_shapes: bool,
pub ignore_output_types: bool,
}

Expand Down Expand Up @@ -263,7 +263,7 @@ impl Onnx {
}

pub fn with_ignore_output_shapes(self, ignore: bool) -> Onnx {
Self { ignore_output_shapes: ignore, ..self }
Self { use_output_shapes: !ignore, ..self }
}

pub fn with_ignore_output_types(self, ignore: bool) -> Onnx {
Expand Down