Skip to content

Commit

Permalink
remove unncessary dynamic axes
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 21, 2024
1 parent 6ce8071 commit 269dfba
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
14 changes: 6 additions & 8 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,27 +266,25 @@ class ImageGPTOnnxConfig(GPT2OnnxConfig):

class DecisionTransformerOnnxConfig(OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyDecisionTransformerInputGenerator,)
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
act_dim="act_dim", state_dim="state_dim", max_ep_len="max_ep_len", hidden_size="hidden_size", allow_new=True
)
NORMALIZED_CONFIG_CLASS = NormalizedConfig

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"states": {0: "batch_size", 1: "sequence_length"},
"actions": {0: "batch_size", 1: "sequence_length"},
"timesteps": {0: "batch_size", 1: "sequence_length"},
"returns_to_go": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"actions": {0: "batch_size", 1: "sequence_length", 2: "act_dim"},
"states": {0: "batch_size", 1: "sequence_length", 2: "state_dim"},
}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"state_preds": {0: "batch_size", 1: "sequence_length", 2: "state_dim"},
"action_preds": {0: "batch_size", 1: "sequence_length", 2: "act_dim"},
"state_preds": {0: "batch_size", 1: "sequence_length"},
"action_preds": {0: "batch_size", 1: "sequence_length"},
"return_preds": {0: "batch_size", 1: "sequence_length"},
"last_hidden_state": {0: "batch_size", 1: "sequence_length", 2: "last_hidden_state"},
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
}


Expand Down
6 changes: 4 additions & 2 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,11 +513,11 @@ class DummyDecisionTransformerInputGenerator(DummyTextInputGenerator):
"""

SUPPORTED_INPUT_NAMES = (
"states",
"actions",
"timesteps",
"attention_mask",
"returns_to_go",
"states",
"attention_mask",
)

def __init__(self, *args, **kwargs):
Expand All @@ -531,6 +531,8 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
shape = [self.batch_size, self.sequence_length, self.state_dim]
elif input_name == "actions":
shape = [self.batch_size, self.sequence_length, self.act_dim]
elif input_name == "rewards":
shape = [self.batch_size, self.sequence_length, 1]
elif input_name == "returns_to_go":
shape = [self.batch_size, self.sequence_length, 1]
elif input_name == "attention_mask":
Expand Down

0 comments on commit 269dfba

Please sign in to comment.