Skip to content

Commit

Permalink
Remove if branches in arange | fix(torchlib) (#1097)
Browse files Browse the repository at this point in the history
Replace if with where for performance.
  • Loading branch information
justinchuby authored Oct 23, 2023
1 parent 778e8e8 commit b0a0b87
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,11 +591,8 @@ def _adjust_args_for_arange_int_dtype(
end = op.Cast(end, to=FLOAT.dtype)
step = op.Cast(step, to=FLOAT.dtype)

if start < zero:
start = op.Ceil(start)

if step < zero:
start = op.Floor(start)
start = op.Where(op.Less(start, zero), op.Ceil(start), start)
start = op.Where(op.Less(step, zero), op.Floor(start), start)

return (start, end, step)

Expand Down

0 comments on commit b0a0b87

Please sign in to comment.