diff --git a/rl/llm/engines/client.py b/rl/llm/engines/client.py index 64728cd..15fbcb0 100644 --- a/rl/llm/engines/client.py +++ b/rl/llm/engines/client.py @@ -255,22 +255,25 @@ def generate(self, prompt: ChatInput) -> InferenceOutput: original_prompt = copy.deepcopy(prompt) - system_prompt = None + extra_kwargs = {} if prompt[0]["role"] == "system": - system_prompt = prompt[0]["content"] + extra_kwargs["system"] = prompt[0]["content"] prompt = prompt[1:] if self.llm_config.max_new_tokens is None: + global _WARNED_MAX_TOKENS if not _WARNED_MAX_TOKENS: LOGGER.warning( "Anthropic requires a max_tokens value. Using 4096 by default. " "You can override this by setting max_new_tokens in the LLMConfig." ) + _WARNED_MAX_TOKENS = True + message = self.client.messages.create( model=self.llm_config.model_name_or_path, - system=system_prompt, messages=prompt, max_tokens=self.llm_config.max_new_tokens or 4096, + **extra_kwargs, ) return InferenceOutput( prompt=original_prompt,