diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 747e1396fb..2eaada7dad 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -36,6 +36,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - Data2VecVision - Deberta - Deberta-v2 +- Decision Transformer - Deit - Detr - DistilBert diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 8984162ee8..bca7cf24ac 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -27,6 +27,7 @@ BloomDummyPastKeyValuesGenerator, DummyAudioInputGenerator, DummyCodegenDecoderTextInputGenerator, + DummyDecisionTransformerInputGenerator, DummyDecoderTextInputGenerator, DummyEncodecInputGenerator, DummyFluxTransformerTextInputGenerator, @@ -263,6 +264,30 @@ class ImageGPTOnnxConfig(GPT2OnnxConfig): pass +class DecisionTransformerOnnxConfig(OnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (DummyDecisionTransformerInputGenerator,) + NORMALIZED_CONFIG_CLASS = NormalizedConfig + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return { + "states": {0: "batch_size", 1: "sequence_length"}, + "actions": {0: "batch_size", 1: "sequence_length"}, + "timesteps": {0: "batch_size", 1: "sequence_length"}, + "returns_to_go": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + } + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return { + "state_preds": {0: "batch_size", 1: "sequence_length"}, + "action_preds": {0: "batch_size", 1: "sequence_length"}, + "return_preds": {0: "batch_size", 1: "sequence_length"}, + "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + } + + class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads") diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index b4bce4696f..8f28ec42ce 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -217,6 +217,7 @@ class TasksManager: "multiple-choice": "AutoModelForMultipleChoice", "object-detection": "AutoModelForObjectDetection", "question-answering": "AutoModelForQuestionAnswering", + "reinforcement-learning": "AutoModel", "semantic-segmentation": "AutoModelForSemanticSegmentation", "text-to-audio": ("AutoModelForTextToSpectrogram", "AutoModelForTextToWaveform"), "text-generation": "AutoModelForCausalLM", @@ -574,6 +575,11 @@ class TasksManager: onnx="DebertaV2OnnxConfig", tflite="DebertaV2TFLiteConfig", ), + "decision-transformer": supported_tasks_mapping( + "feature-extraction", + "reinforcement-learning", + onnx="DecisionTransformerOnnxConfig", + ), "deit": supported_tasks_mapping( "feature-extraction", "image-classification", @@ -2085,6 +2091,9 @@ def get_model_from_task( if original_task == "automatic-speech-recognition" or task == "automatic-speech-recognition": if original_task == "auto" and config.architectures is not None: model_class_name = config.architectures[0] + elif original_task == "reinforcement-learning" or task == "reinforcement-learning": + if config.architectures is not None: + model_class_name = config.architectures[0] if library_name == "diffusers": config = DiffusionPipeline.load_config(model_name_or_path, **kwargs) diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 40d93d298e..e1c2f52a84 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -52,6 +52,7 @@ DummyAudioInputGenerator, DummyBboxInputGenerator, DummyCodegenDecoderTextInputGenerator, + DummyDecisionTransformerInputGenerator, DummyDecoderTextInputGenerator, DummyEncodecInputGenerator, DummyFluxTransformerTextInputGenerator, diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 148072aa0b..0ac1805f97 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -507,6 +507,43 @@ class DummyDecoderTextInputGenerator(DummyTextInputGenerator): ) +class DummyDecisionTransformerInputGenerator(DummyTextInputGenerator): + """ + Generates dummy decision transformer inputs. + """ + + SUPPORTED_INPUT_NAMES = ( + "states", + "actions", + "timesteps", + "returns_to_go", + "attention_mask", + ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.act_dim = self.normalized_config.config.act_dim + self.state_dim = self.normalized_config.config.state_dim + self.max_ep_len = self.normalized_config.config.max_ep_len + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "states": + shape = [self.batch_size, self.sequence_length, self.state_dim] + elif input_name == "actions": + shape = [self.batch_size, self.sequence_length, self.act_dim] + elif input_name == "rewards": + shape = [self.batch_size, self.sequence_length, 1] + elif input_name == "returns_to_go": + shape = [self.batch_size, self.sequence_length, 1] + elif input_name == "attention_mask": + shape = [self.batch_size, self.sequence_length] + elif input_name == "timesteps": + shape = [self.batch_size, self.sequence_length] + return self.random_int_tensor(shape=shape, max_value=self.max_ep_len, framework=framework, dtype=int_dtype) + + return self.random_float_tensor(shape, min_value=-2.0, max_value=2.0, framework=framework, dtype=float_dtype) + + class DummySeq2SeqDecoderTextInputGenerator(DummyDecoderTextInputGenerator): SUPPORTED_INPUT_NAMES = ( "decoder_input_ids", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 31059c403d..c56132c384 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -67,6 +67,7 @@ "data2vec-audio": "hf-internal-testing/tiny-random-Data2VecAudioModel", "deberta": "hf-internal-testing/tiny-random-DebertaModel", "deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model", + "decision-transformer": "edbeeching/decision-transformer-gym-hopper-medium", "deit": "hf-internal-testing/tiny-random-DeiTModel", "donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder", "donut-swin": "hf-internal-testing/tiny-random-DonutSwinModel",