Skip to content

Commit

Permalink
Refactor self-hosted LLM handling; add new models
Browse files Browse the repository at this point in the history
Remove default system message; deprecate SEA-LION-7B-Instruct and add Llama3 8B CPT SEA-LIONv2 Instruct and Sarvam 2B models. Update _run_self_hosted_llm function to handle text inputs. Refactor _run_chat_model and _run_text_model to use updated function.
  • Loading branch information
devxpy committed Aug 13, 2024
1 parent 2bc7859 commit d424854
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 25 deletions.
78 changes: 53 additions & 25 deletions daras_ai_v2/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
)
from functions.recipe_functions import LLMTools

DEFAULT_SYSTEM_MSG = "You are an intelligent AI assistant. Follow the instructions as closely as possible."
DEFAULT_JSON_PROMPT = (
"Please respond directly in JSON format. "
"Don't output markdown or HTML, instead print the JSON object directly without formatting."
Expand Down Expand Up @@ -308,11 +307,26 @@ class LargeLanguageModels(Enum):
)

sea_lion_7b_instruct = LLMSpec(
label="SEA-LION-7B-Instruct (aisingapore)",
label="SEA-LION-7B-Instruct [Deprecated] (aisingapore)",
model_id="aisingapore/sea-lion-7b-instruct",
llm_api=LLMApis.self_hosted,
context_window=2048,
price=1,
is_deprecated=True,
)
llama3_8b_cpt_sea_lion_v2_instruct = LLMSpec(
label="Llama3 8B CPT SEA-LIONv2 Instruct (aisingapore)",
model_id="aisingapore/llama3-8b-cpt-sea-lionv2-instruct",
llm_api=LLMApis.self_hosted,
context_window=8192,
price=1,
)
sarvam_2b = LLMSpec(
label="Sarvam 2B (sarvamai)",
model_id="sarvamai/sarvam-2b-v0.5",
llm_api=LLMApis.self_hosted,
context_window=2048,
price=1,
)

# https://platform.openai.com/docs/models/gpt-3
Expand Down Expand Up @@ -452,7 +466,6 @@ def run_language_model(
if prompt and not messages:
# convert text prompt to chat messages
messages = [
format_chat_entry(role=CHATML_ROLE_SYSTEM, content=DEFAULT_SYSTEM_MSG),
format_chat_entry(role=CHATML_ROLE_USER, content=prompt),
]
if not model.is_vision_model:
Expand Down Expand Up @@ -599,6 +612,17 @@ def _run_text_model(
temperature=temperature,
stop=stop,
)
case LLMApis.self_hosted:
return [
_run_self_hosted_llm(
model=model,
text_inputs=prompt,
max_tokens=max_tokens,
temperature=temperature,
avoid_repetition=avoid_repetition,
stop=stop,
)
]
case _:
raise UserError(f"Unsupported text api: {api}")

Expand Down Expand Up @@ -674,14 +698,19 @@ def _run_chat_model(
stop=stop,
)
case LLMApis.self_hosted:
return _run_self_hosted_chat(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
avoid_repetition=avoid_repetition,
stop=stop,
)
return [
{
"role": CHATML_ROLE_ASSISTANT,
"content": _run_self_hosted_llm(
model=model,
text_inputs=messages,
max_tokens=max_tokens,
temperature=temperature,
avoid_repetition=avoid_repetition,
stop=stop,
),
},
]
# case LLMApis.together:
# if tools:
# raise UserError("Only OpenAI chat models support Tools")
Expand All @@ -697,32 +726,36 @@ def _run_chat_model(
raise UserError(f"Unsupported chat api: {api}")


def _run_self_hosted_chat(
def _run_self_hosted_llm(
*,
model: str,
messages: list[ConversationEntry],
text_inputs: list[ConversationEntry] | str,
max_tokens: int,
temperature: float,
avoid_repetition: bool,
stop: list[str] | None,
) -> list[dict]:
) -> str:
from usage_costs.cost_utils import record_cost_auto
from usage_costs.models import ModelSku

# sea lion doesnt support system prompt
if model == LargeLanguageModels.sea_lion_7b_instruct.model_id:
for i, entry in enumerate(messages):
if (
not isinstance(text_inputs, str)
and model == LargeLanguageModels.sea_lion_7b_instruct.model_id
):
for i, entry in enumerate(text_inputs):
if entry["role"] == CHATML_ROLE_SYSTEM:
messages[i]["role"] = CHATML_ROLE_USER
messages.insert(i + 1, dict(role=CHATML_ROLE_ASSISTANT, content=""))
text_inputs[i]["role"] = CHATML_ROLE_USER
text_inputs.insert(i + 1, dict(role=CHATML_ROLE_ASSISTANT, content=""))

ret = call_celery_task(
"llm.chat",
pipeline=dict(
model_id=model,
fallback_chat_template_from="meta-llama/Llama-2-7b-chat-hf",
),
inputs=dict(
messages=messages,
text_inputs=text_inputs,
max_new_tokens=max_tokens,
stop_strings=stop,
temperature=temperature,
Expand All @@ -742,12 +775,7 @@ def _run_self_hosted_chat(
quantity=usage["completion_tokens"],
)

return [
{
"role": CHATML_ROLE_ASSISTANT,
"content": ret["generated_text"],
}
]
return ret["generated_text"]


def _run_anthropic_chat(
Expand Down
20 changes: 20 additions & 0 deletions scripts/init_llm_pricing.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,26 @@ def run():
notes="Same as GPT-4o. Note that the actual cost of this model is in GPU Milliseconds",
)

llm_pricing_create(
model_id="aisingapore/llama3-8b-cpt-sea-lionv2-instruct",
model_name=LargeLanguageModels.llama3_8b_cpt_sea_lion_v2_instruct.name,
unit_cost_input=5,
unit_cost_output=15,
unit_quantity=10**6,
provider=ModelProvider.aks,
notes="Same as GPT-4o. Note that the actual cost of this model is in GPU Milliseconds",
)

llm_pricing_create(
model_id="sarvamai/sarvam-2b-v0.5",
model_name=LargeLanguageModels.sarvam_2b.name,
unit_cost_input=5,
unit_cost_output=15,
unit_quantity=10**6,
provider=ModelProvider.aks,
notes="Same as GPT-4o. Note that the actual cost of this model is in GPU Milliseconds",
)


def llm_pricing_create(
model_id: str,
Expand Down

0 comments on commit d424854

Please sign in to comment.