forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_symbolic_helper.py
71 lines (61 loc) · 2.24 KB
/
test_symbolic_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Owner(s): ["module: onnx"]
"""Unit tests on `torch.onnx.symbolic_helper`."""
import torch
from torch.onnx import symbolic_helper
from torch.onnx._globals import GLOBALS
from torch.testing._internal import common_utils
class TestHelperFunctions(common_utils.TestCase):
def setUp(self):
super().setUp()
self._initial_training_mode = GLOBALS.training_mode
def tearDown(self):
GLOBALS.training_mode = self._initial_training_mode
@common_utils.parametrize(
"op_train_mode,export_mode",
[
common_utils.subtest(
[1, torch.onnx.TrainingMode.PRESERVE], name="export_mode_is_preserve"
),
common_utils.subtest(
[0, torch.onnx.TrainingMode.EVAL],
name="modes_match_op_train_mode_0_export_mode_eval",
),
common_utils.subtest(
[1, torch.onnx.TrainingMode.TRAINING],
name="modes_match_op_train_mode_1_export_mode_training",
),
],
)
def test_check_training_mode_does_not_warn_when(
self, op_train_mode: int, export_mode: torch.onnx.TrainingMode
):
GLOBALS.training_mode = export_mode
self.assertNotWarn(
lambda: symbolic_helper.check_training_mode(op_train_mode, "testop")
)
@common_utils.parametrize(
"op_train_mode,export_mode",
[
common_utils.subtest(
[0, torch.onnx.TrainingMode.TRAINING],
name="modes_do_not_match_op_train_mode_0_export_mode_training",
),
common_utils.subtest(
[1, torch.onnx.TrainingMode.EVAL],
name="modes_do_not_match_op_train_mode_1_export_mode_eval",
),
],
)
def test_check_training_mode_warns_when(
self,
op_train_mode: int,
export_mode: torch.onnx.TrainingMode,
):
with self.assertWarnsRegex(
UserWarning, f"ONNX export mode is set to {export_mode}"
):
GLOBALS.training_mode = export_mode
symbolic_helper.check_training_mode(op_train_mode, "testop")
common_utils.instantiate_parametrized_tests(TestHelperFunctions)
if __name__ == "__main__":
common_utils.run_tests()