Skip to content

Commit

Permalink
Fix perceiver tests and dummy inputs for ONNX (#1449)
Browse files Browse the repository at this point in the history
Co-authored-by: bas <[email protected]>
  • Loading branch information
baskrahmer and bk-jc authored Oct 16, 2023
1 parent 8f33e0e commit 38b0809
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 29 deletions.
53 changes: 39 additions & 14 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,12 +1034,37 @@ class Data2VecAudioOnnxConfig(AudioOnnxConfig):


class PerceiverDummyInputGenerator(DummyVisionInputGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = DEFAULT_DUMMY_SHAPES["width"],
height: int = DEFAULT_DUMMY_SHAPES["height"],
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
num_channels=num_channels,
width=width,
height=height,
**kwargs,
)

from transformers.onnx.utils import get_preprocessor

preprocessor = get_preprocessor(normalized_config._name_or_path)
if preprocessor is not None and hasattr(preprocessor, "size"):
self.height = preprocessor.size.get("height", self.height)
self.width = preprocessor.size.get("width", self.width)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
input_ = super().generate(
input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype
)
# if input_name == "pixel_values":
# input_ = input_[None, :]
return input_


Expand Down Expand Up @@ -1074,22 +1099,22 @@ def inputs_name(self):

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
# TODO: validate that.
dynamic_axis = {0: "batch_size", 1: "sequence_length"}
return {
self.inputs_name: dynamic_axis,
# TODO: should we add the attention_mask?
# This breaks things for image-classification, suspected bug is the DummyInputGenerators not having the
# same num_channels / sequence_length.
# "attention_mask": dynamic_axis,
}
if self.inputs_name in ["input_ids", "inputs"]:
dynamic_axis = {0: "batch_size", 1: "sequence_length"}
return {
"input_ids": dynamic_axis,
"attention_mask": dynamic_axis,
}
else:
dynamic_axis = {0: "batch_size", 1: "sequence_length", 2: "width", 3: "height"}
return {
"pixel_values": dynamic_axis,
}

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
self.is_generating_dummy_inputs = True
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
specialized_inputs_name = self.inputs_name
self.is_generating_dummy_inputs = True
dummy_inputs[self.inputs_name] = dummy_inputs.pop(specialized_inputs_name)
dummy_inputs[self.inputs_name] = dummy_inputs.pop(self.inputs_name)
return dummy_inputs


Expand Down
21 changes: 6 additions & 15 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,18 +1236,15 @@ class ORTModelForMaskedLMIntegrationTest(ORTModelTestMixin):
"flaubert",
"ibert",
"mobilebert",
# "perceiver",
"perceiver_text",
"roberta",
"roformer",
"squeezebert",
"xlm",
"xlm_roberta",
]

ARCH_MODEL_MAP = {
# TODO: fix non passing test
# "perceiver": "hf-internal-testing/tiny-random-language_perceiver",
}
ARCH_MODEL_MAP = {} # TODO remove

FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
ORTMODEL_CLASS = ORTModelForMaskedLM
Expand Down Expand Up @@ -1398,18 +1395,15 @@ class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin):
"mbart",
"mobilebert",
"nystromformer",
# "perceiver",
"perceiver_text",
"roberta",
"roformer",
"squeezebert",
"xlm",
"xlm_roberta",
]

ARCH_MODEL_MAP = {
# TODO: fix non passing test
# "perceiver": "hf-internal-testing/tiny-random-language_perceiver",
}
ARCH_MODEL_MAP = {} # TODO remove

FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
ORTMODEL_CLASS = ORTModelForSequenceClassification
Expand Down Expand Up @@ -2375,18 +2369,15 @@ class ORTModelForImageClassificationIntegrationTest(ORTModelTestMixin):
"mobilenet_v1",
"mobilenet_v2",
"mobilevit",
# "perceiver",
"perceiver_vision",
"poolformer",
"resnet",
"segformer",
"swin",
"vit",
]

ARCH_MODEL_MAP = {
# TODO: fix non passing test
# "perceiver": "hf-internal-testing/tiny-random-vision_perceiver_conv",
}
ARCH_MODEL_MAP = {} # TODO remove

FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
ORTMODEL_CLASS = ORTModelForImageClassification
Expand Down
2 changes: 2 additions & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
"mt5": "lewtun/tiny-random-mt5",
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
"pegasus": "hf-internal-testing/tiny-random-PegasusModel",
"perceiver_text": "hf-internal-testing/tiny-random-language_perceiver",
"perceiver_vision": "hf-internal-testing/tiny-random-vision_perceiver_conv",
"pix2struct": "fxmarty/pix2struct-tiny-random",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"resnet": "hf-internal-testing/tiny-random-resnet",
Expand Down

0 comments on commit 38b0809

Please sign in to comment.