Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stablehlo Assertion `!impl->wasOpReplaced(op) && "attempting to modify a replaced/erased op"' failed. #1235

Open
uazizTT opened this issue Nov 12, 2024 · 1 comment · May be fixed by #1252
Assignees
Labels
stablehlo conversion bug Bugs in StableHLO conversion

Comments

@uazizTT
Copy link
Contributor

uazizTT commented Nov 12, 2024

Example graph:

%c = stablehlo.constant dense<1> : tensor<19xi64>
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<19xi64>) -> tensor<19xi64>
%1 = stablehlo.add %0, %c : tensor<19xi64>

Error dump:

ttmlir-opt: /localdev/aknezevic/tt-mlir/env/build/llvm-project-prefix/src/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:1589: virtual void mlir::ConversionPatternRewriter::startOpModification(mlir::Operation *): Assertion `!impl->wasOpReplaced(op) && "attempting to modify a replaced/erased op"' failed.
@uazizTT uazizTT added the stablehlo conversion bug Bugs in StableHLO conversion label Nov 12, 2024
@uazizTT uazizTT added this to the [Third Party] HLO + XLA milestone Nov 12, 2024
@uazizTT uazizTT self-assigned this Nov 12, 2024
@mmanzoorTT
Copy link
Contributor

The following PyTorch example generates the same error

def test_broadcast_add():
  class Basic(nn.Module):
    def __init__(self):
      super().__init__()

    def forward(self, x, y):
      return  x + y

  input1 = torch.tensor([1], dtype=torch.float32)
  input2 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float32)

  tt_mod = torch.compile(Basic(), backend=backend)
  out = tt_mod(input1, input2)
  print(f"output: {out}")

Generated StableHLO graph

module {
  func.func @main(%arg0: tensor<8xf32>, %arg1: tensor<1xf32>) -> tensor<8xf32> {
    %0 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<8xf32>) -> tensor<8xf32>
    %1 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<1xf32>) -> tensor<8xf32>
    %2 = stablehlo.add %0, %1 : tensor<8xf32>
    return %2 : tensor<8xf32>
  }
}

PyTorch is broadcasting both operands, although there is no need to broadcast %arg0 (output shape is same as input).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stablehlo conversion bug Bugs in StableHLO conversion
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants