Skip to content

Commit

Permalink
Pix2Struct onnxruntime support (#1296)
Browse files Browse the repository at this point in the history
* add onnxruntime support for pix2struct

* Fix for Pix2Struct ORT

* Fix from_pretrained for pix2struct

* test

* test

* huggingface(#1288)

* Update optimum/onnxruntime/modeling_seq2seq.py

Co-authored-by: fxmarty <[email protected]>

* Update optimum/onnxruntime/modeling_seq2seq.py

Co-authored-by: fxmarty <[email protected]>

* Update optimum/onnxruntime/modeling_seq2seq.py

Co-authored-by: fxmarty <[email protected]>

* update modeling_seq2seq.py

* update modeling_seq2seq.py

* working ort inference pix2struct

* add documentation

* fix doc

---------

Co-authored-by: ARK <[email protected]>
Co-authored-by: fxmarty <[email protected]>
  • Loading branch information
3 people authored Aug 23, 2023
1 parent edd829b commit 05d20df
Show file tree
Hide file tree
Showing 13 changed files with 834 additions and 74 deletions.
38 changes: 37 additions & 1 deletion docs/source/onnxruntime/package_reference/modeling_ort.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,44 @@ The following ORT classes are available for the following natural language proce

### ORTModelForCausalLM

This class officially supports bloom, codegen, gpt2, gpt_bigcode, gpt_neo, gpt_neox, gptj, llama.

[[autodoc]] onnxruntime.ORTModelForCausalLM

### ORTModelForMaskedLM

This class officially supports albert, bert, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, ibert, mobilebert, roberta, roformer, squeezebert, xlm, xlm_roberta.

[[autodoc]] onnxruntime.ORTModelForMaskedLM

### ORTModelForSeq2SeqLM

This class officially supports bart, blenderbot, blenderbot_small, longt5, m2m_100, marian, mbart, mt5, pegasus, t5.

[[autodoc]] onnxruntime.ORTModelForSeq2SeqLM

### ORTModelForSequenceClassification

This class officially supports albert, bart, bert, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, ibert, mbart, mobilebert, nystromformer, roberta, roformer, squeezebert, xlm, xlm_roberta.

[[autodoc]] onnxruntime.ORTModelForSequenceClassification

### ORTModelForTokenClassification

This class officially supports albert, bert, bloom, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, gpt2, ibert, mobilebert, roberta, roformer, squeezebert, xlm, xlm_roberta.

[[autodoc]] onnxruntime.ORTModelForTokenClassification

### ORTModelForMultipleChoice

This class officially supports albert, bert, camembert, convbert, data2vec_text, deberta_v2, distilbert, electra, flaubert, ibert, mobilebert, nystromformer, roberta, roformer, squeezebert, xlm, xlm_roberta.

[[autodoc]] onnxruntime.ORTModelForMultipleChoice

### ORTModelForQuestionAnswering

This class officially supports albert, bart, bert, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, gptj, ibert, mbart, mobilebert, nystromformer, roberta, roformer, squeezebert, xlm, xlm_roberta.

[[autodoc]] onnxruntime.ORTModelForQuestionAnswering

## Computer vision
Expand All @@ -58,10 +72,14 @@ The following ORT classes are available for the following computer vision tasks.

### ORTModelForImageClassification

This class officially supports beit, convnext, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, vit.

[[autodoc]] onnxruntime.ORTModelForImageClassification

### ORTModelForSemanticSegmentation

This class officially supports segformer.

[[autodoc]] onnxruntime.ORTModelForSemanticSegmentation

## Audio
Expand All @@ -70,22 +88,32 @@ The following ORT classes are available for the following audio tasks.

### ORTModelForAudioClassification

This class officially supports audio_spectrogram_transformer, data2vec_audio, hubert, sew, sew_d, unispeech, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer.

[[autodoc]] onnxruntime.ORTModelForAudioClassification

### ORTModelForAudioFrameClassification

This class officially supports data2vec_audio, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer.

[[autodoc]] onnxruntime.ORTModelForAudioFrameClassification

### ORTModelForCTC

This class officially supports data2vec_audio, hubert, sew, sew_d, unispeech, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer.

[[autodoc]] onnxruntime.ORTModelForCTC

### ORTModelForSpeechSeq2Seq

This class officially supports whisper, speech_to_text.

[[autodoc]] onnxruntime.ORTModelForSpeechSeq2Seq

### ORTModelForAudioXVector

This class officially supports data2vec_audio, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer.

[[autodoc]] onnxruntime.ORTModelForAudioXVector

## Multimodal
Expand All @@ -94,8 +122,16 @@ The following ORT classes are available for the following multimodal tasks.

### ORTModelForVision2Seq

This class officially supports trocr and vision-encoder-decoder.

[[autodoc]] onnxruntime.ORTModelForVision2Seq

### ORTModelForPix2Struct

This class officially supports pix2struct.

[[autodoc]] onnxruntime.ORTModelForPix2Struct

## Custom Tasks

The following ORT classes are available for the following custom tasks.
Expand Down Expand Up @@ -129,4 +165,4 @@ The following ORT classes are available for the following custom tasks.

#### ORTStableDiffusionXLImg2ImgPipeline

[[autodoc]] onnxruntime.ORTStableDiffusionXLImg2ImgPipeline
[[autodoc]] onnxruntime.ORTStableDiffusionXLImg2ImgPipeline
13 changes: 10 additions & 3 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from ...commands.export.onnx import parse_args_onnx
from ...utils import DEFAULT_DUMMY_SHAPES, ONNX_WEIGHTS_NAME, logging
from ...utils.save_utils import maybe_save_preprocessors
from ...utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from ..error_utils import AtolError, OutputMatchError, ShapeError
from ..tasks import TasksManager
from .base import OnnxConfigWithPast
Expand All @@ -43,7 +43,7 @@
if is_torch_available():
import torch

from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union


if TYPE_CHECKING:
Expand All @@ -62,6 +62,7 @@ def _get_submodels_and_onnx_configs(
custom_onnx_configs: Dict,
custom_architecture: bool,
fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
):
is_stable_diffusion = "stable-diffusion" in task
if not custom_architecture:
Expand All @@ -72,7 +73,7 @@ def _get_submodels_and_onnx_configs(
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="onnx", task=task
)
onnx_config = onnx_config_constructor(model.config)
onnx_config = onnx_config_constructor(model.config, preprocessors=preprocessors)

if (
model.config.is_encoder_decoder
Expand Down Expand Up @@ -359,13 +360,19 @@ def main_export(
possible_synonyms = ""
logger.info(f"Automatic task detection to {task}{possible_synonyms}.")

# The preprocessors are loaded as they may be useful to export the model. Notably, some of the static input shapes may be stored in the
# preprocessors config.
preprocessors = maybe_load_preprocessors(
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
)
onnx_config, models_and_onnx_configs = _get_submodels_and_onnx_configs(
model=model,
task=task,
monolith=monolith,
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
custom_architecture=custom_architecture,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
)

if not is_stable_diffusion:
Expand Down
67 changes: 45 additions & 22 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,17 @@ class OnnxConfig(ExportConfig, ABC):
),
}

def __init__(self, config: "PretrainedConfig", task: str = "feature-extraction"):
def __init__(
self, config: "PretrainedConfig", task: str = "feature-extraction", preprocessors: Optional[List[Any]] = None
):
if task not in self._TASK_TO_COMMON_OUTPUTS:
raise ValueError(
f"{task} is not a supported task, supported tasks: {', '.join(self._TASK_TO_COMMON_OUTPUTS.keys())}"
)
self.task = task

self._config = config
self._preprocessors = preprocessors
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)

def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]:
Expand Down Expand Up @@ -493,6 +496,7 @@ def __init__(
use_past: bool = False,
use_past_in_inputs: Optional[bool] = None,
use_present_in_outputs: Optional[bool] = None,
preprocessors: Optional[List[Any]] = None,
):
self.use_past = use_past
if use_past_in_inputs is None:
Expand All @@ -515,10 +519,12 @@ def __init__(
)
self.is_merged = False
self.use_cache_branch = None
super().__init__(config, task=task)
super().__init__(config, task=task, preprocessors=preprocessors)

