Skip to content

Commit

Permalink
Merge pull request #1356 from sonos/fix-1345
Browse files Browse the repository at this point in the history
cast back to tdim in reduce<sum>
  • Loading branch information
kali authored Mar 22, 2024
2 parents e449faa + f5a4e47 commit 9180657
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion hir/src/ops/nn/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ impl Expansion for Reduce {
inputs: &[OutletId],
) -> TractResult<TVec<OutletId>> {
let mut wire = inputs[0];
let fact = target.outlet_fact(wire)?;
let fact = target.outlet_fact(wire)?.clone();
let mut axes = self.resolve_axes(fact.rank())?;
axes.sort();
if fact.datum_type == TDim::datum_type() {
Expand All @@ -222,6 +222,13 @@ impl Expansion for Reduce {
)?[0];
}
wire = self.reducer.wire(axes.clone(), name, target, wire).context("wiring reducer")?;
if fact.datum_type == TDim::datum_type() {
wire = target.wire_node(
format!("{name}.cast_to_tdim"),
tract_core::ops::cast::cast(TDim::datum_type()),
&[wire],
)?[0];
}
if !self.keep_dims {
for axis in axes.into_iter().rev() {
wire = target.wire_node(
Expand Down

0 comments on commit 9180657

Please sign in to comment.