Skip to content

Commit

Permalink
[onnx] Fix grad op domain
Browse files Browse the repository at this point in the history
  • Loading branch information
twata committed May 9, 2023
1 parent bfbc9cb commit 8995e70
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pytorch_pfn_extras/onnx/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _grad( # type: ignore
@staticmethod
def symbolic(g, output, grad_output, *inputs): # type: ignore
return g.op(
"ai.onnx.preview::Gradient",
"ai.onnx.preview.training::Gradient",
*inputs,
xs_s=input_names,
zs_s=[],
Expand Down
7 changes: 4 additions & 3 deletions tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def forward(self, x):
assert y.shape == (1, 1, 32, 20)


@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_grad():
if not pytorch_pfn_extras.requires('1.8.0'):
Expand Down Expand Up @@ -100,6 +100,7 @@ def forward(self, x):
)

actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
print(actual_onnx)
named_nodes = {n.name: n for n in actual_onnx.graph.node}
if pytorch_pfn_extras.requires("1.13"):
assert '/_ppe_as_out_module/conv/Conv' in named_nodes
Expand All @@ -122,7 +123,7 @@ def forward(self, x):
assert named_nodes["Conv_2"].output[0] == y_in


@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_grad_multiple_times():
if not pytorch_pfn_extras.requires("1.8.0"):
Expand Down Expand Up @@ -201,7 +202,7 @@ def forward(self, x):
assert named_nodes["Conv_7"].output[0] == y1_in


@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_grad_with_multiple_inputs():
if not pytorch_pfn_extras.requires("1.8.0"):
Expand Down

0 comments on commit 8995e70

Please sign in to comment.