diff --git a/rl/llm/config.py b/rl/llm/config.py index 36887fe..150624e 100644 --- a/rl/llm/config.py +++ b/rl/llm/config.py @@ -17,7 +17,7 @@ class LLMConfig: context_window_tokens: int = 0 max_new_tokens: int = 2048 temperature: float = 0.0 - frequency_penalty: float = 0.05 # Experiment with this + frequency_penalty: float = 0.2 # Experiment with this num_gpus: int | None = None visible_devices: str | None = None diff --git a/rl/llm/engines.py b/rl/llm/engines.py index eac2334..3ff3a37 100644 --- a/rl/llm/engines.py +++ b/rl/llm/engines.py @@ -23,6 +23,7 @@ import torch import tqdm.asyncio from anthropic import Anthropic +from google.generativeai.types import HarmBlockThreshold, HarmCategory from openai import OpenAI from openai.types.chat import ChatCompletionMessageParam from transformers import AutoTokenizer, PreTrainedTokenizer @@ -271,6 +272,12 @@ def generate(self, prompt: ChatInput) -> InferenceOutput: "response_mime_type": "text/plain", }, system_instruction=system_message, + safety_settings={ + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + }, ) chat_session = model.start_chat(history=prev_messages) # Can't include the last message in the history, because @@ -545,6 +552,7 @@ def _get_vllm_kwargs(llm_config): "disable_log_stats": True, "dtype": "auto", "gpu_memory_utilization": 0.9, + "enable_prefix_caching": True, "enable_lora": llm_config.lora_name_or_path is not None, "max_lora_rank": 32, }