diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 8ffd5243507..bca7cf24acf 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -266,27 +266,25 @@ class ImageGPTOnnxConfig(GPT2OnnxConfig): class DecisionTransformerOnnxConfig(OnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyDecisionTransformerInputGenerator,) - NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( - act_dim="act_dim", state_dim="state_dim", max_ep_len="max_ep_len", hidden_size="hidden_size", allow_new=True - ) + NORMALIZED_CONFIG_CLASS = NormalizedConfig @property def inputs(self) -> Dict[str, Dict[int, str]]: return { + "states": {0: "batch_size", 1: "sequence_length"}, + "actions": {0: "batch_size", 1: "sequence_length"}, "timesteps": {0: "batch_size", 1: "sequence_length"}, "returns_to_go": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "sequence_length"}, - "actions": {0: "batch_size", 1: "sequence_length", 2: "act_dim"}, - "states": {0: "batch_size", 1: "sequence_length", 2: "state_dim"}, } @property def outputs(self) -> Dict[str, Dict[int, str]]: return { - "state_preds": {0: "batch_size", 1: "sequence_length", 2: "state_dim"}, - "action_preds": {0: "batch_size", 1: "sequence_length", 2: "act_dim"}, + "state_preds": {0: "batch_size", 1: "sequence_length"}, + "action_preds": {0: "batch_size", 1: "sequence_length"}, "return_preds": {0: "batch_size", 1: "sequence_length"}, - "last_hidden_state": {0: "batch_size", 1: "sequence_length", 2: "last_hidden_state"}, + "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, } diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index a6ce07bab32..0ac1805f97d 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -513,11 +513,11 @@ class DummyDecisionTransformerInputGenerator(DummyTextInputGenerator): """ SUPPORTED_INPUT_NAMES = ( + "states", "actions", "timesteps", - "attention_mask", "returns_to_go", - "states", + "attention_mask", ) def __init__(self, *args, **kwargs): @@ -531,6 +531,8 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int shape = [self.batch_size, self.sequence_length, self.state_dim] elif input_name == "actions": shape = [self.batch_size, self.sequence_length, self.act_dim] + elif input_name == "rewards": + shape = [self.batch_size, self.sequence_length, 1] elif input_name == "returns_to_go": shape = [self.batch_size, self.sequence_length, 1] elif input_name == "attention_mask":