@classmethod
def with_past(cls, config: "PretrainedConfig", task: str = "feature-extraction") -> "OnnxConfigWithPast":
def with_past(
cls, config: "PretrainedConfig", task: str = "feature-extraction", preprocessors: Optional[List[Any]] = None
) -> "OnnxConfigWithPast":
"""
Instantiates a [`~optimum.exporters.onnx.OnnxConfig`] with `use_past` attribute set to `True`.
Expand All @@ -531,7 +537,7 @@ def with_past(cls, config: "PretrainedConfig", task: str = "feature-extraction")
Returns:
[`~optimum.exporters.onnx.OnnxConfig`]: The onnx config with `.use_past = True`
"""
return cls(config, task=task, use_past=True)
return cls(config, task=task, use_past=True, preprocessors=preprocessors)

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
Expand Down Expand Up @@ -564,24 +570,9 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
# models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name
# while models from TextDecoderOnnxConfig use input_ids, hence the check for both
if (
self.use_past is True
and self.use_cache_branch is not False
and input_name in ["decoder_input_ids", "input_ids"]
):
sequence_length = dummy_input_gen.sequence_length
if "sequence_length" in kwargs and kwargs["sequence_length"] != 1:
logger.info(
f"Asked a sequence length of {kwargs['sequence_length']}, but a sequence length of 1 "
f"will be used with use_past == True for `{input_name}`."
)
dummy_input_gen.sequence_length = 1
dummy_inputs[input_name] = dummy_input_gen.generate(input_name, framework=framework)
dummy_input_gen.sequence_length = sequence_length
else:
dummy_inputs[input_name] = dummy_input_gen.generate(input_name, framework=framework)
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
dummy_input_gen, input_name, framework, input_shapes=kwargs
)
input_was_inserted = True
break
if not input_was_inserted:
Expand Down Expand Up @@ -617,6 +608,35 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):

