Skip to content

Commit

Permalink
Fix chat input typing
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Aug 13, 2024
1 parent bf32dc8 commit 5745ff3
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions rl/llm/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, AsyncGenerator, Iterator, Union, cast
from typing import Any, AsyncGenerator, Iterator, TypedDict, Union, cast

import google.generativeai as genai
import huggingface_hub
Expand All @@ -24,8 +24,7 @@
from anthropic import Anthropic
from google.generativeai.types import HarmBlockThreshold, HarmCategory
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizer

import rl.llm.modal_utils
import rl.utils.click as click
Expand All @@ -49,7 +48,12 @@
from vllm.lora.request import LoRARequest


ChatInput = list[ChatCompletionMessageParam]
class ChatMessage(TypedDict):
role: str
content: str


ChatInput = list[ChatMessage]
InferenceInput = Union[str, ChatInput]


Expand Down Expand Up @@ -816,18 +820,22 @@ def get_inference_engine_cls(engine_name: str = "vllm") -> type[InferenceEngine]
return ENGINES[engine_name]


def get_inference_engine(llm_config: LLMConfig) -> InferenceEngine:
assert llm_config.engine_name in ENGINES, (
def get_inference_engine(
llm_config: LLMConfig, engine_name: str | None = None
) -> InferenceEngine:
engine_name = engine_name or llm_config.engine_name
assert engine_name in ENGINES, (
f"Engine {llm_config.engine_name} not found. "
f"Available engines: {', '.join(ENGINES.keys())}"
)
engine_cls = get_inference_engine_cls(llm_config.engine_name)
engine_cls = get_inference_engine_cls(engine_name)
return engine_cls(llm_config)


# A decorator which injects engine configuration click options, reads them, then constructs the engine
# and passes it to the decorated function.
def inject_llm_engine(defaults: dict[str, Any] | None):
"""A decorator which injects engine configuration click options, reads them,
then constructs the engine and passes it to the decorated function."""

def decorator(func):
@click.option(
"--engine-name",
Expand Down

0 comments on commit 5745ff3

Please sign in to comment.