Skip to content

Commit

Permalink
fix incomplete type analyse for Split13
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Jan 22, 2024
1 parent 6a4f610 commit e9d7b8a
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions onnx/src/ops/array/split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::pb::*;
pub fn split(
ctx: &ParsingContext,
node: &NodeProto,
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
let axis = node.get_attr_opt("axis")?.unwrap_or(0);
if ctx.onnx_operator_set_version < 13 || node.input.len() == 1 {
let split = node.get_attr_opt_vec("split")?;
Expand All @@ -23,8 +23,6 @@ struct Split13 {
outputs: usize,
}



impl Expansion for Split13 {
fn name(&self) -> Cow<str> {
"Split13".into()
Expand All @@ -35,8 +33,23 @@ impl Expansion for Split13 {
s: &mut Solver<'r>,
inputs: &'p [TensorProxy],
outputs: &'p [TensorProxy],
) -> InferenceResult {
) -> InferenceResult {
check_input_arity(inputs, 2)?;
for o in outputs {
s.equals(&inputs[0].rank, &o.rank)?;
s.equals(&inputs[0].datum_type, &o.datum_type)?;
}
s.given(&inputs[0].rank, move |s, rank| {
let axis = (self.axis + if self.axis < 0 { rank as isize } else { 0 }) as usize;
for a in 0..rank as usize {
if a != axis {
for o in outputs {
s.equals(&inputs[0].shape[a], &o.shape[a])?;
}
}
}
Ok(())
})?;
s.given_2(&inputs[0].shape, &inputs[1].value, move |s, shape, splits| {
let splits = splits.cast_to::<TDim>()?;
let splits = splits.as_slice::<TDim>()?;
Expand All @@ -53,9 +66,10 @@ impl Expansion for Split13 {
prefix: &str,
model: &mut TypedModel,
inputs: &[OutletId],
) -> TractResult<TVec<OutletId>> {
) -> TractResult<TVec<OutletId>> {
if let Some(splits) = model.outlet_fact(inputs[1])?.konst.as_ref() {
let axis = self.axis + if self.axis < 0 { model.outlet_fact(inputs[0])?.rank() as isize } else { 0 };
let axis = self.axis
+ if self.axis < 0 { model.outlet_fact(inputs[0])?.rank() as isize } else { 0 };
let splits = splits.cast_to::<i64>()?;
let splits = splits.as_slice::<i64>()?.iter().map(|i| *i as usize).collect::<Vec<_>>();
let op = tract_hir::ops::array::Split::new(axis, splits.len(), Some(splits));
Expand All @@ -67,5 +81,4 @@ impl Expansion for Split13 {
fn nboutputs(&self) -> TractResult<usize> {
Ok(self.outputs)
}

}

0 comments on commit e9d7b8a

Please sign in to comment.