Skip to content

Commit

Permalink
Fix max_tokens issue
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Sep 22, 2024
1 parent 2bb0a02 commit 0552dfa
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
11 changes: 10 additions & 1 deletion rl/llm/engines/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ def _convert_openai_to_gemini(
)


_WARNED_MAX_TOKENS = False


@register_engine("anthropic", required_modules=("anthropic",))
class AnthropicEngine(ClientEngine):
BASE_URL = "https://api.anthropic.com/v1"
Expand Down Expand Up @@ -257,11 +260,17 @@ def generate(self, prompt: ChatInput) -> InferenceOutput:
system_prompt = prompt[0]["content"]
prompt = prompt[1:]

if self.llm_config.max_new_tokens is None:
if not _WARNED_MAX_TOKENS:
LOGGER.warning(
"Anthropic requires a max_tokens value. Using 4096 by default. "
"You can override this by setting max_new_tokens in the LLMConfig."
)
message = self.client.messages.create(
model=self.llm_config.model_name_or_path,
system=system_prompt,
messages=prompt,
max_tokens=self.llm_config.max_new_tokens,
max_tokens=self.llm_config.max_new_tokens or 4096,
)
return InferenceOutput(
prompt=original_prompt,
Expand Down
4 changes: 2 additions & 2 deletions rl/llm/engines/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict
from strenum import StrEnum

import rl.utils.io
Expand All @@ -13,7 +13,7 @@ class LLMConfig(BaseModel):
context_window_tokens: int | None = None
max_new_tokens: int | None = None
temperature: float = 0.0
frequency_penalty: float = Field(0.2, description="Experiment with this")
frequency_penalty: float = 0.2 # Experiment with this

json_output: bool = False
return_logprobs: bool = False
Expand Down
4 changes: 3 additions & 1 deletion rl/llm/train_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def _get_default_output_dir(name: str) -> Path:
return _DEFAULT_BASE_OUTPUT_DIR / name


def get_dataset(train_data_path: Path, val_data_path: Path) -> datasets.Dataset:
def get_dataset(
train_data_path: Path, val_data_path: Path | None = None
) -> datasets.Dataset:
df = pd.read_json(train_data_path, lines=True)
# Removing the metadata because it causes weird problems when loading the dataset.
df = df.drop(columns=["metadata"])
Expand Down

0 comments on commit 0552dfa

Please sign in to comment.