Skip to content

Commit

Permalink
Upgrade vllm version
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Jul 11, 2024
1 parent 41e65ba commit 9421503
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 94 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ sherlock = [
]
local_llm = [
"rl[llm]",
"vllm==0.4.1; platform_system == 'Linux'",
"vllm==0.5.1; platform_system == 'Linux'",
"peft",
"trl",
"datasets",
Expand Down
115 changes: 22 additions & 93 deletions rl/llm/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ def _apply_chat_template(tokenizer, messages):
)


ENGINES = {}


def _register_engine(name: str):
def decorator(cls):
ENGINES[name] = cls
return cls

return decorator


class InferenceEngine(ABC):
NAME: str
llm_config: LLMConfig
Expand Down Expand Up @@ -106,8 +117,8 @@ def batch_generate(self, prompts: list[InferenceInput]) -> list[InferenceOutput]
_RESPONSE_CANARY = "### Response template begins now, delete this line. ###"


@_register_engine("manual_edit")
class ManualEditEngine(InferenceEngine):
NAME = "manual_edit"
_EDITOR = os.environ.get("EDITOR", "vim")

tokenizer: PreTrainedTokenizer
Expand Down Expand Up @@ -182,7 +193,6 @@ def generate(self, prompt: ChatInput) -> InferenceOutput:


class OpenAIClientEngine(InferenceEngine, ABC):
NAME: str = "openai"
BASE_URL: str = "https://api.openai.com/v1"
API_KEY_NAME: str = "OPENAI_API_KEY"
llm_config: LLMConfig
Expand Down Expand Up @@ -225,27 +235,26 @@ def generate(self, prompt: ChatInput) -> InferenceOutput:
)


@_register_engine("together")
class TogetherEngine(OpenAIClientEngine):
NAME = "together"
BASE_URL = "https://api.together.xyz/v1"
API_KEY_NAME = "TOGETHER_API_KEY"


@_register_engine("openai")
class OpenAIEngine(OpenAIClientEngine):
NAME = "openai"
BASE_URL = "https://api.openai.com/v1"
API_KEY_NAME = "OPENAI_API_KEY"


@_register_engine("groq")
class GroqEngine(OpenAIClientEngine):
NAME = "groq"
BASE_URL = "https://api.groq.com/openai/v1"
API_KEY_NAME = "GROQ_API_KEY"


@_register_engine("gemini")
class GeminiEngine(InferenceEngine):
NAME: str = "gemini"

def __init__(self, llm_config: LLMConfig):
super().__init__(llm_config)

Expand Down Expand Up @@ -318,8 +327,8 @@ def _convert_openai_to_gemini(
)


@_register_engine("anthropic")
class AnthropicEngine(ClientEngine):
NAME = "anthropic"
BASE_URL = "https://api.anthropic.com/v1"
API_KEY_NAME = "ANTHROPIC_API_KEY"

Expand Down Expand Up @@ -365,8 +374,8 @@ def generate(self, prompt: ChatInput) -> InferenceOutput:
)


@_register_engine("modal")
class ModalEngine(InferenceEngine):
NAME = "modal"
app_name: str
modal_call: modal.Function
tokenizer: PreTrainedTokenizer
Expand Down Expand Up @@ -531,7 +540,7 @@ def _get_vllm_engine(
lora_path = lora_path.resolve()

generate_kwargs: dict[str, Any] = {
"sampling_params": sampling_params,
"params": sampling_params,
}
if lora_path is not None:
generate_kwargs["lora_request"] = LoRARequest(
Expand Down Expand Up @@ -571,9 +580,8 @@ def _get_vllm_kwargs(llm_config):
return engine_args_kwargs


@_register_engine("vllm")
class VLLMEngine(InferenceEngine):
NAME = "vllm"

vllm: "LLMEngine"
generate_kwargs: dict

Expand Down Expand Up @@ -621,7 +629,7 @@ def _get_vllm_outputs(self, prompts: list[str]):
for i, prompt in enumerate(prompts):
self.vllm.add_request(
request_id=str(f"{curr_uuid}_{i}"),
prompt=prompt,
inputs=prompt,
**self.generate_kwargs,
)

Expand All @@ -644,9 +652,8 @@ def _get_vllm_outputs(self, prompts: list[str]):
return [output[1] for output in vllm_outputs]


@_register_engine("server_vllm")
class WorkerVLLMEngine(InferenceEngine):
NAME = "server_vllm"

client: openai.OpenAI

def __init__(self, llm_config: LLMConfig):
Expand Down Expand Up @@ -762,84 +769,6 @@ def _find_free_port(self):
return s.getsockname()[1]


class AsyncVLLMEngine(AsyncInferenceEngine):
NAME = "vllm_async"

vllm: "AsyncLLMEngine"
generate_kwargs: dict

def __init__(self, llm_config: LLMConfig):
super().__init__(llm_config)

def __enter__(self):
self.vllm, self.generate_kwargs = _get_vllm_engine(
self.llm_config, use_async=True
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.llm_config.tokenizer_name_or_path
)
return self

def __exit__(self, exc_type, exc_value, traceback):
del self.vllm

async def stream(
self, prompt: InferenceInput
) -> AsyncGenerator[InferenceOutput, bool]: # type: ignore
formatted_prompt: str = (
prompt
if isinstance(prompt, str)
else _apply_chat_template(self.tokenizer, prompt)
)
curr_uuid = str(uuid.uuid4())
result_generator = self.vllm.generate(
formatted_prompt,
**self.generate_kwargs,
request_id=curr_uuid,
)
async for request_output in result_generator:
curr_res = self._wrap_output(request_output)
abort = yield curr_res
if abort:
await self.vllm.abort(curr_uuid)
break

async def generate(self, prompt: InferenceInput) -> InferenceOutput:
if isinstance(prompt, list):
prompt = _apply_chat_template(self.tokenizer, prompt)
res = None
async for res in self.stream(prompt): # type: ignore
pass
return res

def _wrap_output(self, req_output) -> InferenceOutput:
return InferenceOutput(
prompt=req_output.prompt,
text=req_output.outputs[0].text,
metadata={
"model_name_or_path": self.llm_config.model_name_or_path,
"lora_name_or_path": self.llm_config.lora_name_or_path,
},
)


ENGINES = {
e.NAME: e
for e in (
VLLMEngine,
AsyncVLLMEngine,
WorkerVLLMEngine,
OpenAIEngine,
TogetherEngine,
GroqEngine,
AnthropicEngine,
ModalEngine,
GeminiEngine,
ManualEditEngine,
)
}


def get_inference_engine_cls(engine_name: str = "vllm") -> type[InferenceEngine]:
return ENGINES[engine_name]

Expand Down

0 comments on commit 9421503

Please sign in to comment.