Skip to content

Commit

Permalink
fill
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Jul 22, 2024
1 parent b7a8105 commit 4130673
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3490,14 +3490,13 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType


@torch_op(("aten::fill.Tensor", "aten::fill.Sclaar"))
def aten_fill(self: TTensor, value: TTensor) -> TTensor:
def aten_fill(self: TTensor, value: TTensor2) -> TTensor:
"""fill.Tensor(Tensor self, Tensor value) -> Tensor"""

# after fill, the self Tensor should keep origianl type
# Cast the value before Expand so it can be constant folded
value = op.CastLike(value, self)
shape = op.Shape(self)
expanded = op.Expand(value, shape)
result = op.CastLike(expanded, self)
return result
return op.Expand(value, shape)


def aten_fix(self: TensorType) -> TensorType:
Expand Down

0 comments on commit 4130673

Please sign in to comment.