Skip to content

Commit

Permalink
fix typo in timestep
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 18, 2024
1 parent 7f03a0e commit 395a4f7
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyTransformerTextInputGenerator,
DummyTransformerTimestpsInputGenerator,
DummyTransformerTimestepInputGenerator,
DummyTransformerVisionInputGenerator,
DummyVisionEmbeddingsGenerator,
DummyVisionEncoderDecoderPastKeyValuesGenerator,
Expand Down Expand Up @@ -1207,7 +1207,7 @@ class SD3TransformerOnnxConfig(VisionOnnxConfig):
DEFAULT_ONNX_OPSET = 14

DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestpsInputGenerator,
DummyTransformerTimestepInputGenerator,
DummyTransformerVisionInputGenerator,
DummyTransformerTextInputGenerator,
)
Expand Down Expand Up @@ -1247,7 +1247,7 @@ def torch_to_onnx_output_map(self) -> Dict[str, str]:

class FluxTransformerOnnxConfig(SD3TransformerOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestpsInputGenerator,
DummyTransformerTimestepInputGenerator,
DummyFluxTransformerVisionInputGenerator,
DummyFluxTransformerTextInputGenerator,
)
Expand Down Expand Up @@ -2124,9 +2124,9 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
# for Speech2text, we need to name the second axis as
# encoder_sequence_length / 2 * self._config.num_conv_layers as the axis name is
# used for dummy input generation
common_outputs["last_hidden_state"][
1
] = f"{common_outputs['last_hidden_state'][1]} / {(2 * self._config.num_conv_layers)}"
common_outputs["last_hidden_state"][1] = (
f"{common_outputs['last_hidden_state'][1]} / {(2 * self._config.num_conv_layers)}"
)
return common_outputs


Expand Down
2 changes: 1 addition & 1 deletion optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyTransformerTextInputGenerator,
DummyTransformerTimestpsInputGenerator,
DummyTransformerTimestepInputGenerator,
DummyTransformerVisionInputGenerator,
DummyVisionEmbeddingsGenerator,
DummyVisionEncoderDecoderPastKeyValuesGenerator,
Expand Down
2 changes: 1 addition & 1 deletion optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,7 +1413,7 @@ def generate(
return self.random_int_tensor(shape=(1,), min_value=20, max_value=22, framework=framework, dtype=int_dtype)


class DummyTransformerTimestpsInputGenerator(DummyTimestepInputGenerator):
class DummyTransformerTimestepInputGenerator(DummyTimestepInputGenerator):
SUPPORTED_INPUT_NAMES = ("timestep",)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
Expand Down

0 comments on commit 395a4f7

Please sign in to comment.