Skip to content

Commit

Permalink
Swin transformer conformance test
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Aug 8, 2024
1 parent faeb768 commit cf4aece
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 7 deletions.
6 changes: 6 additions & 0 deletions tests/post_training/data/ptq_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ torchvision/resnet18_backend_CUDA_TORCH:
metric_value: 0.69152
torchvision/resnet18_backend_FX_TORCH:
metric_value: 0.6946
torchvision/swin_v2_s_backend_FP32:
metric_value: 0.83712
torchvision/swin_v2_s_backend_OV:
metric_value: 0.83638
torchvision/swin_v2_s_backend_FX_TORCH:
metric_value: 0.82908
timm/crossvit_9_240_backend_CUDA_TORCH:
metric_value: 0.689
timm/crossvit_9_240_backend_FP32:
Expand Down
11 changes: 11 additions & 0 deletions tests/post_training/model_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@
"backends": [BackendType.FX_TORCH, BackendType.TORCH, BackendType.CUDA_TORCH, BackendType.OV, BackendType.ONNX],
"batch_size": 128,
},
{
"reported_name": "torchvision/swin_v2_s",
"model_id": "swin_v2_s",
"pipeline_cls": ImageClassificationTorchvision,
"compression_params": {
"model_type": ModelType.TRANSFORMER,
"advanced_parameters": AdvancedQuantizationParameters(smooth_quant_alpha=0.5),
},
"backends": [BackendType.FX_TORCH, BackendType.OV],
"batch_size": 1,
},
# Timm models
{
"reported_name": "timm/crossvit_9_240",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ def prepare_model(self) -> None:
model = model_cls(weights=self.model_weights)
model.eval()

self.static_input_size = [self.batch_size, 3, 224, 224]
default_input_size = [self.batch_size, 3, 224, 224]
self.dummy_tensor = self.model_weights.transforms()(torch.rand(default_input_size))
self.static_input_size = list(self.dummy_tensor.shape)

self.input_size = self.static_input_size.copy()
if self.batch_size > 1: # Dynamic batch_size shape export
self.input_size[0] = -1

self.dummy_tensor = torch.rand(self.static_input_size)

if self.backend == BackendType.FX_TORCH:
with torch.no_grad():
with disable_patching():
Expand Down
9 changes: 5 additions & 4 deletions tests/post_training/test_quantize_conformance.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def test_ptq_quantization(
err_msg = None
test_model_param = None
start_time = time.perf_counter()
try:
if True:
# try:
if test_case_name not in ptq_reference_data:
raise nncf.ValidationError(f"{test_case_name} does not exist in 'reference_data.yaml'")
test_model_param = PTQ_TEST_CASES[test_case_name]
Expand All @@ -271,9 +272,9 @@ def test_ptq_quantization(
)
pipeline: BaseTestPipeline = pipeline_cls(**pipeline_kwargs)
pipeline.run()
except Exception as e:
err_msg = str(e)
traceback.print_exc()
# except Exception as e:
# err_msg = str(e)
# traceback.print_exc()

if pipeline is not None:
pipeline.cleanup_cache()
Expand Down

0 comments on commit cf4aece

Please sign in to comment.