forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_op_consistency.py
425 lines (337 loc) · 12.3 KB
/
test_op_consistency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
# Owner(s): ["module: onnx"]
"""Test consistency between the output values of torch.onnx exported operators
and torch operators given the same inputs.
Usage:
pytest test/onnx/test_op_consistancy.py
To run tests on a specific operator (e.g. torch.ceil):
pytest test/onnx/test_op_consistancy.py -k ceil
Read more on Running and writing tests:
https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
Note:
When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and
ALLOWLIST_OP lists. See "Modify this section"
"""
import copy
import dataclasses
import unittest
from typing import (
AbstractSet,
Callable,
Collection,
Iterable,
Optional,
Sequence,
Tuple,
Union,
)
import onnx_test_common
import torch
from torch.onnx import _constants
from torch.testing._internal import (
common_device_type,
common_methods_invocations,
common_utils,
)
from torch.testing._internal.opinfo import core as opinfo_core
# The min onnx opset version to test for
MIN_ONNX_OPSET_VERSION = 9
# The max onnx opset version to test for
MAX_ONNX_OPSET_VERSION = _constants.ONNX_MAX_OPSET
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
BOOL_TYPES = (torch.bool,)
INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
)
QINT_TYPES = (
torch.qint8,
torch.quint8,
)
FLOAT_TYPES = (
torch.float16,
torch.float32,
torch.float64,
)
COMPLEX_TYPES = (
torch.complex32,
torch.complex64,
torch.complex128,
)
SUPPORTED_DTYPES = (
# Boolean
torch.bool,
# Integers
*INT_TYPES,
# Floating types
*FLOAT_TYPES,
)
@dataclasses.dataclass
class DecorateMeta:
"""Information about a test case to skip or xfail.
Adapted from functorch: functorch/test/common_utils.py
Attributes:
op_name: The name of the operator.
variant_name: The name of the OpInfo variant.
decorator: The decorator to apply to the test case.
opsets: The opsets to apply the decorator to.
dtypes: The dtypes to apply the decorator to.
reason: The reason for skipping.
"""
op_name: str
variant_name: str
decorator: Callable
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
dtypes: Optional[Collection[torch.dtype]]
reason: str
def contains_opset(self, opset: int) -> bool:
if self.opsets is None:
return True
return any(
opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset)
for opset_spec in self.opsets
)
def xfail(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
):
"""Expects a OpInfo test to fail.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.expectedFailure,
opsets=opsets,
dtypes=dtypes,
reason=reason,
)
def dont_care(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
):
"""Skips a test case in OpInfo that we don't care about.
Likely because ONNX does not support the use case or it is by design.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"Don't care: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
)
def fixme(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
):
"""Skips a test case in OpInfo. It should be eventually fixed.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"To fix: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
)
def add_decorate_info(
all_opinfos: Sequence[opinfo_core.OpInfo],
test_class_name: str,
base_test_name: str,
opset: int,
skip_or_xfails: Iterable[DecorateMeta],
):
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list.
Args:
all_opinfos: All OpInfos.
test_class_name: The name of the test class.
base_test_name: The name of the test method.
opset: The opset to decorate for.
skip_or_xfails: DecorateMeta's.
"""
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
for decorate_meta in skip_or_xfails:
if not decorate_meta.contains_opset(opset):
# Skip does not apply to this opset
continue
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
assert (
opinfo is not None
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
decorators = list(opinfo.decorators)
new_decorator = opinfo_core.DecorateInfo(
decorate_meta.decorator,
test_class_name,
base_test_name,
dtypes=decorate_meta.dtypes,
)
decorators.append(new_decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
def opsets_before(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is before the specified."""
def compare(other_opset: int):
return other_opset < opset
return compare
def opsets_after(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is after the specified."""
def compare(other_opset: int):
return other_opset > opset
return compare
def reason_onnx_runtime_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX Runtime doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime"
def reason_onnx_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec"
def reason_jit_tracer_error(info: str) -> str:
"""Formats the reason: JIT tracer errors."""
return f"JIT tracer error on {info}"
def reason_flaky() -> str:
"""Formats the reason: test is flaky."""
return "flaky test"
# Modify this section ##########################################################
# NOTE: Modify this section as more ops are supported. The list should be sorted
# alphabetically.
#
# For example, to add a test for torch.ceil:
# 1. Add "ceil" to ALLOWLIST_OP then run pytest.
# 2. If the test fails, fix the error or add a new entry to EXPECTED_SKIPS_OR_FAILS.
# TODO: Directly modify DecorateInfo in each OpInfo in ob_db when all ops are enabled.
# Ops to be tested for numerical consistency between onnx and pytorch
ALLOWLIST_OP: AbstractSet[str] = frozenset(
[
"ceil",
"sqrt",
"t",
]
)
# fmt: off
# Turn off black formatting to keep the list compact
# Expected failures for onnx export.
# The list should be sorted alphabetically by op name.
# Q: When should I use fixme vs vs dont_care vs xfail?
# A: Use fixme when we want to fix the test eventually but it doesn't fail consistently,
# e.g. the test is flaky or some tests pass. Otherwise, use xfail.
# Use dont_care if we don't care about the test passing, e.g. ONNX doesn't support the usage.
# Use xfail if a test fails now and we want to eventually fix the test.
EXPECTED_SKIPS_OR_FAILS: Tuple[DecorateMeta, ...] = (
dont_care(
"ceil", dtypes=BOOL_TYPES + INT_TYPES,
reason=reason_onnx_does_not_support("Ceil")
),
fixme("ceil", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
dont_care("sqrt", dtypes=BOOL_TYPES, reason=reason_onnx_does_not_support("Sqrt")),
)
# fmt: on
# END OF SECTION TO MODIFY #####################################################
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
class SingleOpModel(torch.nn.Module):
"""Test model to wrap around a single op for export."""
def __init__(self, op, kwargs):
super().__init__()
self.operator = op
self.kwargs = kwargs
def forward(self, *args):
return self.operator(*args, **self.kwargs)
class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
"""Test output consistency between exported ONNX models and PyTorch eager mode.
This is a parameterized test suite.
"""
@classmethod
def create_test_base(cls, opset: int):
"""Returns the base test method for the given opset."""
def _output_match_base(self, device: str, dtype: torch.dtype, op):
"""Base test method for testing each opset, used by instantiate_device_type_tests."""
# device is provided by instantiate_device_type_tests, but we only want to run in cpu.
assert device == "cpu"
samples = op.sample_inputs(
device,
dtype,
requires_grad=False,
)
for (i, cpu_sample) in enumerate(samples):
# Provide the repr to subtest because tensors are not serializable in parallel test runs
with self.subTest(
opset=opset,
sample_num=i,
input=repr(cpu_sample.input),
args=repr(cpu_sample.args),
kwargs=repr(cpu_sample.kwargs),
):
model = SingleOpModel(op, cpu_sample.kwargs)
model.eval()
# Run the test
inputs = (cpu_sample.input, *cpu_sample.args)
self.run_test(model, inputs)
test_name = f"test_output_match_opset_{opset}"
_output_match_base.__name__ = test_name
return _output_match_base
@classmethod
def parameterize_opsets(cls, opsets: Sequence[int]):
"""Parametrizes the TestOnnxModelOutputConsistency class with the given opsets."""
for opset in opsets:
# Generate a test method for each opset
base_method = cls.create_test_base(opset)
# Important to rename the test method so that DecorateInfo can find it
test_name = base_method.__name__
# Update the ops to skip in the OpInfo database
add_decorate_info(
OPS_DB,
cls.__name__,
test_name,
opset=opset,
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
)
# Create parameterized tests for each op
filtered_ops = [op for op in OPS_DB if op.name in ALLOWLIST_OP]
decorated = common_device_type.ops(
filtered_ops,
allowed_dtypes=SUPPORTED_DTYPES,
)(base_method)
setattr(cls, test_name, decorated)
TestOnnxModelOutputConsistency.parameterize_opsets(TESTED_OPSETS)
common_device_type.instantiate_device_type_tests(
TestOnnxModelOutputConsistency, globals(), only_for="cpu"
)
if __name__ == "__main__":
common_utils.run_tests()