forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_pytorch_onnx_onnxruntime_cuda.py
153 lines (129 loc) · 4.78 KB
/
test_pytorch_onnx_onnxruntime_cuda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Owner(s): ["module: onnx"]
import unittest
import onnx_test_common
import onnxruntime # noqa: F401
import parameterized
from onnx_test_common import MAX_ONNX_OPSET_VERSION, MIN_ONNX_OPSET_VERSION
from pytorch_test_common import (
skipIfNoBFloat16Cuda,
skipIfNoCuda,
skipIfUnsupportedMinOpsetVersion,
skipScriptTest,
)
from test_pytorch_onnx_onnxruntime import _parameterized_class_attrs_and_values
import torch
from torch.cuda.amp import autocast
from torch.testing._internal import common_utils
@parameterized.parameterized_class(
**_parameterized_class_attrs_and_values(
MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION
),
class_name_func=onnx_test_common.parameterize_class_name,
)
class TestONNXRuntime_cuda(onnx_test_common._TestONNXRuntime):
@skipIfUnsupportedMinOpsetVersion(9)
@skipIfNoCuda
def test_gelu_fp16(self):
class GeluModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.gelu(x)
x = torch.randn(
2,
4,
5,
6,
requires_grad=True,
dtype=torch.float16,
device=torch.device("cuda"),
)
self.run_test(GeluModel(), x, rtol=1e-3, atol=1e-5)
@skipIfUnsupportedMinOpsetVersion(9)
@skipIfNoCuda
@skipScriptTest()
def test_layer_norm_fp16(self):
class LayerNormModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer_norm = torch.nn.LayerNorm([10, 10])
@autocast()
def forward(self, x):
return self.layer_norm(x)
x = torch.randn(
20,
5,
10,
10,
requires_grad=True,
dtype=torch.float16,
device=torch.device("cuda"),
)
self.run_test(LayerNormModel().cuda(), x, rtol=1e-3, atol=1e-5)
@skipIfUnsupportedMinOpsetVersion(12)
@skipIfNoCuda
@skipScriptTest()
def test_softmaxCrossEntropy_fusion_fp16(self):
class FusionModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.loss = torch.nn.NLLLoss(reduction="none")
self.m = torch.nn.LogSoftmax(dim=1)
@autocast()
def forward(self, input, target):
output = self.loss(self.m(2 * input), target)
return output
N, C = 5, 4
input = torch.randn(N, 16, dtype=torch.float16, device=torch.device("cuda"))
target = torch.empty(N, dtype=torch.long, device=torch.device("cuda")).random_(
0, C
)
# using test data containing default ignore_index=-100
target[target == 1] = -100
self.run_test(FusionModel(), (input, target))
@skipIfNoCuda
@skipScriptTest()
def test_apex_o2(self):
class LinearModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 5)
def forward(self, x):
return self.linear(x)
try:
from apex import amp
except Exception as e:
raise unittest.SkipTest("Apex is not available") from e
input = torch.randn(3, 3, device=torch.device("cuda"))
model = amp.initialize(LinearModel(), opt_level="O2")
self.run_test(model, input)
# ONNX supports bfloat16 for opsets >= 13
# Add, Sub and Mul ops don't support bfloat16 cpu in onnxruntime.
@skipIfUnsupportedMinOpsetVersion(13)
@skipIfNoBFloat16Cuda
def test_arithmetic_bfp16(self):
class MyModule(torch.nn.Module):
def forward(self, x):
y = torch.ones(3, 4, dtype=torch.bfloat16, device=torch.device("cuda"))
x = x.type_as(y)
return torch.mul(torch.add(x, y), torch.sub(x, y)).to(
dtype=torch.float16
)
x = torch.ones(
3, 4, requires_grad=True, dtype=torch.float16, device=torch.device("cuda")
)
self.run_test(MyModule(), x, rtol=1e-3, atol=1e-5)
@skipIfNoCuda
def test_deduplicate_initializers_diff_devices(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.nn.Parameter(
torch.ones(2, 3, device=torch.device("cpu"))
)
self.b = torch.nn.Parameter(torch.ones(3, device=torch.device("cuda")))
def forward(self, x, y):
return torch.matmul(self.w, x), y + self.b
x = torch.randn(3, 3, device=torch.device("cpu"))
y = torch.randn(3, 3, device=torch.device("cuda"))
self.run_test(Model(), (x, y))
if __name__ == "__main__":
common_utils.run_tests()