From b1c73c47ae322b211e32389aa0e1733fcfeece1a Mon Sep 17 00:00:00 2001 From: Hubert de La Jonquiere Date: Wed, 11 Dec 2024 17:31:34 +0100 Subject: [PATCH] Range on TDIM output i64 instead of TDIM --- core/src/ops/array/range.rs | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/core/src/ops/array/range.rs b/core/src/ops/array/range.rs index 5d9cae5327..7d33b7ad80 100644 --- a/core/src/ops/array/range.rs +++ b/core/src/ops/array/range.rs @@ -59,14 +59,14 @@ impl Range { values: &SymbolValues, ) -> TractResult { if start.datum_type() == TDim::datum_type() { + let start = start.to_scalar::()?.eval(values).to_i64()?; + let step = step.to_scalar::()?.eval(values).to_i64()?; let len = { - let start = start.to_scalar::()?.eval(values).to_i64()?; let end = end.to_scalar::()?.eval(values).to_i64()?; - let step = step.to_scalar::()?.eval(values).to_i64()?; #[allow(clippy::cast_abs_to_unsigned)] ((end - start).abs() as usize).divceil(step.abs() as usize) }; - Self::make_t::(start, step, len) + Self::make_t::(&tensor0(start), &tensor0(step), len) } else { let len = dispatch_numbers!(Self::len_for_numbers(start.datum_type())( self, start, end, step @@ -116,22 +116,23 @@ impl TypedOp for Range { ensure!(end.shape.volume().is_one()); ensure!(step.shape.volume().is_one()); if let (Some(start), Some(end), Some(step)) = (&start.konst, &end.konst, &step.konst) { - let len = if start.datum_type() == TDim::datum_type() { + if start.datum_type() == TDim::datum_type() { let start = start.to_scalar::()?; let end = end.to_scalar::()?; let step = step.cast_to_scalar::()?; - if step < 0 { + let len = if step < 0 { (start.clone() - end).divceil(-step as usize) } else { (end.clone() - start).divceil(step as usize) - } + }; + Ok(tvec!(DatumType::I64.fact([len]))) } else { - dispatch_numbers!(Self::len_for_numbers(start.datum_type())( + let len = dispatch_numbers!(Self::len_for_numbers(start.datum_type())( self, start, end, step ))? - .to_dim() - }; - Ok(tvec!(start.datum_type().fact([len]))) + .to_dim(); + Ok(tvec!(start.datum_type().fact([len]))) + } } else { Ok(tvec!(start.datum_type.fact(&[self.len.clone()]))) }