Skip to content

Commit

Permalink
Add logprobs to gpt 3.5
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Sep 5, 2024
1 parent c10e0fd commit a3661b9
Show file tree
Hide file tree
Showing 6 changed files with 4,586 additions and 20 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
build/
dist/
*.spec
uv.lock
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ curl -LsSf https://astral.sh/uv/install.sh | sh

Then, on both your local machine and Sherlock (or either one, if you just want some features):
```
uv tool install "rl[sherlock] @ git+https://github.com/ProbablyFaiz/rl.git@v0.8.1"
uv tool install "rl[sherlock] @ git+https://github.com/ProbablyFaiz/rl.git@v0.9.0"
```

## Setup
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ authors = [
{ name = "Faiz Surani", email = "[email protected]" },
{ name = "Varun Magesh", email="[email protected]" }
]
description = "A CLI for various RegLab / Sherlock tasks"
description = "Utilities for various RegLab / Sherlock tasks"
readme = "README.md"
classifiers = [
"Development Status :: 3 - Alpha",
Expand All @@ -25,8 +25,8 @@ dependencies = [
"python-dotenv",
"unidecode",
"StrEnum",
"watchdog",
# TODO: Maybe move into an extra?
"watchdog",
"pydantic>=2",
"typing-extensions",
"legal-segmenter @ git+https://github.com/lexeme-dev/legal-segmenter@main",
Expand All @@ -42,9 +42,9 @@ sherlock = [
"paramiko",
]
local_llm = [
"rl[llm]",
"transformers>=4.43",
"torch",
"rl[llm]",
"vllm~=0.5.3post1; platform_system == 'Linux'",
"peft",
"trl",
Expand Down
51 changes: 37 additions & 14 deletions rl/llm/engines/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def batch_generate(self, prompts: list[ChatInput]) -> list[InferenceOutput]:
)


_NUM_LOGPROBS = 20


class _OAIClientEngine(ClientEngine, ABC):
BASE_URL: str
API_KEY_NAME: str
Expand Down Expand Up @@ -82,16 +85,26 @@ def generate(self, prompt: ChatInput) -> InferenceOutput:
completion_kwargs["max_tokens"] = self.llm_config.max_new_tokens
if EngineFeature.JSON_OUTPUT in self.enabled_features:
completion_kwargs["response_format"] = {"type": "json_object"}
if EngineFeature.RETURN_LOGPROBS in self.enabled_features:
completion_kwargs["logprobs"] = True
completion_kwargs["top_logprobs"] = _NUM_LOGPROBS

response = self.client.chat.completions.create(**completion_kwargs)
return InferenceOutput(
choice = response.choices[0]

output = InferenceOutput(
prompt=prompt, # type: ignore
text=response.choices[0].message.content,
text=choice.message.content,
metadata={
"model": self.llm_config.model_name_or_path,
"base_url": self.BASE_URL,
},
)
if EngineFeature.RETURN_LOGPROBS in self.enabled_features:
output.logprobs = [
{tkn.token: tkn.logprob for tkn in pos.top_logprobs}
for pos in choice.logprobs.content
]


@register_engine(
Expand All @@ -107,7 +120,7 @@ class TogetherEngine(_OAIClientEngine):
@register_engine(
"openai",
required_modules=("openai",),
supported_features=(EngineFeature.JSON_OUTPUT,),
supported_features=(EngineFeature.JSON_OUTPUT, EngineFeature.RETURN_LOGPROBS),
)
class OpenAIEngine(_OAIClientEngine):
BASE_URL = "https://api.openai.com/v1"
Expand Down Expand Up @@ -322,7 +335,7 @@ def _get_modal_app_name(self, model_name: str) -> str:
@register_engine(
"batch_openai",
required_modules=("openai",),
supported_features=(EngineFeature.JSON_OUTPUT,),
supported_features=(EngineFeature.JSON_OUTPUT, EngineFeature.RETURN_LOGPROBS),
)
class BatchOpenAIEngine(InferenceEngine):
def __init__(self, llm_config: LLMConfig):
Expand All @@ -343,6 +356,8 @@ def generate(self, prompt: ChatInput) -> InferenceOutput:
return self.batch_generate([prompt])[0]

def batch_generate(self, prompts: list[ChatInput]) -> list[InferenceOutput]:
from openai.types.chat import ChatCompletion

with tempfile.NamedTemporaryFile(
mode="w+", suffix=".jsonl", delete=False
) as temp_file:
Expand All @@ -357,6 +372,9 @@ def batch_generate(self, prompts: list[ChatInput]) -> list[InferenceOutput]:
body_kwargs["temperature"] = self.llm_config.temperature
if EngineFeature.JSON_OUTPUT in self.enabled_features:
body_kwargs["response_format"] = {"type": "json_object"}
if EngineFeature.RETURN_LOGPROBS in self.enabled_features:
body_kwargs["logprobs"] = True
body_kwargs["top_logprobs"] = _NUM_LOGPROBS
request = {
"custom_id": f"request-{i}",
"method": "POST",
Expand Down Expand Up @@ -396,17 +414,22 @@ def batch_generate(self, prompts: list[ChatInput]) -> list[InferenceOutput]:

outputs = []
for result in results:
response = result["response"]["body"]
outputs.append(
InferenceOutput(
prompt=prompts[int(result["custom_id"].split("-")[1])],
text=response["choices"][0]["message"]["content"],
metadata={
"model": self.llm_config.model_name_or_path,
"base_url": "https://api.openai.com/v1",
},
)
parsed_result = ChatCompletion.model_validate(result["response"]["body"])
choice = parsed_result.choices[0]
output = InferenceOutput(
prompt=prompts[int(result["custom_id"].split("-")[1])],
text=choice.message.content,
metadata={
"model": self.llm_config.model_name_or_path,
"base_url": "https://api.openai.com/v1",
},
)
if EngineFeature.RETURN_LOGPROBS in self.enabled_features:
output.logprobs = [
{tkn.token: tkn.logprob for tkn in pos.top_logprobs}
for pos in choice.logprobs.content
]
outputs.append(output)

Path(temp_file_path).unlink()

Expand Down
5 changes: 4 additions & 1 deletion rl/llm/engines/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ class InferenceOutput(BaseModel):
prompt: InferenceInput
text: str

logprobs: list[dict[int, float]] | None = None
# TODO: OpenAI gives us logprobs in terms of the string representation of the token
# but vLLM gives us logprobs in terms of the token itself. We should probably
# standardize on one representation.
logprobs: list[dict[int | str, float]] | None = None
metadata: dict[str, Any] = Field(default_factory=dict)


Expand Down
Loading

0 comments on commit a3661b9

Please sign in to comment.