Skip to content

Commit

Permalink
Fix anthropic issues
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Sep 22, 2024
1 parent 0552dfa commit 00ddd3e
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions rl/llm/engines/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,22 +255,25 @@ def generate(self, prompt: ChatInput) -> InferenceOutput:

original_prompt = copy.deepcopy(prompt)

system_prompt = None
extra_kwargs = {}
if prompt[0]["role"] == "system":
system_prompt = prompt[0]["content"]
extra_kwargs["system"] = prompt[0]["content"]
prompt = prompt[1:]

if self.llm_config.max_new_tokens is None:
global _WARNED_MAX_TOKENS
if not _WARNED_MAX_TOKENS:
LOGGER.warning(
"Anthropic requires a max_tokens value. Using 4096 by default. "
"You can override this by setting max_new_tokens in the LLMConfig."
)
_WARNED_MAX_TOKENS = True

message = self.client.messages.create(
model=self.llm_config.model_name_or_path,
system=system_prompt,
messages=prompt,
max_tokens=self.llm_config.max_new_tokens or 4096,
**extra_kwargs,
)
return InferenceOutput(
prompt=original_prompt,
Expand Down

0 comments on commit 00ddd3e

Please sign in to comment.