From 05d20df3e6602e26d01cf3994a108de5b097a719 Mon Sep 17 00:00:00 2001 From: Athul Raj K <79792953+krathul@users.noreply.github.com> Date: Wed, 23 Aug 2023 21:34:14 +0530 Subject: [PATCH] Pix2Struct onnxruntime support (#1296) * add onnxruntime support for pix2struct * Fix for Pix2Struct ORT * Fix from_pretrained for pix2struct * test * test * huggingface(#1288) * Update optimum/onnxruntime/modeling_seq2seq.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * Update optimum/onnxruntime/modeling_seq2seq.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * Update optimum/onnxruntime/modeling_seq2seq.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * update modeling_seq2seq.py * update modeling_seq2seq.py * working ort inference pix2struct * add documentation * fix doc --------- Co-authored-by: ARK Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> --- .../package_reference/modeling_ort.mdx | 38 ++- optimum/exporters/onnx/__main__.py | 13 +- optimum/exporters/onnx/base.py | 67 ++-- optimum/exporters/onnx/config.py | 8 +- optimum/exporters/onnx/model_configs.py | 119 ++++++- optimum/gptq/utils.py | 4 +- optimum/onnxruntime/__init__.py | 14 +- optimum/onnxruntime/base.py | 36 +- optimum/onnxruntime/modeling_seq2seq.py | 250 +++++++++++++- optimum/utils/input_generators.py | 10 +- optimum/utils/normalized_config.py | 7 +- tests/onnxruntime/test_modeling.py | 323 ++++++++++++++++++ tests/onnxruntime/utils_onnxruntime_tests.py | 19 +- 13 files changed, 834 insertions(+), 74 deletions(-) diff --git a/docs/source/onnxruntime/package_reference/modeling_ort.mdx b/docs/source/onnxruntime/package_reference/modeling_ort.mdx index ebbfa1736e..43703a0ad9 100644 --- a/docs/source/onnxruntime/package_reference/modeling_ort.mdx +++ b/docs/source/onnxruntime/package_reference/modeling_ort.mdx @@ -26,30 +26,44 @@ The following ORT classes are available for the following natural language proce ### ORTModelForCausalLM +This class officially supports bloom, codegen, gpt2, gpt_bigcode, gpt_neo, gpt_neox, gptj, llama. + [[autodoc]] onnxruntime.ORTModelForCausalLM ### ORTModelForMaskedLM +This class officially supports albert, bert, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, ibert, mobilebert, roberta, roformer, squeezebert, xlm, xlm_roberta. + [[autodoc]] onnxruntime.ORTModelForMaskedLM ### ORTModelForSeq2SeqLM +This class officially supports bart, blenderbot, blenderbot_small, longt5, m2m_100, marian, mbart, mt5, pegasus, t5. + [[autodoc]] onnxruntime.ORTModelForSeq2SeqLM ### ORTModelForSequenceClassification +This class officially supports albert, bart, bert, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, ibert, mbart, mobilebert, nystromformer, roberta, roformer, squeezebert, xlm, xlm_roberta. + [[autodoc]] onnxruntime.ORTModelForSequenceClassification ### ORTModelForTokenClassification +This class officially supports albert, bert, bloom, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, gpt2, ibert, mobilebert, roberta, roformer, squeezebert, xlm, xlm_roberta. + [[autodoc]] onnxruntime.ORTModelForTokenClassification ### ORTModelForMultipleChoice +This class officially supports albert, bert, camembert, convbert, data2vec_text, deberta_v2, distilbert, electra, flaubert, ibert, mobilebert, nystromformer, roberta, roformer, squeezebert, xlm, xlm_roberta. + [[autodoc]] onnxruntime.ORTModelForMultipleChoice ### ORTModelForQuestionAnswering +This class officially supports albert, bart, bert, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, gptj, ibert, mbart, mobilebert, nystromformer, roberta, roformer, squeezebert, xlm, xlm_roberta. + [[autodoc]] onnxruntime.ORTModelForQuestionAnswering ## Computer vision @@ -58,10 +72,14 @@ The following ORT classes are available for the following computer vision tasks. ### ORTModelForImageClassification +This class officially supports beit, convnext, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, vit. + [[autodoc]] onnxruntime.ORTModelForImageClassification ### ORTModelForSemanticSegmentation +This class officially supports segformer. + [[autodoc]] onnxruntime.ORTModelForSemanticSegmentation ## Audio @@ -70,22 +88,32 @@ The following ORT classes are available for the following audio tasks. ### ORTModelForAudioClassification +This class officially supports audio_spectrogram_transformer, data2vec_audio, hubert, sew, sew_d, unispeech, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer. + [[autodoc]] onnxruntime.ORTModelForAudioClassification ### ORTModelForAudioFrameClassification +This class officially supports data2vec_audio, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer. + [[autodoc]] onnxruntime.ORTModelForAudioFrameClassification ### ORTModelForCTC +This class officially supports data2vec_audio, hubert, sew, sew_d, unispeech, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer. + [[autodoc]] onnxruntime.ORTModelForCTC ### ORTModelForSpeechSeq2Seq +This class officially supports whisper, speech_to_text. + [[autodoc]] onnxruntime.ORTModelForSpeechSeq2Seq ### ORTModelForAudioXVector +This class officially supports data2vec_audio, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer. + [[autodoc]] onnxruntime.ORTModelForAudioXVector ## Multimodal @@ -94,8 +122,16 @@ The following ORT classes are available for the following multimodal tasks. ### ORTModelForVision2Seq +This class officially supports trocr and vision-encoder-decoder. + [[autodoc]] onnxruntime.ORTModelForVision2Seq +### ORTModelForPix2Struct + +This class officially supports pix2struct. + +[[autodoc]] onnxruntime.ORTModelForPix2Struct + ## Custom Tasks The following ORT classes are available for the following custom tasks. @@ -129,4 +165,4 @@ The following ORT classes are available for the following custom tasks. #### ORTStableDiffusionXLImg2ImgPipeline -[[autodoc]] onnxruntime.ORTStableDiffusionXLImg2ImgPipeline \ No newline at end of file +[[autodoc]] onnxruntime.ORTStableDiffusionXLImg2ImgPipeline diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 6cefc7c571..9b58a8836e 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -24,7 +24,7 @@ from ...commands.export.onnx import parse_args_onnx from ...utils import DEFAULT_DUMMY_SHAPES, ONNX_WEIGHTS_NAME, logging -from ...utils.save_utils import maybe_save_preprocessors +from ...utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from ..error_utils import AtolError, OutputMatchError, ShapeError from ..tasks import TasksManager from .base import OnnxConfigWithPast @@ -43,7 +43,7 @@ if is_torch_available(): import torch -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union if TYPE_CHECKING: @@ -62,6 +62,7 @@ def _get_submodels_and_onnx_configs( custom_onnx_configs: Dict, custom_architecture: bool, fn_get_submodels: Optional[Callable] = None, + preprocessors: Optional[List[Any]] = None, ): is_stable_diffusion = "stable-diffusion" in task if not custom_architecture: @@ -72,7 +73,7 @@ def _get_submodels_and_onnx_configs( onnx_config_constructor = TasksManager.get_exporter_config_constructor( model=model, exporter="onnx", task=task ) - onnx_config = onnx_config_constructor(model.config) + onnx_config = onnx_config_constructor(model.config, preprocessors=preprocessors) if ( model.config.is_encoder_decoder @@ -359,6 +360,11 @@ def main_export( possible_synonyms = "" logger.info(f"Automatic task detection to {task}{possible_synonyms}.") + # The preprocessors are loaded as they may be useful to export the model. Notably, some of the static input shapes may be stored in the + # preprocessors config. + preprocessors = maybe_load_preprocessors( + model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code + ) onnx_config, models_and_onnx_configs = _get_submodels_and_onnx_configs( model=model, task=task, @@ -366,6 +372,7 @@ def main_export( custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {}, custom_architecture=custom_architecture, fn_get_submodels=fn_get_submodels, + preprocessors=preprocessors, ) if not is_stable_diffusion: diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 3c50389726..09ea0b4466 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -176,7 +176,9 @@ class OnnxConfig(ExportConfig, ABC): ), } - def __init__(self, config: "PretrainedConfig", task: str = "feature-extraction"): + def __init__( + self, config: "PretrainedConfig", task: str = "feature-extraction", preprocessors: Optional[List[Any]] = None + ): if task not in self._TASK_TO_COMMON_OUTPUTS: raise ValueError( f"{task} is not a supported task, supported tasks: {', '.join(self._TASK_TO_COMMON_OUTPUTS.keys())}" @@ -184,6 +186,7 @@ def __init__(self, config: "PretrainedConfig", task: str = "feature-extraction") self.task = task self._config = config + self._preprocessors = preprocessors self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]: @@ -493,6 +496,7 @@ def __init__( use_past: bool = False, use_past_in_inputs: Optional[bool] = None, use_present_in_outputs: Optional[bool] = None, + preprocessors: Optional[List[Any]] = None, ): self.use_past = use_past if use_past_in_inputs is None: @@ -515,10 +519,12 @@ def __init__( ) self.is_merged = False self.use_cache_branch = None - super().__init__(config, task=task) + super().__init__(config, task=task, preprocessors=preprocessors) @classmethod - def with_past(cls, config: "PretrainedConfig", task: str = "feature-extraction") -> "OnnxConfigWithPast": + def with_past( + cls, config: "PretrainedConfig", task: str = "feature-extraction", preprocessors: Optional[List[Any]] = None + ) -> "OnnxConfigWithPast": """ Instantiates a [`~optimum.exporters.onnx.OnnxConfig`] with `use_past` attribute set to `True`. @@ -531,7 +537,7 @@ def with_past(cls, config: "PretrainedConfig", task: str = "feature-extraction") Returns: [`~optimum.exporters.onnx.OnnxConfig`]: The onnx config with `.use_past = True` """ - return cls(config, task=task, use_past=True) + return cls(config, task=task, use_past=True, preprocessors=preprocessors) @property def outputs(self) -> Dict[str, Dict[int, str]]: @@ -564,24 +570,9 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): input_was_inserted = False for dummy_input_gen in dummy_inputs_generators: if dummy_input_gen.supports_input(input_name): - # models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name - # while models from TextDecoderOnnxConfig use input_ids, hence the check for both - if ( - self.use_past is True - and self.use_cache_branch is not False - and input_name in ["decoder_input_ids", "input_ids"] - ): - sequence_length = dummy_input_gen.sequence_length - if "sequence_length" in kwargs and kwargs["sequence_length"] != 1: - logger.info( - f"Asked a sequence length of {kwargs['sequence_length']}, but a sequence length of 1 " - f"will be used with use_past == True for `{input_name}`." - ) - dummy_input_gen.sequence_length = 1 - dummy_inputs[input_name] = dummy_input_gen.generate(input_name, framework=framework) - dummy_input_gen.sequence_length = sequence_length - else: - dummy_inputs[input_name] = dummy_input_gen.generate(input_name, framework=framework) + dummy_inputs[input_name] = self.overwrite_shape_and_generate_input( + dummy_input_gen, input_name, framework, input_shapes=kwargs + ) input_was_inserted = True break if not input_was_inserted: @@ -617,6 +608,35 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): return dummy_inputs + def overwrite_shape_and_generate_input( + self, dummy_input_gen: "DummyInputGenerator", input_name: str, framework: str, input_shapes: Dict + ): + """ + The shape passed to the dummy input generator may not always be correct for all of the inputs it manages. This method allows + to overwrite some shapes, and generate the dummy input. This should probably be refactored more elegantly. + """ + + # models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name + # while models from TextDecoderOnnxConfig use input_ids, hence the check for both + if ( + self.use_past is True + and self.use_cache_branch is not False + and input_name in ["decoder_input_ids", "input_ids"] + ): + sequence_length = dummy_input_gen.sequence_length + if "sequence_length" in input_shapes and input_shapes["sequence_length"] != 1: + logger.info( + f"Asked a sequence length of {input_shapes['sequence_length']}, but a sequence length of 1 " + f"will be used with use_past == True for `{input_name}`." + ) + dummy_input_gen.sequence_length = 1 + dummy_input = dummy_input_gen.generate(input_name, framework=framework) + dummy_input_gen.sequence_length = sequence_length + else: + dummy_input = dummy_input_gen.generate(input_name, framework=framework) + + return dummy_input + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): """ Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction. @@ -703,6 +723,7 @@ def __init__( use_past_in_inputs: Optional[bool] = None, use_present_in_outputs: Optional[bool] = None, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, + preprocessors: Optional[List[Any]] = None, ): super().__init__( config, @@ -710,6 +731,7 @@ def __init__( use_past=use_past, use_past_in_inputs=use_past_in_inputs, use_present_in_outputs=use_present_in_outputs, + preprocessors=preprocessors, ) self._behavior = behavior self.override_attributes_for_behavior() @@ -746,6 +768,7 @@ def with_behavior( task=self.task, use_past=use_past, behavior=behavior, + preprocessors=self._preprocessors, ) @property diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index e8aef99649..28d32a55fb 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -278,6 +278,7 @@ def __init__( use_past_in_inputs: Optional[bool] = None, use_present_in_outputs: Optional[bool] = None, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, + preprocessors: Optional[List[Any]] = None, ): super().__init__( config, @@ -286,6 +287,7 @@ def __init__( use_past_in_inputs=use_past_in_inputs, use_present_in_outputs=use_present_in_outputs, behavior=behavior, + preprocessors=preprocessors, ) from ..tasks import TasksManager @@ -296,7 +298,7 @@ def __init__( encoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor( exporter="onnx", task="feature-extraction", model_type=config.encoder.model_type ) - self._encoder_onnx_config = encoder_onnx_config_constructor(config.encoder) + self._encoder_onnx_config = encoder_onnx_config_constructor(config.encoder, preprocessors=preprocessors) self._normalized_config.ENCODER_NORMALIZED_CONFIG_CLASS = self._encoder_onnx_config._normalized_config if self._behavior is not ConfigBehavior.ENCODER: @@ -316,7 +318,9 @@ def __init__( "past key values." ) - self._decoder_onnx_config = decoder_onnx_config_constructor(config.decoder, **kwargs) + self._decoder_onnx_config = decoder_onnx_config_constructor( + config.decoder, preprocessors=preprocessors, **kwargs + ) if issubclass(decoder_onnx_config_constructor.func, OnnxSeq2SeqConfigWithPast): self._decoder_onnx_config = self._decoder_onnx_config.with_behavior( self._behavior, use_past=kwargs["use_past"] diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 4d41ea1e73..eb5824316b 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -852,8 +852,10 @@ class OwlViTOnnxConfig(CLIPOnnxConfig): ATOL_FOR_VALIDATION = 1e-4 MIN_TORCH_VERSION = version.parse("2.1") - def __init__(self, config: "PretrainedConfig", task: str = "feature-extraction"): - super().__init__(config, task) + def __init__( + self, config: "PretrainedConfig", task: str = "feature-extraction", preprocessors: Optional[List[Any]] = None + ): + super().__init__(config, task, preprocessors=preprocessors) if task == "zero-shot-object-detection": logger.warning( "The batch size of this model will not be dynamic because non-maximum suppression is performed. " @@ -957,8 +959,10 @@ class PerceiverOnnxConfig(TextAndVisionOnnxConfig): PerceiverDummyInputGenerator, ) + TextAndVisionOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES - def __init__(self, config: "PretrainedConfig", task: str = "feature-extraction"): - super().__init__(config, task=task) + def __init__( + self, config: "PretrainedConfig", task: str = "feature-extraction", preprocessors: Optional[List[Any]] = None + ): + super().__init__(config, task=task, preprocessors=preprocessors) self.is_generating_dummy_inputs = False @property @@ -1177,8 +1181,11 @@ def __init__( use_past_in_inputs: Optional[bool] = None, use_present_in_outputs: Optional[bool] = None, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, + preprocessors: Optional[List[Any]] = None, ): - super().__init__(config, task, use_past, use_past_in_inputs, use_present_in_outputs, behavior) + super().__init__( + config, task, use_past, use_past_in_inputs, use_present_in_outputs, behavior, preprocessors=preprocessors + ) # TODO: Check modeling code to fix the issue with use_cache for trocr if config.decoder.model_type == "trocr": @@ -1215,8 +1222,10 @@ class SamOnnxConfig(OnnxConfig): DEFAULT_ONNX_OPSET = 12 # einsum op not supported with opset 11 MIN_TORCH_VERSION = version.parse("2.0.99") # See: https://github.com/huggingface/optimum/pull/1301 - def __init__(self, config: "PretrainedConfig", task: str = "feature-extraction"): - super().__init__(config, task) + def __init__( + self, config: "PretrainedConfig", task: str = "feature-extraction", preprocessors: Optional[List[Any]] = None + ): + super().__init__(config, task, preprocessors=preprocessors) self._normalized_config.ENCODER_NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig(self._config.vision_config) @property @@ -1261,10 +1270,10 @@ class Pix2StructOnnxConfig(OnnxSeq2SeqConfigWithPast): @property def inputs(self): common_inputs = {} - common_inputs["attention_mask"] = {0: "batch_size", 1: "max_patches"} + common_inputs["attention_mask"] = {0: "batch_size"} if self._behavior is not ConfigBehavior.DECODER: - common_inputs["flattened_patches"] = {0: "batch_size", 1: "max_patches", 2: "patch_size"} + common_inputs["flattened_patches"] = {0: "batch_size"} if self._behavior is not ConfigBehavior.ENCODER: if self.use_past_in_inputs: @@ -1276,12 +1285,46 @@ def inputs(self): if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") - common_inputs["encoder_outputs"] = {0: "batch_size", 1: "max_patches"} + common_inputs["encoder_outputs"] = {0: "batch_size"} + # Contrary to other seq2seq archs as t5 and bart, Pix2Struct DO make use of the decoder_attention_mask input. common_inputs["decoder_attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"} return common_inputs + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if self._behavior is ConfigBehavior.ENCODER: + common_outputs = { + "last_hidden_state": {0: "batch_size"} + } # The last hidden state dim=1 is constant, no need for it to be dynamic. + else: + common_outputs = super(OnnxConfigWithPast, self).outputs + + # Renaming the outputs axes properly. + for name, axes_names in common_outputs.items(): + if self._behavior is ConfigBehavior.ENCODER or "encoder" in name: + sequence_name = "encoder_sequence_length" + else: + sequence_name = "decoder_sequence_length" + + new_axes_names = {} + for axis_idx, axis_name in axes_names.items(): + if "sequence" in axis_name: + if self.use_past_in_inputs is False or self.is_merged is True: + new_axes_names[axis_idx] = sequence_name + else: + # Trick to force it since ONNX sometimes infer a dynamic axis where it's not. + new_axes_names[axis_idx] = "1" + else: + new_axes_names[axis_idx] = axis_name + common_outputs[name] = new_axes_names + + if self.use_present_in_outputs: + self.add_past_key_values(common_outputs, direction="outputs") + + return common_outputs + @property def torch_to_onnx_input_map(self) -> Dict[str, str]: if self._behavior is ConfigBehavior.DECODER: @@ -1311,3 +1354,59 @@ def generate_dummy_inputs_for_validation( reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0] return super().generate_dummy_inputs_for_validation(reference_model_inputs) + + def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: + dummy_inputs_generators = [] + dummy_inputs_generators.append(self.DUMMY_INPUT_GENERATOR_CLASSES[0](self.task, self._normalized_config)) + + if self._preprocessors is None or len(self._preprocessors) != 2: + raise ValueError( + f"Preprocessors for pix2struct need to be available for the ONNX export to infer input static shapes. Got: {self._preprocessors}" + ) + + encoder_sequence_length = self._preprocessors[1].image_processor.max_patches + # A hack for DummyPix2StructInputGenerator to gain access to the preprocessors. + # TODO: we should probably pass preprocessors to all dummy input generators. + kwargs["preprocessors"] = self._preprocessors + for cls_ in self.DUMMY_INPUT_GENERATOR_CLASSES[1:]: + dummy_inputs_generators.append( + cls_(self.task, self._normalized_config, encoder_sequence_length=encoder_sequence_length, **kwargs) + ) + + return dummy_inputs_generators + + def overwrite_shape_and_generate_input( + self, dummy_input_gen: "DummyInputGenerator", input_name: str, framework: str, input_shapes: Dict + ): + if self._preprocessors is None or len(self._preprocessors) != 2: + raise ValueError( + f"Preprocessors for pix2struct need to be available for the ONNX export to infer input static shapes. Got: {self._preprocessors}" + ) + + # models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name + # while models from TextDecoderOnnxConfig use input_ids, hence the check for both + if ( + self.use_past is True + and self.use_cache_branch is not False + and input_name in ["decoder_input_ids", "input_ids"] + ): + sequence_length = dummy_input_gen.sequence_length + if "sequence_length" in input_shapes and input_shapes["sequence_length"] != 1: + logger.info( + f"Asked a sequence length of {input_shapes['sequence_length']}, but a sequence length of 1 " + f"will be used with use_past == True for `{input_name}`." + ) + dummy_input_gen.sequence_length = 1 + dummy_input = dummy_input_gen.generate(input_name, framework=framework) + dummy_input_gen.sequence_length = sequence_length + elif input_name in ["encoder_outputs", "attention_mask"]: + # pix2struct takes inputs whose so-called sequence length is **static** to max_patches, so we do NOT use + # the passed sequence_length that behaves as a dynamic shape. + original_seq_length = dummy_input_gen.sequence_length + dummy_input_gen.sequence_length = self._preprocessors[1].image_processor.max_patches + dummy_input = dummy_input_gen.generate(input_name, framework=framework) + dummy_input_gen.sequence_length = original_seq_length + else: + dummy_input = dummy_input_gen.generate(input_name, framework=framework) + + return dummy_input diff --git a/optimum/gptq/utils.py b/optimum/gptq/utils.py index b7387561c2..a5f9afdaae 100644 --- a/optimum/gptq/utils.py +++ b/optimum/gptq/utils.py @@ -72,7 +72,7 @@ def get_block_name_with_pattern(model: nn.Module): modules_names = [n for n, _ in model.named_modules()] for pattern_candidate in BLOCK_PATTERNS: pattern_candidate = pattern_candidate - if any([pattern_candidate in name for name in modules_names]): + if any(pattern_candidate in name for name in modules_names): return pattern_candidate raise ValueError("Block pattern could not be match. Pass `block_name_to_quantize` argument in `quantize_model`") @@ -105,7 +105,7 @@ def get_device(obj: Union[torch.Tensor, nn.Module]): def get_seqlen(model: nn.Module): if hasattr(model, "config"): model_config = model.config.to_dict() - if any([k in model_config for k in SEQLEN_KEYS_TRANFORMERS]): + if any(k in model_config for k in SEQLEN_KEYS_TRANFORMERS): for key in SEQLEN_KEYS_TRANFORMERS: if key in model_config: return model_config[key] diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index 62e32cfe71..c6bb12916a 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -45,7 +45,12 @@ "ORTModelForSequenceClassification", "ORTModelForTokenClassification", ], - "modeling_seq2seq": ["ORTModelForSeq2SeqLM", "ORTModelForSpeechSeq2Seq", "ORTModelForVision2Seq"], + "modeling_seq2seq": [ + "ORTModelForSeq2SeqLM", + "ORTModelForSpeechSeq2Seq", + "ORTModelForVision2Seq", + "ORTModelForPix2Struct", + ], "modeling_decoder": ["ORTModelForCausalLM"], "optimization": ["ORTOptimizer"], "quantization": ["ORTQuantizer"], @@ -104,7 +109,12 @@ ORTModelForSequenceClassification, ORTModelForTokenClassification, ) - from .modeling_seq2seq import ORTModelForSeq2SeqLM, ORTModelForSpeechSeq2Seq + from .modeling_seq2seq import ( + ORTModelForPix2Struct, + ORTModelForSeq2SeqLM, + ORTModelForSpeechSeq2Seq, + ORTModelForVision2Seq, + ) from .optimization import ORTOptimizer from .quantization import ORTQuantizer from .trainer import ORTTrainer diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index f9091650b2..b87b3ccaf0 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -549,6 +549,7 @@ def forward( self, input_ids: torch.LongTensor, encoder_hidden_states: torch.FloatTensor, + decoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, @@ -572,7 +573,7 @@ def forward( input_ids, past_key_values, use_torch=use_torch ) - if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding: + if self.parent_model.use_io_binding: known_output_shapes = self.compute_past_key_values_output_shapes( input_ids, encoder_hidden_states, @@ -587,19 +588,22 @@ def forward( if "encoder_hidden_states" in self.input_names: model_inputs.append(encoder_hidden_states) + if "decoder_attention_mask" in self.input_names: + model_inputs.append(decoder_attention_mask) + if "encoder_attention_mask" in self.input_names: model_inputs.append(encoder_attention_mask) if past_key_values is not None: model_inputs += past_key_values - if use_cache_branch_tensor is not None: - model_inputs.append(use_cache_branch_tensor) - if "labels" in self.input_names: model_inputs.append(labels) known_output_shapes.update({"loss": []}) + if use_cache_branch_tensor is not None: + model_inputs.append(use_cache_branch_tensor) + io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding( self.session, *model_inputs, @@ -670,14 +674,18 @@ def forward( "input_ids": input_ids.cpu().detach().numpy(), } - # Add the encoder_attention_mask inputs when needed - if "encoder_attention_mask" in self.input_names: - onnx_inputs["encoder_attention_mask"] = encoder_attention_mask.cpu().detach().numpy() - # Add the encoder_hidden_states inputs when needed if "encoder_hidden_states" in self.input_names: onnx_inputs["encoder_hidden_states"] = encoder_hidden_states.cpu().detach().numpy() + # Add the decoder_attention_mask inputs when needed + if "decoder_attention_mask" in self.input_names: + onnx_inputs["decoder_attention_mask"] = decoder_attention_mask.cpu().detach().numpy() + + # Add the encoder_attention_mask inputs when needed + if "encoder_attention_mask" in self.input_names: + onnx_inputs["encoder_attention_mask"] = encoder_attention_mask.cpu().detach().numpy() + if past_key_values is not None: # Add the past_key_values to the decoder inputs for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): @@ -694,14 +702,18 @@ def forward( "input_ids": input_ids, } - # Add the encoder_attention_mask inputs when needed - if "encoder_attention_mask" in self.input_names: - onnx_inputs["encoder_attention_mask"] = encoder_attention_mask - # Add the encoder_hidden_states inputs when needed if "encoder_hidden_states" in self.input_names: onnx_inputs["encoder_hidden_states"] = encoder_hidden_states + # Add the decoder_attention_mask inputs when needed + if "decoder_attention_mask" in self.input_names: + onnx_inputs["decoder_attention_mask"] = decoder_attention_mask + + # Add the encoder_attention_mask inputs when needed + if "encoder_attention_mask" in self.input_names: + onnx_inputs["encoder_attention_mask"] = encoder_attention_mask + if past_key_values is not None: # Add the past_key_values to the decoder inputs for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 2908a1c803..c436a900cb 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -31,6 +31,7 @@ AutoModelForSpeechSeq2Seq, AutoModelForVision2Seq, GenerationConfig, + Pix2StructForConditionalGeneration, # Pix2struct does not support AutoModel WhisperForConditionalGeneration, ) from transformers.file_utils import add_start_docstrings_to_model_forward @@ -99,6 +100,13 @@ Features extracted from an Image. This tensor should be of shape `(batch_size, num_channels, height, width)`. """ +PIX2STRUCT_INPUTS_DOCSTRING = r""" + Args: + flattened_patches (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels x patch_height x patch_width)`): + Flattened and padded pixel values. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Mask to avoid performing attention on padding pixel values. +""" DECODER_INPUTS_DOCSTRING = r""" Args: @@ -133,7 +141,6 @@ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. """ - SPEECH_SEQ2SEQ_ONNX_MODEL_DOCSTRING = r""" Args: input_features (`torch.FloatTensor`): @@ -166,9 +173,36 @@ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. """ +PIX2STRUCT_ONNX_MODEL_DOCSTRING = r""" + Args: + flattened_patches (`torch.FloatTensor` of shape `(batch_size, seq_length, hidden_size)`): + Flattened pixel patches. the `hidden_size` is obtained by the following formula: `hidden_size` = + `num_channels` * `patch_size` * `patch_size` + The process of flattening the pixel patches is done by `Pix2StructProcessor`. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + Pix2StructText uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)` + Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding. + The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape + `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)` and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. +""" + _TOKENIZER_FOR_DOC = "AutoTokenizer" _PROCESSOR_FOR_DOC = "AutoProcessor" -_IMAGE_PROCESSOER_FOR_DOC = "AutoImageProcessor" +_IMAGE_PROCESSER_FOR_DOC = "AutoImageProcessor" TRANSLATION_EXAMPLE = r""" Example of text generation: @@ -281,6 +315,28 @@ ``` """ +PIX2STRUCT_EXAMPLE = r""" + Example of pix2struct: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.onnxruntime import {model_class} + >>> from PIL import Image + >>> import requests + + >>> processor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}", export=True, use_io_binding=True) + + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud" + >>> inputs = processor(images=image, text=question, return_tensors="pt") + + >>> gen_tokens = model.generate(**inputs) + >>> outputs = processor.batch_decode(gen_tokens, skip_special_tokens=True) + ``` +""" + class ORTEncoderForSpeech(ORTEncoder): """ @@ -385,6 +441,64 @@ def forward( return BaseModelOutput(last_hidden_state=last_hidden_state) +class ORTEncoderForPix2Struct(ORTEncoder): + """ + Encoder model for ONNX Runtime inference for Pix2Struct. + + Args: + session (`ort.InferenceSession`): + The ONNX Runtime inference session associated to the encoder. + """ + + @add_start_docstrings_to_model_forward(PIX2STRUCT_INPUTS_DOCSTRING) + def forward( + self, + flattened_patches: torch.FloatTensor, + attention_mask: torch.LongTensor, + **kwargs, + ) -> BaseModelOutput: + use_torch = isinstance(flattened_patches, torch.Tensor) + self.parent_model.raise_on_numpy_input_io_binding(use_torch) + + if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding: + model_inputs = ( + [flattened_patches, attention_mask] if "attention_mask" in self.input_names else [flattened_patches] + ) + io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding( + self.session, + *model_inputs, + ordered_input_names=self._ordered_input_names, + ) + + io_binding.synchronize_inputs() + self.session.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) + else: + if use_torch: + onnx_inputs = {"flattened_patches": flattened_patches.cpu().detach().numpy()} + if "attention_mask" in self.input_names: + onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy() + else: + onnx_inputs = {"flattened_patches": flattened_patches} + if "attention_mask" in self.input_names: + onnx_inputs["attention_mask"] = attention_mask + + if "attention_mask" in self.input_names: + if self.session.get_inputs()[1].type == "tensor(int64)": + onnx_inputs["attention_mask"] = onnx_inputs["attention_mask"].astype(np.int64) + + outputs = self.session.run(None, onnx_inputs) + + last_hidden_state = outputs[self.output_names["last_hidden_state"]] + + if use_torch: + last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device) + + return BaseModelOutput(last_hidden_state=last_hidden_state) + + class ORTModelForConditionalGeneration(ORTModel, ABC): """ Sequence-to-sequence model with a language modeling head for ONNX Runtime inference. @@ -982,7 +1096,7 @@ def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder: return ORTEncoder(session, self) @add_start_docstrings_to_model_forward( - SEQ2SEQ_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length") + SEQ2SEQ_ONNX_MODEL_DOCSTRING + TRANSLATION_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="ORTModelForSeq2SeqLM", @@ -1119,7 +1233,7 @@ def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder: return ORTEncoderForSpeech(session, self) @add_start_docstrings_to_model_forward( - SPEECH_SEQ2SEQ_ONNX_MODEL_DOCSTRING.format("batch_size, feature_size, sequence_length") + SPEECH_SEQ2SEQ_ONNX_MODEL_DOCSTRING + AUTOMATIC_SPEECH_RECOGNITION_EXAMPLE.format( processor_class=_PROCESSOR_FOR_DOC, model_class="ORTModelForSpeechSeq2Seq", @@ -1307,9 +1421,9 @@ def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder: return ORTEncoderForVisionEncoderDecoder(session, self) @add_start_docstrings_to_model_forward( - VISION_ENCODER_DECODER_SEQ2SEQ_ONNX_MODEL_DOCSTRING.format("batch_size, num_channels, height, width") + VISION_ENCODER_DECODER_SEQ2SEQ_ONNX_MODEL_DOCSTRING + IMAGE_TO_TEXT_EXAMPLE.format( - processor_class=_IMAGE_PROCESSOER_FOR_DOC, + processor_class=_IMAGE_PROCESSER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC, model_class="ORTModelForVision2Seq", checkpoint="nlpconnect/vit-gpt2-image-captioning", @@ -1394,3 +1508,127 @@ def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: def can_generate(self): """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" return True + + +class ORTModelForPix2Struct(ORTModelForConditionalGeneration, GenerationMixin): + """ + Pix2struct model with a language modeling head for ONNX Runtime inference. + """ + + # pix2struct cannot be loaded using AutoModel + auto_model_class = Pix2StructForConditionalGeneration + main_input_name = "flattened_patches" + + def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder: + return ORTEncoderForPix2Struct(session, self) + + @add_start_docstrings_to_model_forward( + PIX2STRUCT_ONNX_MODEL_DOCSTRING + + PIX2STRUCT_EXAMPLE.format( + processor_class=_PROCESSOR_FOR_DOC, + model_class="ORTModelForPix2Struct", + checkpoint="google/pix2struct-ai2d-base", + ) + ) + def forward( + self, + flattened_patches: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Seq2SeqLMOutput: + # Encode if needed : first prediction pass + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + flattened_patches=flattened_patches, + attention_mask=attention_mask, + ) + + # TODO: for some reason the attention_mask for pix2struct is a float in transformers and not an int64. This messes up with the exporter + # hardcodes int64 input dtype for the attention mask. This workaround is quite ugly, it should be fixed rather in the ONNX exporter. + if isinstance(attention_mask, torch.Tensor): + attention_mask = attention_mask.to(torch.int64) + else: + attention_mask = attention_mask.astype(np.int64) + + # Decode + if past_key_values is None or self.use_cache is False: + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + labels=labels, + ) + elif self.use_merged is True: + decoder_outputs = self.decoder( + input_ids=decoder_input_ids[:, -1:], + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + labels=labels, + ) + else: + decoder_outputs = self.decoder_with_past( + input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + labels=labels, + ) + + return Seq2SeqLMOutput( + loss=decoder_outputs.get("loss", None), + logits=decoder_outputs.logits, + past_key_values=decoder_outputs.past_key_values, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + flattened_patches: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ) -> Dict: + if decoder_attention_mask is None: + decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device) + + return { + "flattened_patches": flattened_patches, + "decoder_input_ids": input_ids, + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def get_encoder(self) -> ORTEncoder: + return self.encoder + + # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache + @staticmethod + def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: + ORTModelForSeq2SeqLM._reorder_cache(past, beam_idx) + + def can_generate(self): + """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" + return True diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 1ff2e4bf2f..f4ff908847 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -713,18 +713,20 @@ def __init__( self, task: str, normalized_config: NormalizedConfig, + preprocessors: List[Any], batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], - patch_height: int = 16, - patch_width: int = 16, - max_patches: int = DEFAULT_DUMMY_SHAPES["sequence_length"], num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], **kwargs, ): self.task = task self.batch_size = batch_size + + # looking for static shapes in Pix2StructProcessor + patch_height = preprocessors[1].image_processor.patch_size["height"] + patch_width = preprocessors[1].image_processor.patch_size["width"] self.flattened_patch_size = 2 + patch_height * patch_width * num_channels - self.max_patches = max_patches + self.max_patches = preprocessors[1].image_processor.max_patches def generate(self, input_name: str, framework: str = "pt"): shape = [self.batch_size, self.max_patches, self.flattened_patch_size] diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 6da01ff8de..c5f3d5ce4c 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -109,6 +109,11 @@ def __getattr__(self, attr_name): return super().__getattr__(attr_name) +Pix2StructNormalizedTextConfig = NormalizedTextAndVisionConfig.with_args( + text_config="text_config", vision_config="vision_config" +) + + class NormalizedEncoderDecoderConfig(NormalizedConfig): ENCODER_NORMALIZED_CONFIG_CLASS = None DECODER_NORMALIZED_CONFIG_CLASS = None @@ -230,7 +235,7 @@ class NormalizedConfigManager: "nystromformer": NormalizedTextConfig, "opt": NormalizedTextConfig, "pegasus": BartLikeNormalizedTextConfig, - "pix2struct": NormalizedVisionConfig, + "pix2struct": Pix2StructNormalizedTextConfig, "poolformer": NormalizedVisionConfig, "regnet": NormalizedVisionConfig, "resnet": NormalizedVisionConfig, diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 5a8cd52a75..f28a3676be 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -53,6 +53,7 @@ AutoModelForVision2Seq, AutoTokenizer, MBartForConditionalGeneration, + Pix2StructForConditionalGeneration, # Pix2Struct does not work with AutoModel PretrainedConfig, set_seed, ) @@ -80,6 +81,7 @@ ORTModelForImageClassification, ORTModelForMaskedLM, ORTModelForMultipleChoice, + ORTModelForPix2Struct, ORTModelForQuestionAnswering, ORTModelForSemanticSegmentation, ORTModelForSeq2SeqLM, @@ -4263,6 +4265,327 @@ def test_compare_to_io_binding(self, *args, **kwargs): gc.collect() +class ORTModelForPix2StructTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = ["pix2struct"] + + FULL_GRID = { + "model_arch": SUPPORTED_ARCHITECTURES, + "use_cache": [False, True], + "use_merged": [False, True], + } + + ORTMODEL_CLASS = ORTModelForPix2Struct + TASK = "image-to-text" # is it fine as well with visual-question-answering? + + GENERATION_LENGTH = 100 + SPEEDUP_CACHE = 1.1 + + IMAGE = Image.open( + requests.get( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg", + stream=True, + ).raw + ) + + def test_load_vanilla_transformers_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = ORTModelForPix2Struct.from_pretrained(MODEL_NAMES["bert"], export=True) + + self.assertIn("Unrecognized configuration class", str(context.exception)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_merge_from_transformers_and_save(self, model_arch): + model_id = MODEL_NAMES[model_arch] + model = ORTModelForPix2Struct.from_pretrained(model_id, export=True, use_merged=True) + with tempfile.TemporaryDirectory() as tmpdir: + model.save_pretrained(tmpdir) + save_path = os.path.join(tmpdir, ONNX_DECODER_MERGED_NAME) + self.assertTrue(has_onnx_input(save_path, "use_cache_branch")) + + folder_contents = os.listdir(tmpdir) + self.assertTrue(ONNX_ENCODER_NAME in folder_contents) + self.assertTrue(ONNX_DECODER_NAME not in folder_contents) + self.assertTrue(ONNX_DECODER_WITH_PAST_NAME not in folder_contents) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_merge_from_onnx_and_save(self, model_arch): + model_id = MODEL_NAMES[model_arch] + task = "image-to-text-with-past" + + with tempfile.TemporaryDirectory() as tmpdir: + main_export(model_id, tmpdir, task=task) + + model = ORTModelForPix2Struct.from_pretrained(tmpdir) + + self.assertTrue(model.use_merged) + self.assertTrue(model.decoder_with_past is None) + + model.save_pretrained(tmpdir + "_save") + save_path = os.path.join(tmpdir + "_save", ONNX_DECODER_MERGED_NAME) + self.assertTrue(has_onnx_input(save_path, "use_cache_branch")) + + folder_contents = os.listdir(tmpdir + "_save") + self.assertTrue(ONNX_ENCODER_NAME in folder_contents) + self.assertFalse(ONNX_DECODER_NAME in folder_contents) + self.assertFalse(ONNX_DECODER_WITH_PAST_NAME in folder_contents) + + @parameterized.expand(grid_parameters(FULL_GRID)) + def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): + if use_cache is False and use_merged is True: + self.skipTest("use_cache=False, use_merged=True are uncompatible") + + if use_cache is False: + self.skipTest("skip") + + model_args = { + "test_name": test_name, + "model_arch": model_arch, + "use_cache": use_cache, + "use_merged": use_merged, + } + self._setup(model_args) + + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForPix2Struct.from_pretrained(self.onnx_model_dirs[test_name], use_cache=use_cache) + + self.assertIsInstance(onnx_model.encoder, ORTEncoder) + if use_merged is False: + model_path = Path(self.onnx_model_dirs[test_name], ONNX_DECODER_NAME) + self.assertFalse(has_onnx_input(model_path, "use_cache_branch")) + self.assertEqual(onnx_model.use_merged, False) + else: + model_path = Path(self.onnx_model_dirs[test_name], ONNX_DECODER_MERGED_NAME) + self.assertTrue(has_onnx_input(model_path, "use_cache_branch")) + self.assertEqual(onnx_model.use_merged, True) + + self.assertIsInstance(onnx_model.decoder, ORTDecoder) + if onnx_model.use_cache is True and onnx_model.use_merged is False: + self.assertIsInstance(onnx_model.decoder_with_past, ORTDecoder) + if onnx_model.use_cache is True and onnx_model.use_merged is True: + self.assertTrue(onnx_model.decoder_with_past is None) + + self.assertIsInstance(onnx_model.config, PretrainedConfig) + + set_seed(SEED) + questions = [ + "Who am I?", + "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud and this is long long very long and super long my dear", + ] + + transformers_model = Pix2StructForConditionalGeneration.from_pretrained(model_id) + preprocessor = get_preprocessor(model_id) + + inputs = preprocessor(images=[self.IMAGE, self.IMAGE], text=questions, padding=True, return_tensors="pt") + del inputs["decoder_attention_mask"] + del inputs["decoder_input_ids"] + + decoder_start_token_id = transformers_model.config.decoder_start_token_id + decoder_inputs = { + "decoder_input_ids": torch.ones((2, 1), dtype=torch.long) * decoder_start_token_id, + "decoder_attention_mask": torch.ones((2, 1), dtype=torch.int64), + } + + with torch.no_grad(): + transformers_outputs = transformers_model(**inputs, **decoder_inputs) + + for input_type in ["pt", "np"]: + inputs = preprocessor( + images=[self.IMAGE, self.IMAGE], text=questions, padding=True, return_tensors=input_type + ) + del inputs["decoder_attention_mask"] + del inputs["decoder_input_ids"] + + if input_type == "np": + decoder_inputs = { + "decoder_input_ids": np.ones((2, 1), dtype=np.int64) * decoder_start_token_id, + "decoder_attention_mask": np.ones((2, 1), dtype=np.int64), + } + + onnx_outputs = onnx_model(**inputs, **decoder_inputs) + + self.assertTrue("logits" in onnx_outputs) + self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type]) + + self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-4)) + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @pytest.mark.gpu_test # mark as GPU test as well to run the without/with cache timing test on the slow tests + def test_compare_with_and_without_past_key_values(self, model_arch: str): + if model_arch == "m2m_100": + return # TODO: this test is failing for m2m_100 + model_args = {"test_name": model_arch + "_False", "model_arch": model_arch, "use_cache": False} + self._setup(model_args) + model_args = {"test_name": model_arch + "_True", "model_arch": model_arch, "use_cache": True} + self._setup(model_args) + + model_id = MODEL_NAMES[model_arch] + preprocessor = get_preprocessor(model_id) + + question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud" + inputs = preprocessor(images=self.IMAGE, text=question, return_tensors="pt") + del inputs["decoder_attention_mask"] + del inputs["decoder_input_ids"] + + model_with_pkv = ORTModelForPix2Struct.from_pretrained( + self.onnx_model_dirs[model_arch + "_True"], use_cache=True + ) + + _ = model_with_pkv.generate(**inputs) # warmup + with Timer() as with_pkv_timer: + outputs_model_with_pkv = model_with_pkv.generate( + **inputs, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1 + ) + + model_without_pkv = ORTModelForPix2Struct.from_pretrained( + self.onnx_model_dirs[model_arch + "_False"], use_cache=False + ) + _ = model_without_pkv.generate(**inputs) # warmup + with Timer() as without_pkv_timer: + outputs_model_without_pkv = model_without_pkv.generate( + **inputs, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1 + ) + + self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) + self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH + 1) + self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH + 1) + + if os.environ.get("TEST_LEVEL", 0) == "1": + self.assertTrue( + without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE, + f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms," + f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", + ) + + @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]})) + def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, model_arch: str, use_cache: bool): + model_args = { + "test_name": test_name + "_True", + "model_arch": model_arch, + "use_cache": use_cache, + "use_merged": True, + } + self._setup(model_args) + model_args = { + "test_name": test_name + "_False", + "model_arch": model_arch, + "use_cache": use_cache, + "use_merged": False, + } + self._setup(model_args) + + model_id = MODEL_NAMES[model_arch] + preprocessor = get_preprocessor(model_id) + + question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud" + inputs = preprocessor(images=self.IMAGE, text=question, return_tensors="pt") + del inputs["decoder_attention_mask"] + del inputs["decoder_input_ids"] + + model_not_merged_dir = self.onnx_model_dirs[test_name + "_False"] + model_merged_dir = self.onnx_model_dirs[test_name + "_True"] + + model_not_merged = ORTModelForPix2Struct.from_pretrained(model_not_merged_dir) + not_merged_onnx_path = Path(model_not_merged_dir, ONNX_DECODER_NAME) + self.assertFalse(has_onnx_input(not_merged_onnx_path, "use_cache_branch")) + self.assertEqual(model_not_merged.use_merged, False) + + model_merged = ORTModelForPix2Struct.from_pretrained(model_merged_dir) + merged_onnx_path = Path(model_merged_dir, ONNX_DECODER_MERGED_NAME) + self.assertTrue(has_onnx_input(merged_onnx_path, "use_cache_branch")) + self.assertEqual(model_merged.decoder_with_past, None) + self.assertEqual(model_merged.use_merged, True) + + outputs_model_not_merged = model_not_merged.generate(**inputs) + outputs_model_merged = model_merged.generate(**inputs) + + self.assertTrue(torch.equal(outputs_model_merged, outputs_model_not_merged)) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]}) + ) + @pytest.mark.gpu_test + def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): + if use_cache is False and use_merged is True: + self.skipTest("use_cache=False, use_merged=True are uncompatible") + + model_args = { + "test_name": test_name, + "model_arch": model_arch, + "use_cache": use_cache, + "use_merged": use_merged, + } + self._setup(model_args) + + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForPix2Struct.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False) + io_model = ORTModelForPix2Struct.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True) + + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) + + preprocessor = get_preprocessor(model_id) + + question = [ + "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud and this is even longer and longer and longer and longer and hey", + "Who are you?", + ] + inputs = preprocessor(images=[self.IMAGE, self.IMAGE], text=question, padding=True, return_tensors="pt") + del inputs["decoder_attention_mask"] + del inputs["decoder_input_ids"] + decoder_start_token_id = onnx_model.config.decoder_start_token_id + decoder_inputs = { + "decoder_input_ids": torch.ones((2, 1), dtype=torch.long) * decoder_start_token_id, + "decoder_attention_mask": torch.ones((2, 1), dtype=torch.int64), + } + + onnx_outputs = onnx_model(**inputs, **decoder_inputs) + io_outputs = io_model(**inputs, **decoder_inputs) + + self.assertTrue("logits" in io_outputs) + self.assertIsInstance(io_outputs.logits, torch.Tensor) + + self.assertTrue(torch.allclose(onnx_outputs.logits, io_outputs.logits, atol=1e-4)) + + gc.collect() + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]}) + ) + def test_compare_generation_to_io_binding( + self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool + ): + if use_cache is False and use_merged is True: + self.skipTest("use_cache=False, use_merged=True are uncompatible") + + model_args = { + "test_name": test_name, + "model_arch": model_arch, + "use_cache": use_cache, + "use_merged": use_merged, + } + self._setup(model_args) + + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForPix2Struct.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False) + io_model = ORTModelForPix2Struct.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True) + + preprocessor = get_preprocessor(model_id) + + question = ["What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud", "Who are you?"] + inputs = preprocessor(images=[self.IMAGE, self.IMAGE], text=question, padding=True, return_tensors="pt") + del inputs["decoder_attention_mask"] + del inputs["decoder_input_ids"] + onnx_outputs = onnx_model.generate(**inputs, num_beams=5) + io_outputs = io_model.generate(**inputs, num_beams=5) + + # compare tensor outputs + self.assertTrue(torch.equal(onnx_outputs, io_outputs)) + + gc.collect() + + class TestBothExportersORTModel(unittest.TestCase): @parameterized.expand( [ diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 09ada4e369..be0f3d0c31 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -57,6 +57,7 @@ "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJModel", "groupvit": "hf-internal-testing/tiny-random-groupvit", + "hubert": "hf-internal-testing/tiny-random-HubertModel", "ibert": "hf-internal-testing/tiny-random-IBertModel", "levit": "hf-internal-testing/tiny-random-LevitModel", "layoutlm": "hf-internal-testing/tiny-random-LayoutLMModel", @@ -73,32 +74,32 @@ "mt5": "lewtun/tiny-random-mt5", "nystromformer": "hf-internal-testing/tiny-random-NystromformerModel", "pegasus": "hf-internal-testing/tiny-random-PegasusModel", + "pix2struct": "fxmarty/pix2struct-tiny-random", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-RobertaModel", "roformer": "hf-internal-testing/tiny-random-RoFormerModel", "segformer": "hf-internal-testing/tiny-random-SegformerModel", + "sew": "hf-internal-testing/tiny-random-SEWModel", + "sew_d": "hf-internal-testing/tiny-random-SEWDModel", "squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel", + "speech_to_text": "hf-internal-testing/tiny-random-Speech2TextModel", "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "swin": "hf-internal-testing/tiny-random-SwinModel", "t5": "hf-internal-testing/tiny-random-t5", + "trocr": "microsoft/trocr-small-handwritten", + "unispeech": "hf-internal-testing/tiny-random-unispeech", + "unispeech_sat": "hf-internal-testing/tiny-random-UnispeechSatModel", + "vision-encoder-decoder": "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2", "vit": "hf-internal-testing/tiny-random-vit", - "yolos": "hf-internal-testing/tiny-random-YolosModel", "whisper": "openai/whisper-tiny.en", # hf-internal-testing ones are broken - "hubert": "hf-internal-testing/tiny-random-HubertModel", "wav2vec2": "hf-internal-testing/tiny-random-Wav2Vec2Model", "wav2vec2-conformer": "hf-internal-testing/tiny-random-wav2vec2-conformer", "wavlm": "hf-internal-testing/tiny-random-WavlmModel", - "sew": "hf-internal-testing/tiny-random-SEWModel", - "sew_d": "hf-internal-testing/tiny-random-SEWDModel", - "speech_to_text": "hf-internal-testing/tiny-random-Speech2TextModel", - "unispeech": "hf-internal-testing/tiny-random-unispeech", - "unispeech_sat": "hf-internal-testing/tiny-random-UnispeechSatModel", "xlm": "hf-internal-testing/tiny-random-XLMModel", "xlm_roberta": "hf-internal-testing/tiny-xlm-roberta", - "vision-encoder-decoder": "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2", - "trocr": "microsoft/trocr-small-handwritten", + "yolos": "hf-internal-testing/tiny-random-YolosModel", } SEED = 42