Skip to content

Commit

Permalink
fix mpt
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Nov 15, 2024
1 parent 64e6d5b commit 067587c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def patch_model_for_export(
class MPTOnnxConfig(TextDecoderOnnxConfig):
# MPT does not require position_ids input.
DEFAULT_ONNX_OPSET = 13
MIN_TRANSFORMERS_VERSION = version.parse("4.41.0")
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers"
)
Expand Down
8 changes: 4 additions & 4 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2325,7 +2325,6 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
"gptj",
"llama",
"mistral",
"mpt",
"opt",
]

Expand All @@ -2335,8 +2334,9 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
if check_if_transformers_greater("4.38"):
SUPPORTED_ARCHITECTURES.append("gemma")

# TODO: fix "mpt" for which inference fails for transformers < v4.41
if check_if_transformers_greater("4.41"):
SUPPORTED_ARCHITECTURES.append("phi3")
SUPPORTED_ARCHITECTURES.extend(["phi3", "mpt"])

FULL_GRID = {
"model_arch": SUPPORTED_ARCHITECTURES,
Expand Down Expand Up @@ -2449,7 +2449,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
transformers_model = transformers_model.eval()
tokenizer = get_preprocessor(model_id)
tokens = tokenizer("This is a sample output", return_tensors="pt")
tokens = tokenizer("This is a sample input", return_tensors="pt")
position_ids = None
if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:
input_shape = tokens["input_ids"].shape
Expand All @@ -2471,7 +2471,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
# Compare batched generation.
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
tokens = tokenizer(["This is", "This is a sample input"], return_tensors="pt", padding=True)
onnx_model.generation_config.eos_token_id = None
transformers_model.generation_config.eos_token_id = None
onnx_model.config.eos_token_id = None
Expand Down

0 comments on commit 067587c

Please sign in to comment.