Skip to content

Commit

Permalink
Add ONNX Support for Decision Transformer Model (#2038)
Browse files Browse the repository at this point in the history
* Decision Transformer to ONNX V0.1

* Decision Transformer to ONNX V0.2

* Update optimum/exporters/onnx/model_configs.py

* Apply suggestions from code review

* Update optimum/exporters/onnx/base.py

* Update optimum/exporters/onnx/model_configs.py

* Update optimum/utils/input_generators.py

* Update optimum/exporters/onnx/model_configs.py

* Apply suggestions from code review

* Update optimum/exporters/tasks.py

* ONNXToDT: changes to order of OrderedDict elements

* make style changes

* test

* remove custom normalized config

* remove unncessary dynamic axes

---------

Co-authored-by: Ilyas Moutawwakil <[email protected]>
Co-authored-by: IlyasMoutawwakil <[email protected]>
  • Loading branch information
3 people authored Nov 25, 2024
1 parent d2a5a6a commit 65a8a94
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Data2VecVision
- Deberta
- Deberta-v2
- Decision Transformer
- Deit
- Detr
- DistilBert
Expand Down
25 changes: 25 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
BloomDummyPastKeyValuesGenerator,
DummyAudioInputGenerator,
DummyCodegenDecoderTextInputGenerator,
DummyDecisionTransformerInputGenerator,
DummyDecoderTextInputGenerator,
DummyEncodecInputGenerator,
DummyFluxTransformerTextInputGenerator,
Expand Down Expand Up @@ -263,6 +264,30 @@ class ImageGPTOnnxConfig(GPT2OnnxConfig):
pass


class DecisionTransformerOnnxConfig(OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyDecisionTransformerInputGenerator,)
NORMALIZED_CONFIG_CLASS = NormalizedConfig

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"states": {0: "batch_size", 1: "sequence_length"},
"actions": {0: "batch_size", 1: "sequence_length"},
"timesteps": {0: "batch_size", 1: "sequence_length"},
"returns_to_go": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"state_preds": {0: "batch_size", 1: "sequence_length"},
"action_preds": {0: "batch_size", 1: "sequence_length"},
"return_preds": {0: "batch_size", 1: "sequence_length"},
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
}


class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads")
Expand Down
9 changes: 9 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class TasksManager:
"multiple-choice": "AutoModelForMultipleChoice",
"object-detection": "AutoModelForObjectDetection",
"question-answering": "AutoModelForQuestionAnswering",
"reinforcement-learning": "AutoModel",
"semantic-segmentation": "AutoModelForSemanticSegmentation",
"text-to-audio": ("AutoModelForTextToSpectrogram", "AutoModelForTextToWaveform"),
"text-generation": "AutoModelForCausalLM",
Expand Down Expand Up @@ -574,6 +575,11 @@ class TasksManager:
onnx="DebertaV2OnnxConfig",
tflite="DebertaV2TFLiteConfig",
),
"decision-transformer": supported_tasks_mapping(
"feature-extraction",
"reinforcement-learning",
onnx="DecisionTransformerOnnxConfig",
),
"deit": supported_tasks_mapping(
"feature-extraction",
"image-classification",
Expand Down Expand Up @@ -2085,6 +2091,9 @@ def get_model_from_task(
if original_task == "automatic-speech-recognition" or task == "automatic-speech-recognition":
if original_task == "auto" and config.architectures is not None:
model_class_name = config.architectures[0]
elif original_task == "reinforcement-learning" or task == "reinforcement-learning":
if config.architectures is not None:
model_class_name = config.architectures[0]

if library_name == "diffusers":
config = DiffusionPipeline.load_config(model_name_or_path, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
DummyAudioInputGenerator,
DummyBboxInputGenerator,
DummyCodegenDecoderTextInputGenerator,
DummyDecisionTransformerInputGenerator,
DummyDecoderTextInputGenerator,
DummyEncodecInputGenerator,
DummyFluxTransformerTextInputGenerator,
Expand Down
37 changes: 37 additions & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,43 @@ class DummyDecoderTextInputGenerator(DummyTextInputGenerator):
)


class DummyDecisionTransformerInputGenerator(DummyTextInputGenerator):
"""
Generates dummy decision transformer inputs.
"""

SUPPORTED_INPUT_NAMES = (
"states",
"actions",
"timesteps",
"returns_to_go",
"attention_mask",
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.act_dim = self.normalized_config.config.act_dim
self.state_dim = self.normalized_config.config.state_dim
self.max_ep_len = self.normalized_config.config.max_ep_len

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "states":
shape = [self.batch_size, self.sequence_length, self.state_dim]
elif input_name == "actions":
shape = [self.batch_size, self.sequence_length, self.act_dim]
elif input_name == "rewards":
shape = [self.batch_size, self.sequence_length, 1]
elif input_name == "returns_to_go":
shape = [self.batch_size, self.sequence_length, 1]
elif input_name == "attention_mask":
shape = [self.batch_size, self.sequence_length]
elif input_name == "timesteps":
shape = [self.batch_size, self.sequence_length]
return self.random_int_tensor(shape=shape, max_value=self.max_ep_len, framework=framework, dtype=int_dtype)

return self.random_float_tensor(shape, min_value=-2.0, max_value=2.0, framework=framework, dtype=float_dtype)


class DummySeq2SeqDecoderTextInputGenerator(DummyDecoderTextInputGenerator):
SUPPORTED_INPUT_NAMES = (
"decoder_input_ids",
Expand Down
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"data2vec-audio": "hf-internal-testing/tiny-random-Data2VecAudioModel",
"deberta": "hf-internal-testing/tiny-random-DebertaModel",
"deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model",
"decision-transformer": "edbeeching/decision-transformer-gym-hopper-medium",
"deit": "hf-internal-testing/tiny-random-DeiTModel",
"donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder",
"donut-swin": "hf-internal-testing/tiny-random-DonutSwinModel",
Expand Down

0 comments on commit 65a8a94

Please sign in to comment.