Skip to content

Commit

Permalink
fix for ? in onnx shapes specs
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Apr 22, 2024
1 parent dae9667 commit 7dfcdc9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
5 changes: 3 additions & 2 deletions onnx/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use std::collections::HashMap;
use tract_hir::internal::*;
use tract_hir::prelude::tract_itertools::Itertools;

use crate::data_resolver::{self, ModelDataResolver};
use crate::pb::type_proto::Value;
use crate::pb::{self, TensorProto, TypeProto};
use crate::data_resolver::{self, ModelDataResolver};
use crate::tensor::{load_tensor, translate_inference_fact};
use prost::Message;

Expand Down Expand Up @@ -88,7 +88,8 @@ impl<'a> ParsingContext<'a> {
let fact = input.r#type.as_ref().unwrap().value.as_ref().unwrap();
#[allow(irrefutable_let_patterns)]
let fact: InferenceFact = if let pb::type_proto::Value::TensorType(fact) = fact {
translate_inference_fact(&ctx, fact, true)?
translate_inference_fact(&ctx, fact, true)
.with_context(|| format!("translating to fact: {:?}", fact))?
} else {
bail!("Can not parse tensor type");
};
Expand Down
7 changes: 4 additions & 3 deletions onnx/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::data_resolver::ModelDataResolver;
use crate::model::ParsingContext;
use crate::pb::tensor_proto::DataType;
use crate::pb::*;
use crate::data_resolver::ModelDataResolver;
use prost::Message;
use std::convert::{TryFrom, TryInto};
use std::path::PathBuf;
Expand Down Expand Up @@ -46,10 +46,11 @@ pub fn translate_inference_fact(
Ok(DimFact::from(v.to_dim()))
}
Some(tensor_shape_proto::dimension::Value::DimParam(v)) => {
if v.starts_with("unk__") && !include_unknown_symbols {
if v == "?" || (v.starts_with("unk__") && !include_unknown_symbols) {
Ok(DimFact::default())
} else {
let dim = parse_tdim(&ctx.symbol_table, v)?;
let dim = parse_tdim(&ctx.symbol_table, v)
.with_context(|| format!("Parsing as TDim: `{v}'"))?;
Ok(DimFact::from(dim))
}
}
Expand Down

0 comments on commit 7dfcdc9

Please sign in to comment.