Skip to content

Commit

Permalink
Fix instruct models UI issue (#78)
Browse files Browse the repository at this point in the history
* feat(tgi): allow top_k = 0 and top_p = 1 when do_sample = True

This might not be the most elegant solution, but it will allow the
server to keep working when the web ui gives a request with these
parameters for instruct models.

* chore: update version to 0.1.4
  • Loading branch information
tengomucho authored Jul 23, 2024
1 parent da2d1ad commit 7f5b0cc
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 7 deletions.
5 changes: 0 additions & 5 deletions optimum/tpu/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ def from_config(cls, generation_config: GenerationConfig) -> "FusedLogitsWarper"
Returns:
a `FusedLogitsWarper` or None if neither top-k nor top-p are configured.
"""
if generation_config.do_sample and generation_config.top_k == 0 and generation_config.top_p == 1.0:
raise ValueError("Multinomial sampling requires at least top-k or top-p to be specified.")
return cls(generation_config.temperature, generation_config.top_k, generation_config.top_p)

def __call__(self, logits: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
Expand All @@ -59,9 +57,6 @@ def __call__(self, logits: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.
do_top_k = self.top_k > 0 and self.top_k < logits.shape[-1]
do_top_p = self.top_p < 1.0 and self.top_p > 0.0

if not do_top_k and not do_top_p:
return logits, None

if do_top_k:
sorted_logits, sorted_indices = torch.topk(logits, self.top_k)
else:
Expand Down
2 changes: 1 addition & 1 deletion optimum/tpu/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
from pkg_resources import parse_version


__version__ = "0.1.3"
__version__ = "0.1.4"
VERSION = parse_version(__version__)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pkg_resources import parse_version


__version__ = "0.1.3"
__version__ = "0.1.4"
VERSION = parse_version(__version__)

0 comments on commit 7f5b0cc

Please sign in to comment.