diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index dffe49060..533be6410 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -32,7 +32,6 @@ ) from functions.recipe_functions import LLMTools -DEFAULT_SYSTEM_MSG = "You are an intelligent AI assistant. Follow the instructions as closely as possible." DEFAULT_JSON_PROMPT = ( "Please respond directly in JSON format. " "Don't output markdown or HTML, instead print the JSON object directly without formatting." @@ -308,11 +307,26 @@ class LargeLanguageModels(Enum): ) sea_lion_7b_instruct = LLMSpec( - label="SEA-LION-7B-Instruct (aisingapore)", + label="SEA-LION-7B-Instruct [Deprecated] (aisingapore)", model_id="aisingapore/sea-lion-7b-instruct", llm_api=LLMApis.self_hosted, context_window=2048, price=1, + is_deprecated=True, + ) + llama3_8b_cpt_sea_lion_v2_instruct = LLMSpec( + label="Llama3 8B CPT SEA-LIONv2 Instruct (aisingapore)", + model_id="aisingapore/llama3-8b-cpt-sea-lionv2-instruct", + llm_api=LLMApis.self_hosted, + context_window=8192, + price=1, + ) + sarvam_2b = LLMSpec( + label="Sarvam 2B (sarvamai)", + model_id="sarvamai/sarvam-2b-v0.5", + llm_api=LLMApis.self_hosted, + context_window=2048, + price=1, ) # https://platform.openai.com/docs/models/gpt-3 @@ -452,7 +466,6 @@ def run_language_model( if prompt and not messages: # convert text prompt to chat messages messages = [ - format_chat_entry(role=CHATML_ROLE_SYSTEM, content=DEFAULT_SYSTEM_MSG), format_chat_entry(role=CHATML_ROLE_USER, content=prompt), ] if not model.is_vision_model: @@ -599,6 +612,17 @@ def _run_text_model( temperature=temperature, stop=stop, ) + case LLMApis.self_hosted: + return [ + _run_self_hosted_llm( + model=model, + text_inputs=prompt, + max_tokens=max_tokens, + temperature=temperature, + avoid_repetition=avoid_repetition, + stop=stop, + ) + ] case _: raise UserError(f"Unsupported text api: {api}") @@ -674,14 +698,19 @@ def _run_chat_model( stop=stop, ) case LLMApis.self_hosted: - return _run_self_hosted_chat( - model=model, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - avoid_repetition=avoid_repetition, - stop=stop, - ) + return [ + { + "role": CHATML_ROLE_ASSISTANT, + "content": _run_self_hosted_llm( + model=model, + text_inputs=messages, + max_tokens=max_tokens, + temperature=temperature, + avoid_repetition=avoid_repetition, + stop=stop, + ), + }, + ] # case LLMApis.together: # if tools: # raise UserError("Only OpenAI chat models support Tools") @@ -697,32 +726,36 @@ def _run_chat_model( raise UserError(f"Unsupported chat api: {api}") -def _run_self_hosted_chat( +def _run_self_hosted_llm( *, model: str, - messages: list[ConversationEntry], + text_inputs: list[ConversationEntry] | str, max_tokens: int, temperature: float, avoid_repetition: bool, stop: list[str] | None, -) -> list[dict]: +) -> str: from usage_costs.cost_utils import record_cost_auto from usage_costs.models import ModelSku # sea lion doesnt support system prompt - if model == LargeLanguageModels.sea_lion_7b_instruct.model_id: - for i, entry in enumerate(messages): + if ( + not isinstance(text_inputs, str) + and model == LargeLanguageModels.sea_lion_7b_instruct.model_id + ): + for i, entry in enumerate(text_inputs): if entry["role"] == CHATML_ROLE_SYSTEM: - messages[i]["role"] = CHATML_ROLE_USER - messages.insert(i + 1, dict(role=CHATML_ROLE_ASSISTANT, content="")) + text_inputs[i]["role"] = CHATML_ROLE_USER + text_inputs.insert(i + 1, dict(role=CHATML_ROLE_ASSISTANT, content="")) ret = call_celery_task( "llm.chat", pipeline=dict( model_id=model, + fallback_chat_template_from="meta-llama/Llama-2-7b-chat-hf", ), inputs=dict( - messages=messages, + text_inputs=text_inputs, max_new_tokens=max_tokens, stop_strings=stop, temperature=temperature, @@ -742,12 +775,7 @@ def _run_self_hosted_chat( quantity=usage["completion_tokens"], ) - return [ - { - "role": CHATML_ROLE_ASSISTANT, - "content": ret["generated_text"], - } - ] + return ret["generated_text"] def _run_anthropic_chat( diff --git a/scripts/init_llm_pricing.py b/scripts/init_llm_pricing.py index b93222d42..7c22b2c65 100644 --- a/scripts/init_llm_pricing.py +++ b/scripts/init_llm_pricing.py @@ -647,6 +647,26 @@ def run(): notes="Same as GPT-4o. Note that the actual cost of this model is in GPU Milliseconds", ) + llm_pricing_create( + model_id="aisingapore/llama3-8b-cpt-sea-lionv2-instruct", + model_name=LargeLanguageModels.llama3_8b_cpt_sea_lion_v2_instruct.name, + unit_cost_input=5, + unit_cost_output=15, + unit_quantity=10**6, + provider=ModelProvider.aks, + notes="Same as GPT-4o. Note that the actual cost of this model is in GPU Milliseconds", + ) + + llm_pricing_create( + model_id="sarvamai/sarvam-2b-v0.5", + model_name=LargeLanguageModels.sarvam_2b.name, + unit_cost_input=5, + unit_cost_output=15, + unit_quantity=10**6, + provider=ModelProvider.aks, + notes="Same as GPT-4o. Note that the actual cost of this model is in GPU Milliseconds", + ) + def llm_pricing_create( model_id: str,