Skip to content

Commit

Permalink
Update model configurations and enhance LLM chat handling
Browse files Browse the repository at this point in the history
- Update `common-llms-sealion` image to version 7 and adjust resource limits
- Add new deployment `common-llms-sarvam-2b` with specific resource limits
- Improve `llm_chat` function to handle fallback chat templates
- Modify `LLMChatInputs` to accept text inputs as a list or string
- Ensure stripping of stop strings and eos tokens from generated output
- Refactor `load_pipe` to explicitly return `TextGenerationPipeline`
  • Loading branch information
devxpy committed Aug 13, 2024
1 parent d775038 commit 0773b08
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 15 deletions.
21 changes: 17 additions & 4 deletions chart/model-values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -365,16 +365,29 @@ deployments:
ESRGAN_MODEL_IDS: |-
RealESRGAN_x2plus
- name: "common-llms-sealion"
image: "crgooeyprodwestus1.azurecr.io/gooey-gpu-common:5"
- name: "common-llms-sealion-v2"
image: "crgooeyprodwestus1.azurecr.io/gooey-gpu-common:7"
limits_gpu: "30Gi"
limits:
memory: "45Gi"
memory: "80Gi" # (220 / 80) * 30
cpu: "2"
env:
IMPORTS: |-
common.llms
LLM_MODEL_IDS: |-
aisingapore/sea-lion-7b-instruct
aisingapore/llama3-8b-cpt-sea-lionv2-instruct
- name: "common-llms-sarvam-2b"
image: "crgooeyprodwestus1.azurecr.io/gooey-gpu-common:7"
limits_gpu: "6Gi"
limits:
memory: "16Gi" # (220 / 80) * 6
cpu: "2"
env:
IMPORTS: |-
common.llms
LLM_MODEL_IDS: |-
sarvamai/sarvam-2b-v0.5
## Dependencies
nfs-server-provisioner:
Expand Down
38 changes: 28 additions & 10 deletions common/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import transformers
from pydantic import BaseModel
from transformers import AutoTokenizer
from transformers.models.auto.tokenization_auto import get_tokenizer_config

import gooey_gpu
from celeryconfig import app, setup_queues
Expand All @@ -14,10 +15,11 @@
class PipelineInfo(BaseModel):
model_id: str
seed: int = None
fallback_chat_template_from: str | None


class LLMChatInputs(BaseModel):
messages: typing.List[dict]
text_inputs: typing.List[dict] | str
max_new_tokens: int
stop_strings: typing.Optional[typing.List[str]]
temperature: float = 1
Expand All @@ -33,8 +35,15 @@ class LLMChatOutput(BaseModel):
@gooey_gpu.endpoint
def llm_chat(pipeline: PipelineInfo, inputs: LLMChatInputs) -> LLMChatOutput:
pipe = load_pipe(pipeline.model_id)
return pipe(
inputs.messages,

if pipeline.fallback_chat_template_from and not pipe.tokenizer.chat_template:
# if the tokenizer does not have a chat template, use the provided fallback
config = get_tokenizer_config(pipeline.fallback_chat_template_from)
pipe.tokenizer.chat_template = config.get("chat_template")

# for a list of parameters, see https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig
ret = pipe(
inputs.text_inputs,
max_new_tokens=inputs.max_new_tokens,
stop_strings=inputs.stop_strings,
temperature=inputs.temperature,
Expand All @@ -44,17 +53,26 @@ def llm_chat(pipeline: PipelineInfo, inputs: LLMChatInputs) -> LLMChatOutput:
eos_token_id=pipe.tokenizer.eos_token_id,
)[0]

# strip stop strings & eos token from final output
for s in (inputs.stop_strings or []) + [pipe.tokenizer.eos_token]:
ret["generated_text"] = ret["generated_text"].split(s, 1)[0]

return ret


@lru_cache
def load_pipe(model_id: str):
def load_pipe(model_id: str) -> transformers.TextGenerationPipeline:
print(f"Loading llm model {model_id!r}...")
# this should return a TextGenerationPipeline
pipe = transformers.pipeline(
"text-generation",
model=model_id,
device=gooey_gpu.DEVICE_ID,
torch_dtype=torch.float16,
trust_remote_code=True,
pipe = typing.cast(
transformers.TextGenerationPipeline,
transformers.pipeline(
"text-generation",
model=model_id,
device=gooey_gpu.DEVICE_ID,
torch_dtype=torch.float16,
trust_remote_code=True,
),
)
if not pipe.tokenizer:
pipe.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
Expand Down
2 changes: 1 addition & 1 deletion scripts/run-dev.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ docker run \
RealESRGAN_x2plus
"\
-e LLM_MODEL_IDS="
aisingapore/sea-lion-7b-instruct
aisingapore/llama3-8b-cpt-sea-lionv2-instruct
"\
-e C_FORCE_ROOT=1 \
-e BROKER_URL=${BROKER_URL:-"amqp://"} \
Expand Down

0 comments on commit 0773b08

Please sign in to comment.