We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
to_edge
When attempting to convert a PyTorch model from the ATen dialect to the Edge dialect using the to_edge function, a data type mismatch error occurs.
export
import torch from torch import nn from executorch.exir import to_edge from torch.export import export # Create a simple model for demonstration class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.encoder = nn.Linear(80, 64) # Example simple layer def forward(self, x): return self.encoder(x) # Initialize the simple model model = SimpleModel() model.eval() # Define custom example inputs (using random tensors) example_inputs = ( torch.randn(1, 80), # Input tensor of shape (1, 80) torch.tensor([72], dtype=torch.float32), # Length tensor torch.randn(12, 1, 25, 512), # Example additional input tensor torch.randn(12, 1, 512, 15), # Another example tensor torch.tensor([15], dtype=torch.float32) # Another length tensor ) # Export model to ATen dialect ___audio_time = torch.export.Dim('___audio_time', min=128, max=256) dynamic_shapes = { "audio_signal": {2: 4 * ___audio_time}, "length": {0: torch.export.Dim.STATIC}, "cache_last_channel": {}, "cache_last_time": {}, "cache_last_channel_len": {} } # Perform the model export aten_model = export(model, example_inputs, dynamic_shapes=dynamic_shapes) # Attempt to convert ATen dialect model to Edge dialect edge_model = to_edge(aten_model)
torch._export.verifier.SpecViolationError: These operators are taking Tensor inputs with mismatched dtypes: defaultdict(<class 'dict'>, {<EdgeOpOverload: aten.div.Tensor_mode>: schema = aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor: {'self': torch.int32, 'other': torch.int64, '__ret_0': torch.int32}})
The text was updated successfully, but these errors were encountered:
cc @larryliu0820 @Gasoonjia @manuelcandales on op dtype issues
Sorry, something went wrong.
No branches or pull requests
🐛 Describe the bug
Bug Report: Data Type Mismatch Error in
to_edge
FunctionDescription
When attempting to convert a PyTorch model from the ATen dialect to the Edge dialect using the
to_edge
function, a data type mismatch error occurs.Steps to Reproduce
export
function.to_edge
function.Sample Code
Error Message
Versions
Environment
The text was updated successfully, but these errors were encountered: