Skip to content

Commit

Permalink
Add logprob support to vllm engine
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Aug 26, 2024
1 parent 384474e commit 3dd285b
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions rl/llm/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import modal
import openai
import vllm
import vllm.sequence
from transformers import PreTrainedTokenizer


Expand Down Expand Up @@ -599,6 +600,12 @@ async def batch_generate(
return await tqdm.asyncio.tqdm.gather(*tasks)


# There's little value in making this configurable, so we
# just set it to 20, which is OpenAI's default.
# https://platform.openai.com/docs/api-reference/chat/create#chat-create-top_logprobs
_VLLM_NUM_LOGPROBS = 20


def _get_vllm_engine(
llm_config: LLMConfig,
) -> tuple["VLLMEngine", dict]:
Expand Down Expand Up @@ -644,6 +651,8 @@ def _get_vllm_engine(
frequency_penalty=llm_config.frequency_penalty,
top_p=1.0,
)
if EngineFeature.RETURN_LOGPROBS in llm_config.features:
sampling_params.logprobs = _VLLM_NUM_LOGPROBS

lora_path = None
if llm_config.lora_name_or_path:
Expand Down Expand Up @@ -720,7 +729,23 @@ def _get_vllm_kwargs(llm_config):
return engine_args_kwargs


@_register_engine("vllm", required_modules=("vllm",))
def _parse_vllm_logprobs(
logprobs: list[dict[int, "vllm.sequence.Logprob"]] | None,
) -> list[dict[int, float]]:
if logprobs is None:
raise ValueError("Expected logprobs in vLLM output but got None")

output = []
for logprob_dict in logprobs:
output.append({token_id: lp.logprob for token_id, lp in logprob_dict.items()})
return output


@_register_engine(
"vllm",
required_modules=("vllm",),
supported_features=(EngineFeature.RETURN_LOGPROBS,),
)
class VLLMEngine(InferenceEngine):
vllm: "vllm.LLMEngine"
generate_kwargs: dict
Expand Down Expand Up @@ -754,13 +779,14 @@ def batch_generate(self, prompts: list[InferenceInput]) -> list[InferenceOutput]

inference_outputs = []
for prompt, output in zip(prompts, vllm_outputs):
inference_outputs.append(
InferenceOutput(
prompt=prompt,
text=output.outputs[0].text,
metadata={},
)
inf_output = InferenceOutput(
prompt=prompt,
text=output.outputs[0].text,
metadata={},
)
if EngineFeature.RETURN_LOGPROBS in self.llm_config.features:
inf_output.logprobs = _parse_vllm_logprobs(output.outputs[0].logprobs)
inference_outputs.append(inf_output)
return inference_outputs

def _get_vllm_outputs(self, prompts: list[str]):
Expand Down

0 comments on commit 3dd285b

Please sign in to comment.