Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The translation of the model Helsinki-NLP/opus-mt-es-en in onnx is poor. #17677

Closed
Zapotecatl opened this issue Sep 24, 2023 · 5 comments
Closed
Labels
platform:windows issues related to the Windows platform

Comments

@Zapotecatl
Copy link

Describe the issue

I have exported to onnx the Marian model to translate spanish to english (Helsinki-NLP/opus-mt-es-en). After, a program in python to apply the translation with onnxruntime. However, the translation is not very good, for example:

Input in spanish:
"La novela Crimen y castigo está dividida en seis partes más el epílogo. Mucho se ha comentado de la noción de dualismo en la obra, sugiriéndose la existencia de cierto grado de simetría en ella. Los episodios clave se distribuyen primero en una mitad y luego de nuevo en la otra."

Output in english:
The The Crime and Punishment novel is divided into six parts more the epi. has been commented on the notion dual in, suggesting the existence of a degree of sym in. Key episodes are first distributed in one half and again in the in....

As I said, the translation is poor. I was wondering if there is any way to improve the translation?

To reproduce

I exported the opus-mt-es-en model with this command:

python -m transformers.onnx --model=Helsinki-NLP/opus-mt-es-en --feature=seq2seq-lm --atol=1e-04 D:\\Marian

The python program to do the translation from spahish to english is this:

import onnxruntime as rt
import numpy as np

session = rt.InferenceSession('D:\\Marian\\model.onnx')

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-en")
encoded_input = tokenizer("La novela Crimen y castigo está dividida en seis partes más el epílogo. Mucho se ha comentado de la noción de dualismo en la obra, sugiriéndose la existencia de cierto grado de simetría en ella. Los episodios clave se distribuyen primero en una mitad y luego de nuevo en la otra.")

input_ids = np.array(encoded_input.input_ids).astype(np.int64).reshape(1, -1)
attention_mask = np.array(encoded_input.attention_mask).astype(np.int64).reshape(1, -1)

size = 100
decoder_input_ids = np.full((1, size), 65000).astype(np.int64)
decoder_attention_mask = np.array([np.zeros(size)]).astype(np.int64)
decoder_attention_mask[0][0] = 1

model_input = {
    'input_ids': input_ids,
    'attention_mask': attention_mask,
    'decoder_input_ids': decoder_input_ids,
    'decoder_attention_mask': decoder_attention_mask,
}

for i in range(1, size):
    logits = session.run(None, model_input)[0]
    tokens = logits.argmax(axis=2)[0]
    model_input["decoder_input_ids"][0, i] = tokens[i]
    model_input["decoder_attention_mask"][0, i] = 1
    print(tokens[i])

predicted_sequence = model_input["decoder_input_ids"].reshape(-1)
print(model_input["decoder_input_ids"])
decoded_output = tokenizer.decode(predicted_sequence, skip_special_tokens=True)
print("Result: ")
print(decoded_output)

Urgency

No response

Platform

Windows

OS Version

Windows 10 64 bits

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.15.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

@github-actions github-actions bot added the platform:windows issues related to the Windows platform label Sep 24, 2023
@faxu
Copy link
Contributor

faxu commented Sep 25, 2023

Did you try running the original model to see if the translation is different? It'll be good to pinpoint if this is a model issue, conversion to onnx issue, or onnxruntime issue.

@Zapotecatl
Copy link
Author

Hi @faxu , thank for your suggestion.

Yes, I run this program:

from transformers import AutoTokenizer, MarianMTModel

model_name = f"Helsinki-NLP/opus-mt-es-en"
model = MarianMTModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

sample_text = "La novela Crimen y castigo está dividida en seis partes más el epílogo. Mucho se ha comentado de la noción de dualismo en la obra, sugiriéndose la existencia de cierto grado de simetría en ella. Los episodios clave se distribuyen primero en una mitad y luego de nuevo en la otra."

tokens = tokenizer.tokenize(sample_text)

batch = tokenizer([sample_text], return_tensors="pt")
generated_ids = model.generate(**batch)
print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0])

The output is good:
The novel Crime and Punishment is divided into six parts plus the epilogue. Much has been commented on the notion of dualism in the work, suggesting the existence of a certain degree of symmetry in it. Key episodes are first distributed in one half and then again in the other.

@hariharans29 hariharans29 added model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. and removed model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. labels Sep 27, 2023
@wschin
Copy link
Contributor

wschin commented Oct 2, 2023

I am not sure if the mapping from

for i in range(1, size):
    logits = session.run(None, model_input)[0]
    tokens = logits.argmax(axis=2)[0]
    model_input["decoder_input_ids"][0, i] = tokens[i]
    model_input["decoder_attention_mask"][0, i] = 1
    print(tokens[i])

to

generated_ids = model.generate(**batch)

is correct. The first token generated by onnxruntime matches the first token generated by huggingface, so the computation for the first token (i.e., the whole model's computation) seems correct in onnxruntime. It makes feel there could be bug in the sequence-generating part. I can take a look at the model.generate if you can share a link to its code. If you use the same for-loop to feed data token-by-token into huggingface model, does it generate text similar to onnxruntime's output?

@Zapotecatl
Copy link
Author

@wschin
Thanks for your answer.

"If you use the same for-loop to feed data token-by-token into huggingface model, does it generate text similar to onnxruntime's output?"

I haven't tried it, I don't know how to do it into hugging face model :(.

From what I have investigated, the way to export with python -m transformers.onnx is currently deprecated. The new way is:

optimum-cli export onnx --model Helsinki-NLP/opus-mt-es-en D:\\Marian

I did a program using the same logic with encoder and decoder:
huggingface/transformers#26523

However, I also got a wrong translation.

Seems there is a problem currently with Marian models on hugging face.

@wschin
Copy link
Contributor

wschin commented Oct 6, 2023

Since it doesn't sound onnxruntime bug, I will close this for now. Feel free to reopen if huggingface thinks ORT has problems. Thank you!

@wschin wschin closed this as completed Oct 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
platform:windows issues related to the Windows platform
Projects
None yet
Development

No branches or pull requests

4 participants