From f5a4e472c872d767379d116c35980846d5f5f291 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 21 Mar 2024 17:43:05 +0100 Subject: [PATCH] cast back to tdim in reduce --- hir/src/ops/nn/reduce.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/hir/src/ops/nn/reduce.rs b/hir/src/ops/nn/reduce.rs index 149431b40c..10054262d0 100644 --- a/hir/src/ops/nn/reduce.rs +++ b/hir/src/ops/nn/reduce.rs @@ -211,7 +211,7 @@ impl Expansion for Reduce { inputs: &[OutletId], ) -> TractResult> { let mut wire = inputs[0]; - let fact = target.outlet_fact(wire)?; + let fact = target.outlet_fact(wire)?.clone(); let mut axes = self.resolve_axes(fact.rank())?; axes.sort(); if fact.datum_type == TDim::datum_type() { @@ -222,6 +222,13 @@ impl Expansion for Reduce { )?[0]; } wire = self.reducer.wire(axes.clone(), name, target, wire).context("wiring reducer")?; + if fact.datum_type == TDim::datum_type() { + wire = target.wire_node( + format!("{name}.cast_to_tdim"), + tract_core::ops::cast::cast(TDim::datum_type()), + &[wire], + )?[0]; + } if !self.keep_dims { for axis in axes.into_iter().rev() { wire = target.wire_node(