diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 529ed9379..2ce84f9b8 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -1049,6 +1049,10 @@ def run_openai_chat( for entry in messages: if entry["role"] == CHATML_ROLE_SYSTEM: entry["role"] = CHATML_ROLE_USER + max_tokens = NOT_GIVEN + else: + max_tokens = max_completion_tokens + max_completion_tokens = NOT_GIVEN if avoid_repetition: frequency_penalty = 0.1 @@ -1063,6 +1067,7 @@ def run_openai_chat( _get_chat_completions_create( model=model_str, messages=messages, + max_tokens=max_tokens, max_completion_tokens=max_completion_tokens, stop=stop or NOT_GIVEN, n=num_outputs,