Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data Type Mismatch Error in to_edge Function #7112

Open
99004327-Sourabh opened this issue Nov 27, 2024 · 1 comment
Open

Data Type Mismatch Error in to_edge Function #7112

99004327-Sourabh opened this issue Nov 27, 2024 · 1 comment
Labels
bug Something isn't working module: exir Issues related to Export IR

Comments

@99004327-Sourabh
Copy link

99004327-Sourabh commented Nov 27, 2024

🐛 Describe the bug

Bug Report: Data Type Mismatch Error in to_edge Function

Description

When attempting to convert a PyTorch model from the ATen dialect to the Edge dialect using the to_edge function, a data type mismatch error occurs.

Steps to Reproduce

  1. Load a pre-trained PyTorch model.
  2. Export the model to the ATen dialect using the export function.
  3. Attempt to convert the ATen dialect model to the Edge dialect using the to_edge function.

Sample Code

import torch
from torch import nn
from executorch.exir import to_edge
from torch.export import export

# Create a simple model for demonstration
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.encoder = nn.Linear(80, 64)  # Example simple layer

    def forward(self, x):
        return self.encoder(x)

# Initialize the simple model
model = SimpleModel()
model.eval()

# Define custom example inputs (using random tensors)
example_inputs = (
    torch.randn(1, 80),  # Input tensor of shape (1, 80)
    torch.tensor([72], dtype=torch.float32),  # Length tensor
    torch.randn(12, 1, 25, 512),  # Example additional input tensor
    torch.randn(12, 1, 512, 15),  # Another example tensor
    torch.tensor([15], dtype=torch.float32)  # Another length tensor
)

# Export model to ATen dialect
___audio_time = torch.export.Dim('___audio_time', min=128, max=256)
dynamic_shapes = {
    "audio_signal": {2: 4 * ___audio_time},
    "length": {0: torch.export.Dim.STATIC},
    "cache_last_channel": {},
    "cache_last_time": {},
    "cache_last_channel_len": {}
}

# Perform the model export
aten_model = export(model, example_inputs, dynamic_shapes=dynamic_shapes)

# Attempt to convert ATen dialect model to Edge dialect
edge_model = to_edge(aten_model)

Error Message

torch._export.verifier.SpecViolationError: These operators are taking Tensor inputs with mismatched dtypes: defaultdict(<class 'dict'>, {<EdgeOpOverload: aten.div.Tensor_mode>: schema = aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor: {'self': torch.int32, 'other': torch.int64, '__ret_0': torch.int32}})

Versions

Environment

  • PyTorch version: 2.5.0+cu124
  • Executorch version: 0.4.0
  • Python version: 3.10.12
  • Operating System: Linux (specifically, Amazon Linux 2)
@JacobSzwejbka
Copy link
Contributor

cc @larryliu0820 @Gasoonjia @manuelcandales on op dtype issues

@JacobSzwejbka JacobSzwejbka added module: exir Issues related to Export IR bug Something isn't working labels Dec 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working module: exir Issues related to Export IR
Projects
None yet
Development

No branches or pull requests

2 participants