return dummy_inputs

def overwrite_shape_and_generate_input(
self, dummy_input_gen: "DummyInputGenerator", input_name: str, framework: str, input_shapes: Dict
):
"""
The shape passed to the dummy input generator may not always be correct for all of the inputs it manages. This method allows
to overwrite some shapes, and generate the dummy input. This should probably be refactored more elegantly.
"""

# models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name
# while models from TextDecoderOnnxConfig use input_ids, hence the check for both
if (
self.use_past is True
and self.use_cache_branch is not False
and input_name in ["decoder_input_ids", "input_ids"]
):
sequence_length = dummy_input_gen.sequence_length
if "sequence_length" in input_shapes and input_shapes["sequence_length"] != 1:
logger.info(
f"Asked a sequence length of {input_shapes['sequence_length']}, but a sequence length of 1 "
f"will be used with use_past == True for `{input_name}`."
)
dummy_input_gen.sequence_length = 1
dummy_input = dummy_input_gen.generate(input_name, framework=framework)
dummy_input_gen.sequence_length = sequence_length
else:
dummy_input = dummy_input_gen.generate(input_name, framework=framework)

return dummy_input

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
"""
Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.
Expand Down Expand Up @@ -703,13 +723,15 @@ def __init__(
use_past_in_inputs: Optional[bool] = None,
use_present_in_outputs: Optional[bool] = None,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
):
super().__init__(
config,
task=task,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
use_present_in_outputs=use_present_in_outputs,
preprocessors=preprocessors,
)
self._behavior = behavior
self.override_attributes_for_behavior()
Expand Down Expand Up @@ -746,6 +768,7 @@ def with_behavior(
task=self.task,
use_past=use_past,
behavior=behavior,
preprocessors=self._preprocessors,
)

@property
Expand Down
8 changes: 6 additions & 2 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def __init__(
use_past_in_inputs: Optional[bool] = None,
use_present_in_outputs: Optional[bool] = None,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
):
super().__init__(
config,
Expand All @@ -286,6 +287,7 @@ def __init__(
use_past_in_inputs=use_past_in_inputs,
use_present_in_outputs=use_present_in_outputs,
behavior=behavior,
preprocessors=preprocessors,
)

from ..tasks import TasksManager
Expand All @@ -296,7 +298,7 @@ def __init__(
encoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor(
exporter="onnx", task="feature-extraction", model_type=config.encoder.model_type
)
self._encoder_onnx_config = encoder_onnx_config_constructor(config.encoder)
self._encoder_onnx_config = encoder_onnx_config_constructor(config.encoder, preprocessors=preprocessors)
self._normalized_config.ENCODER_NORMALIZED_CONFIG_CLASS = self._encoder_onnx_config._normalized_config

if self._behavior is not ConfigBehavior.ENCODER:
Expand All @@ -316,7 +318,9 @@ def __init__(
"past key values."
)

self._decoder_onnx_config = decoder_onnx_config_constructor(config.decoder, **kwargs)
self._decoder_onnx_config = decoder_onnx_config_constructor(
config.decoder, preprocessors=preprocessors, **kwargs
)
if issubclass(decoder_onnx_config_constructor.func, OnnxSeq2SeqConfigWithPast):
self._decoder_onnx_config = self._decoder_onnx_config.with_behavior(
self._behavior, use_past=kwargs["use_past"]
Expand Down
Loading

0 comments on commit 05d20df

Please sign in to comment.