Skip to content

Commit

Permalink
Range on TDIM output i64 instead of TDIM
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos committed Dec 11, 2024
1 parent e78057f commit b1c73c4
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions core/src/ops/array/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ impl Range {
values: &SymbolValues,
) -> TractResult<Tensor> {
if start.datum_type() == TDim::datum_type() {
let start = start.to_scalar::<TDim>()?.eval(values).to_i64()?;
let step = step.to_scalar::<TDim>()?.eval(values).to_i64()?;
let len = {
let start = start.to_scalar::<TDim>()?.eval(values).to_i64()?;
let end = end.to_scalar::<TDim>()?.eval(values).to_i64()?;
let step = step.to_scalar::<TDim>()?.eval(values).to_i64()?;
#[allow(clippy::cast_abs_to_unsigned)]
((end - start).abs() as usize).divceil(step.abs() as usize)
};
Self::make_t::<TDim>(start, step, len)
Self::make_t::<i64>(&tensor0(start), &tensor0(step), len)
} else {
let len = dispatch_numbers!(Self::len_for_numbers(start.datum_type())(
self, start, end, step
Expand Down Expand Up @@ -116,22 +116,23 @@ impl TypedOp for Range {
ensure!(end.shape.volume().is_one());
ensure!(step.shape.volume().is_one());
if let (Some(start), Some(end), Some(step)) = (&start.konst, &end.konst, &step.konst) {
let len = if start.datum_type() == TDim::datum_type() {
if start.datum_type() == TDim::datum_type() {
let start = start.to_scalar::<TDim>()?;
let end = end.to_scalar::<TDim>()?;
let step = step.cast_to_scalar::<i64>()?;
if step < 0 {
let len = if step < 0 {
(start.clone() - end).divceil(-step as usize)
} else {
(end.clone() - start).divceil(step as usize)
}
};
Ok(tvec!(DatumType::I64.fact([len])))
} else {
dispatch_numbers!(Self::len_for_numbers(start.datum_type())(
let len = dispatch_numbers!(Self::len_for_numbers(start.datum_type())(
self, start, end, step
))?
.to_dim()
};
Ok(tvec!(start.datum_type().fact([len])))
.to_dim();
Ok(tvec!(start.datum_type().fact([len])))
}
} else {
Ok(tvec!(start.datum_type.fact(&[self.len.clone()])))
}
Expand Down

0 comments on commit b1c73c4

Please sign in to comment.