From 76d183edd7ab34ff7cd563f04e050c5d7a7ca123 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 4 Aug 2024 10:43:46 +0200 Subject: [PATCH 1/3] Fix missing type in _add_attribute_to_torchscript_node for Deberta models Signed-off-by: Xavier Dupre --- .../graph_building/_graph_building_torch.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index 4fac129ef..81b8e035c 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -459,8 +459,19 @@ def _add_attribute_to_torchscript_node( return node.fs_(key, list(value)) # type: ignore[arg-type] if isinstance(value[0], int): return node.is_(key, list(value)) # type: ignore[attr-defined] - raise TypeError(f"Unsupported sequence type '{type(value)}' for attribute '{key}'") - raise TypeError(f"Unsupported attribute type '{type(value)}' for attribute '{key}'") + raise TypeError( + f"Unsupported sequence type '{type(value)}' for attribute '{key}' in " + f"node={node!r}, value is {value!r}" + ) + if "TensorProtoDataType" in str(type(value)): + # torch._C._onnx.TensorProtoDataType + itype = int(value) + return node.i_(key, value) + + raise TypeError( + f"Unsupported attribute type '{type(value)}' for attribute '{key}' " + f"in node={node!r}, value is {value!r}" + ) @runtime_typing.checked From 1badd80dbcbe25e897a3a91c9b9ed4d9a7c92674 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 5 Aug 2024 11:58:29 +0000 Subject: [PATCH 2/3] fix lint --- .../torch_lib/graph_building/_graph_building_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index 81b8e035c..fc45a7394 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -466,7 +466,7 @@ def _add_attribute_to_torchscript_node( if "TensorProtoDataType" in str(type(value)): # torch._C._onnx.TensorProtoDataType itype = int(value) - return node.i_(key, value) + return node.i_(key, itype) raise TypeError( f"Unsupported attribute type '{type(value)}' for attribute '{key}' " From 6248b7491d7ce43b9ceb783c997618b4751d0f19 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 5 Aug 2024 19:01:51 +0200 Subject: [PATCH 3/3] address comments Signed-off-by: Xavier Dupre --- .../torch_lib/graph_building/_graph_building_torch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index fc45a7394..bef78a799 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -465,8 +465,7 @@ def _add_attribute_to_torchscript_node( ) if "TensorProtoDataType" in str(type(value)): # torch._C._onnx.TensorProtoDataType - itype = int(value) - return node.i_(key, itype) + return node.i_(key, int(value)) raise TypeError( f"Unsupported attribute type '{type(value)}' for attribute '{key}' "