From 0552dfaf2e402eb0f298b3dfa4d3e5c6413dd574 Mon Sep 17 00:00:00 2001 From: Faiz Surani Date: Sun, 22 Sep 2024 10:56:50 -0700 Subject: [PATCH] Fix max_tokens issue --- rl/llm/engines/client.py | 11 ++++++++++- rl/llm/engines/config.py | 4 ++-- rl/llm/train_llm.py | 4 +++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/rl/llm/engines/client.py b/rl/llm/engines/client.py index c8367ed..64728cd 100644 --- a/rl/llm/engines/client.py +++ b/rl/llm/engines/client.py @@ -222,6 +222,9 @@ def _convert_openai_to_gemini( ) +_WARNED_MAX_TOKENS = False + + @register_engine("anthropic", required_modules=("anthropic",)) class AnthropicEngine(ClientEngine): BASE_URL = "https://api.anthropic.com/v1" @@ -257,11 +260,17 @@ def generate(self, prompt: ChatInput) -> InferenceOutput: system_prompt = prompt[0]["content"] prompt = prompt[1:] + if self.llm_config.max_new_tokens is None: + if not _WARNED_MAX_TOKENS: + LOGGER.warning( + "Anthropic requires a max_tokens value. Using 4096 by default. " + "You can override this by setting max_new_tokens in the LLMConfig." + ) message = self.client.messages.create( model=self.llm_config.model_name_or_path, system=system_prompt, messages=prompt, - max_tokens=self.llm_config.max_new_tokens, + max_tokens=self.llm_config.max_new_tokens or 4096, ) return InferenceOutput( prompt=original_prompt, diff --git a/rl/llm/engines/config.py b/rl/llm/engines/config.py index 8bb2cf8..2f78f20 100644 --- a/rl/llm/engines/config.py +++ b/rl/llm/engines/config.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict from strenum import StrEnum import rl.utils.io @@ -13,7 +13,7 @@ class LLMConfig(BaseModel): context_window_tokens: int | None = None max_new_tokens: int | None = None temperature: float = 0.0 - frequency_penalty: float = Field(0.2, description="Experiment with this") + frequency_penalty: float = 0.2 # Experiment with this json_output: bool = False return_logprobs: bool = False diff --git a/rl/llm/train_llm.py b/rl/llm/train_llm.py index 24623b3..79639af 100644 --- a/rl/llm/train_llm.py +++ b/rl/llm/train_llm.py @@ -237,7 +237,9 @@ def _get_default_output_dir(name: str) -> Path: return _DEFAULT_BASE_OUTPUT_DIR / name -def get_dataset(train_data_path: Path, val_data_path: Path) -> datasets.Dataset: +def get_dataset( + train_data_path: Path, val_data_path: Path | None = None +) -> datasets.Dataset: df = pd.read_json(train_data_path, lines=True) # Removing the metadata because it causes weird problems when loading the dataset. df = df.drop(columns=["metadata"])