Skip to content

Commit

Permalink
Generate: fix logits processors doctests (huggingface#29718)
Browse files Browse the repository at this point in the history
* fix norm

* fix logits processors doctests
  • Loading branch information
gante authored Apr 2, 2024
1 parent 9b0a8ea commit 5080ab1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 56 deletions.
76 changes: 26 additions & 50 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ class TemperatureLogitsWarper(LogitsWarper):
>>> generate_kwargs = {"max_new_tokens": 10, "do_sample": True, "temperature": 1.0, "num_return_sequences": 2}
>>> outputs = model.generate(**inputs, **generate_kwargs)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['Hugging Face Company is a joint venture between GEO Group, one of',
'Hugging Face Company is not an exact science – but what we believe does']
['Hugging Face Company is one of these companies that is going to take a',
"Hugging Face Company is a brand created by Brian A. O'Neil"]
>>> # However, with temperature close to 0, it approximates greedy decoding strategies (invariant)
>>> generate_kwargs["temperature"] = 0.0001
Expand Down Expand Up @@ -419,7 +419,7 @@ class TopPLogitsWarper(LogitsWarper):
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(0)
>>> set_seed(1)
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
Expand All @@ -428,7 +428,9 @@ class TopPLogitsWarper(LogitsWarper):
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2
A sequence: 1, 2, 3 | < 4 (left-hand pointer) ;
<BLANKLINE>
<BLANKLINE>
>>> # With `top_p` sampling, the output gets restricted to high-probability tokens.
>>> # Pro tip: In practice, LLMs use `top_p` in the 0.9-0.95 range.
Expand Down Expand Up @@ -483,7 +485,7 @@ class TopKLogitsWarper(LogitsWarper):
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(0)
>>> set_seed(1)
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
Expand All @@ -492,7 +494,7 @@ class TopKLogitsWarper(LogitsWarper):
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: A, B, C, D, G, H, I. A, M
A sequence: A, B, C, D, E — S — O, P — R
>>> # With `top_k` sampling, the output gets restricted the k most likely tokens.
>>> # Pro tip: In practice, LLMs use `top_k` in the 5-50 range.
Expand Down Expand Up @@ -624,7 +626,7 @@ class EpsilonLogitsWarper(LogitsWarper):
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(0)
>>> set_seed(1)
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
Expand All @@ -633,7 +635,9 @@ class EpsilonLogitsWarper(LogitsWarper):
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2
A sequence: 1, 2, 3 | < 4 (left-hand pointer) ;
<BLANKLINE>
<BLANKLINE>
>>> # With epsilon sampling, the output gets restricted to high-probability tokens. Note that this is similar to
>>> # Top P sampling, which restricts tokens based on their cumulative probability.
Expand Down Expand Up @@ -701,7 +705,7 @@ class EtaLogitsWarper(LogitsWarper):
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(0)
>>> set_seed(1)
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
Expand All @@ -710,7 +714,9 @@ class EtaLogitsWarper(LogitsWarper):
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2
A sequence: 1, 2, 3 | < 4 (left-hand pointer) ;
<BLANKLINE>
<BLANKLINE>
>>> # With eta sampling, the output gets restricted to high-probability tokens. You can see it as a dynamic form of
>>> # epsilon sampling that adapts its cutoff probability based on the entropy (high entropy = lower cutoff).
Expand Down Expand Up @@ -1211,16 +1217,16 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
>>> # We can contrain it with `prefix_allowed_tokens_fn` to force a certain behavior based on a prefix.
>>> # For instance, we can force an entire entity to be generated when its beginning is detected.
>>> entity = tokenizer(" Bob Marley", return_tensors="pt").input_ids[0] # 3 tokens
>>> entity = tokenizer(" Bob Marley", return_tensors="pt").input_ids[0] # 3 tokens
>>> def prefix_allowed_tokens_fn(batch_id, input_ids):
... '''
... Attempts to generate 'Bob Marley' when 'Bob' is detected.
... In this case, `batch_id` is not used, but you can set rules for each batch member.
... '''
... if input_ids[-1] == entity[0]:
... return entity[1]
... return [entity[1].item()]
... elif input_ids[-2] == entity[0] and input_ids[-1] == entity[1]:
... return entity[2]
... return [entity[2].item()]
... return list(range(tokenizer.vocab_size)) # If no match, allow all tokens
>>> outputs = model.generate(**inputs, max_new_tokens=5, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn)
Expand Down Expand Up @@ -1618,13 +1624,13 @@ class LogitNormalization(LogitsProcessor, LogitsWarper):
>>> # By default, the scores are not normalized -- the sum of their exponentials is NOT a normalized probability
>>> # distribution, summing to 1
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(torch.sum(torch.exp(outputs.scores[-1])))
tensor(816.3250)
>>> print(torch.allclose(torch.sum(torch.exp(outputs.scores[-1])), torch.Tensor((1.000,)), rtol=1e-4))
False
>>> # Normalizing them may have a positive impact on beam methods, or when using the scores on your application
>>> outputs = model.generate(**inputs, renormalize_logits=True, return_dict_in_generate=True, output_scores=True)
>>> print(torch.sum(torch.exp(outputs.scores[-1])))
tensor(1.0000)
>>> print(torch.allclose(torch.sum(torch.exp(outputs.scores[-1])), torch.Tensor((1.000,)), rtol=1e-4))
True
```
"""

Expand Down Expand Up @@ -1655,7 +1661,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
>>> # Whisper has `begin_suppress_tokens` set by default (= `[220, 50256]`). 50256 is the EOS token, so this means
>>> # it can't generate and EOS token in the first iteration, but it can in the others.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(outputs.scores[1][0, 50256]) # 1 (and not 0) is the first freely generated token
>>> print(outputs.scores[0][0, 50256])
tensor(-inf)
>>> print(outputs.scores[-1][0, 50256]) # in other places we can see some probability mass for EOS
tensor(29.9010)
Expand All @@ -1664,7 +1670,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
>>> outputs = model.generate(
... **inputs, return_dict_in_generate=True, output_scores=True, begin_suppress_tokens=None
... )
>>> print(outputs.scores[1][0, 50256])
>>> print(outputs.scores[0][0, 50256])
tensor(11.2027)
```
"""
Expand Down Expand Up @@ -1713,7 +1719,7 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
>>> # If we disable `suppress_tokens`, we can generate it.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, suppress_tokens=None)
>>> print(outputs.scores[1][0, 1])
tensor(5.7738)
tensor(6.0678)
```
"""

Expand All @@ -1735,36 +1741,6 @@ class ForceTokensLogitsProcessor(LogitsProcessor):
indices that will be forced before generation. The processor will set their log probs to `inf` so that they are
sampled at their corresponding index. Originally created for
[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).
Examples:
```python
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> # This Whisper model forces the generation to start with `50362` at the first position by default, i.e.
>>> # `"forced_decoder_ids": [[1, 50362]]`. This means all other tokens are masked out.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(
... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
... )
True
>>> print(outputs.scores[0][0, 50362])
tensor(0.)
>>> # If we disable `forced_decoder_ids`, we stop seeing that effect
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, forced_decoder_ids=None)
>>> print(
... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
... )
False
>>> print(outputs.scores[0][0, 50362])
tensor(19.3140)
```
"""

def __init__(self, force_token_map: List[List[int]], _has_warned: Optional[bool] = False):
Expand Down
8 changes: 2 additions & 6 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
import warnings
import zlib
Expand Down Expand Up @@ -474,11 +473,8 @@ def generate(
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
FutureWarning,
)
# 1. copy generation config
if generation_config is None:
generation_config = copy.deepcopy(self.generation_config)
else:
generation_config = copy.deepcopy(generation_config)
# 1. prepare generation config
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)

# 2. set global generate variables
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
Expand Down

0 comments on commit 5080ab1

Please sign in to comment.