Skip to content

Commit

Permalink
openai: better error messages; fix greedy matching (#2327)
Browse files Browse the repository at this point in the history
* better error message; fix greedy matching

* Update lm_eval/models/openai_completions.py

Co-authored-by: Hailey Schoelkopf <[email protected]>

* Update lm_eval/models/openai_completions.py

Co-authored-by: Hailey Schoelkopf <[email protected]>

* pre-commit

---------

Co-authored-by: Hailey Schoelkopf <[email protected]>
  • Loading branch information
baberabb and haileyschoelkopf authored Sep 26, 2024
1 parent 00f5537 commit 1bc6c93
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions lm_eval/models/openai_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def parse_logprobs(
for choice, ctxlen in zip(out["choices"], ctxlens):
assert ctxlen > 0, "Context length must be greater than 0"
logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1])
tokens = choice["logprobs"]["token_logprobs"][ctxlen:-1]
tokens_logprobs = choice["logprobs"]["token_logprobs"][ctxlen:-1]
top_logprobs = choice["logprobs"]["top_logprobs"][ctxlen:-1]
is_greedy = True
for tok, top in zip(tokens, top_logprobs):
if tok != max(top, key=top.get):
for tok, top in zip(tokens_logprobs, top_logprobs):
if tok != max(top.values()):
is_greedy = False
break
res.append((logprobs, is_greedy))
Expand Down Expand Up @@ -190,14 +190,18 @@ def api_key(self):
key = os.environ.get("OPENAI_API_KEY", None)
if key is None:
raise ValueError(
"API key not found. Please set the OPENAI_API_KEY environment variable."
"API key not found. Please set the `OPENAI_API_KEY` environment variable."
)
return key

def loglikelihood(self, requests, **kwargs):
assert (
self.model != "gpt-3.5-turbo"
), "Loglikelihood is not supported for gpt-3.5-turbo"
self.model
in [
"babbage-002",
"davinci-002",
]
), f"Prompt loglikelihoods are only supported by OpenAI's API for {['babbage-002', 'davinci-002']}."
return super().loglikelihood(requests, **kwargs)

def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
Expand Down Expand Up @@ -226,6 +230,11 @@ def api_key(self):
key = os.environ.get("OPENAI_API_KEY", None)
if key is None:
raise ValueError(
"API key not found. Please set the OPENAI_API_KEY environment variable."
"API key not found. Please set the `OPENAI_API_KEY` environment variable."
)
return key

def loglikelihood(self, requests, **kwargs):
raise NotImplementedError(
"Loglikelihood (and therefore `multiple_choice`-type tasks) is not supported for chat completions as OpenAI does not provide prompt logprobs. See https://github.com/EleutherAI/lm-evaluation-harness/issues/942#issuecomment-1777836312 or https://github.com/EleutherAI/lm-evaluation-harness/issues/1196 for more background on this limitation."
)

0 comments on commit 1bc6c93

Please sign in to comment.