Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Early stopping in Hugging Face models #859

Closed
vymao opened this issue Jan 5, 2024 · 7 comments
Closed

Early stopping in Hugging Face models #859

vymao opened this issue Jan 5, 2024 · 7 comments

Comments

@vymao
Copy link

vymao commented Jan 5, 2024

I am trying to enable early stopping of models derived from Hugging Face - specifically, Whisper. I am curious if ONNX models generated via Olive respect this setting in the generation config, as it seems like if I set this

If I follow the conditional generation of Whisper on Hugging Face, I get the following:

>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> 
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
input_features = inputs.input_features

generated_ids = model.generate(inputs=input_features)
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> 
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> 
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
>>> input_features = inputs.input_features
>>> 
>>> generated_ids = model.generate(inputs=input_features)
/Users/victor/anaconda3/envs/transformers-v2/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:2035: FutureWarning: The input name `inputs` is deprecated. Please make sure to use `input_features` instead.
  warnings.warn(
>>> generated_ids
tensor([[50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
           286,   262,  3504,  6097,    11,   290,   356,   389,  9675,   284,
          7062,   465, 21443,    13, 50256]])

As you can see, generation stops at the <|endoftext|> token, 50256.

However, if I run the optimized Olive file in a C++ environment, I get something like:

<|notimestamps|> to see if this<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>

The <|endoftext|> token seems to be continuously generated. I would have possibly expected the <|endoftext|> token to be generated only once, as it becomes pointless to continue generation after that. I am not sure if the ONNX model already does this, and the extra <|endoftext|> tokens are just padding?

I have branched from the Whisper example provided in the examples folder and added the early stopping parameter to the dataset, so my inputs look like this:

inputs = {
                "input_features": data['input_features'],
                "max_length": np.asarray([200], dtype=np.int32),
                "min_length": np.asarray([0], dtype=np.int32),
                "num_beams": np.asarray([1], dtype=np.int32),
                "num_return_sequences": np.asarray([1], dtype=np.int32),
                "length_penalty": np.asarray([1.0], dtype=np.float32),
                "repetition_penalty": np.asarray([1.0], dtype=np.float32),
                "early_stopping": np.asarray([True], dtype=np.bool_)
            }

Other information

  • OS: MacOS 13.5
  • Olive version: [e.g. 0.4.0 or main]
  • ONNXRuntime package and version: 1.16.2
@jambayk
Copy link
Contributor

jambayk commented Jan 5, 2024

Early stopping for the ort beam search op is an attribute and not input. https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftbeamsearch

In Olive, we already enable early stopping https://github.com/microsoft/Olive/blob/main/olive/passes/onnx/insert_beam_search.py#L145.
So, I don't know if there is anything else that can be done from Olive or during inference for it. Maybe you could open an issue in the ort repository if you cannot just ignore the special tokens at the end of the generation?

For the full olive model that also has model pre-post processing using ort-extensions, it is already enabled https://github.com/microsoft/Olive/blob/main/olive/passes/utils/whisper_prepost.py#L17.

@jambayk
Copy link
Contributor

jambayk commented Jan 15, 2024

@vymao do you have any follow up questions or comments? Or can we close this issue?

@edobobo
Copy link

edobobo commented Jan 16, 2024

I have the exact same problem. The problem is that it seems the model does not take into account the end of sequence token id (eos_token_id) at all.

@jambayk
Copy link
Contributor

jambayk commented Feb 13, 2024

There is a new PR in the onnxruntime repo where the token_ids will be set explicitly while creating the beam search node microsoft/onnxruntime#19509

Previously, only a few token ids were set and the rest were inferred using hard-coded offsets. This did not work for all models since the vocabs are not always the same across the different variants of whisper.

Once the PR is checked in, I will test the changes using the ort-nightly build. Will keep you posted once I get to try it out.

Update: Please ignore the above. The issue is unrelated to the linked PR

@kunal-vaishnavi
Copy link

I am not sure if the ONNX model already does this, and the extra <|endoftext|> tokens are just padding?

Yes, this is exactly what is happening. Once the model generates the EOS token id, early stopping is detected and the output is then automatically padded with the EOS token id until the max length is reached. This is done by design because the output shape is already predefined to the max length. The extra EOS token ids can easily be removed during post-processing.

@jambayk
Copy link
Contributor

jambayk commented Feb 13, 2024

@kunal-vaishnavi thanks for the clarification! I was not aware of the padding behavior since the final model in olive uses a post processor which strips the special tokens.

Then the ort PR I linked above the PR is unrelated to this issue since we always had the eos_token_id in the beam search node

helper.make_attribute("eos_token_id", model_config["eos_token_id"]),
. For some reason, I thought the offsets were calculated off another token id.

@jambayk
Copy link
Contributor

jambayk commented Apr 10, 2024

Closing issue since early stopping is already enabled.

Please see response from Kunal above for more clarification.

@jambayk jambayk closed this as completed Apr 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants