From 2b5173d7936f5f9eed794edad8df4924c17c6ab2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 7 Aug 2024 17:05:58 +0200 Subject: [PATCH] Fix missing type in _add_attribute_to_torchscript_node for Deberta models (#1773) Signed-off-by: Xavier Dupre --- .../graph_building/_graph_building_torch.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index 4fac129ef..bef78a799 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -459,8 +459,18 @@ def _add_attribute_to_torchscript_node( return node.fs_(key, list(value)) # type: ignore[arg-type] if isinstance(value[0], int): return node.is_(key, list(value)) # type: ignore[attr-defined] - raise TypeError(f"Unsupported sequence type '{type(value)}' for attribute '{key}'") - raise TypeError(f"Unsupported attribute type '{type(value)}' for attribute '{key}'") + raise TypeError( + f"Unsupported sequence type '{type(value)}' for attribute '{key}' in " + f"node={node!r}, value is {value!r}" + ) + if "TensorProtoDataType" in str(type(value)): + # torch._C._onnx.TensorProtoDataType + return node.i_(key, int(value)) + + raise TypeError( + f"Unsupported attribute type '{type(value)}' for attribute '{key}' " + f"in node={node!r}, value is {value!r}" + ) @runtime_typing.checked