From 269dfba3d84c93e79a0e648f5bbcfb23c174e2d1 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 21 Nov 2024 13:51:58 +0100 Subject: [PATCH] remove unncessary dynamic axes --- optimum/exporters/onnx/model_configs.py | 14 ++++++-------- optimum/utils/input_generators.py | 6 ++++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 8ffd5243507..bca7cf24acf 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -266,27 +266,25 @@ class ImageGPTOnnxConfig(GPT2OnnxConfig): class DecisionTransformerOnnxConfig(OnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyDecisionTransformerInputGenerator,) - NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( - act_dim="act_dim", state_dim="state_dim", max_ep_len="max_ep_len", hidden_size="hidden_size", allow_new=True - ) + NORMALIZED_CONFIG_CLASS = NormalizedConfig @property def inputs(self) -> Dict[str, Dict[int, str]]: return { + "states": {0: "batch_size", 1: "sequence_length"}, + "actions": {0: "batch_size", 1: "sequence_length"}, "timesteps": {0: "batch_size", 1: "sequence_length"}, "returns_to_go": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "sequence_length"}, - "actions": {0: "batch_size", 1: "sequence_length", 2: "act_dim"}, - "states": {0: "batch_size", 1: "sequence_length", 2: "state_dim"}, } @property def outputs(self) -> Dict[str, Dict[int, str]]: return { - "state_preds": {0: "batch_size", 1: "sequence_length", 2: "state_dim"}, - "action_preds": {0: "batch_size", 1: "sequence_length", 2: "act_dim"}, + "state_preds": {0: "batch_size", 1: "sequence_length"}, + "action_preds": {0: "batch_size", 1: "sequence_length"}, "return_preds": {0: "batch_size", 1: "sequence_length"}, - "last_hidden_state": {0: "batch_size", 1: "sequence_length", 2: "last_hidden_state"}, + "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, } diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index a6ce07bab32..0ac1805f97d 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -513,11 +513,11 @@ class DummyDecisionTransformerInputGenerator(DummyTextInputGenerator): """ SUPPORTED_INPUT_NAMES = ( + "states", "actions", "timesteps", - "attention_mask", "returns_to_go", - "states", + "attention_mask", ) def __init__(self, *args, **kwargs): @@ -531,6 +531,8 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int shape = [self.batch_size, self.sequence_length, self.state_dim] elif input_name == "actions": shape = [self.batch_size, self.sequence_length, self.act_dim] + elif input_name == "rewards": + shape = [self.batch_size, self.sequence_length, 1] elif input_name == "returns_to_go": shape = [self.batch_size, self.sequence_length, 1] elif input_name == "attention_mask":