From 0ca8923fedb4157ed592b44734121633468b2c46 Mon Sep 17 00:00:00 2001 From: daniel nakov Date: Tue, 7 May 2024 22:06:54 -0400 Subject: [PATCH] Add support for Google Gemini API --- r2ai/auto.py | 62 ++++++++++++++++++++++++++++++++++++--------- r2ai/interpreter.py | 33 ++++++++++++++++++++++++ r2ai/main.py | 1 + r2ai/models.py | 5 +++- 4 files changed, 88 insertions(+), 13 deletions(-) diff --git a/r2ai/auto.py b/r2ai/auto.py index 914e50c9..ffc07c30 100644 --- a/r2ai/auto.py +++ b/r2ai/auto.py @@ -86,6 +86,25 @@ def get_functionary_tokenizer(repo_id): functionary_tokenizer = AutoTokenizer.from_pretrained(repo_id, legacy=True) return functionary_tokenizer +def r2cmd(command: str): + """runs commands in radare2. You can run it multiple times or chain commands with pipes/semicolons. You can also use r2 interpreters to run scripts using the `#`, '#!', etc. commands. The output could be long, so try to use filters if possible or limit. This is your preferred tool""" + builtins.print('\x1b[1;32mRunning \x1b[4m' + command + '\x1b[0m') + res = r2lang.cmd(command) + builtins.print(res) + return res + +def run_python(command: str): + """runs a python script and returns the results""" + with open('r2ai_tmp.py', 'w') as f: + f.write(command) + builtins.print('\x1b[1;32mRunning \x1b[4m' + "python code" + '\x1b[0m') + builtins.print(command) + r2lang.cmd('#!python r2ai_tmp.py > $tmp') + res = r2lang.cmd('cat $tmp') + r2lang.cmd('rm r2ai_tmp.py') + builtins.print('\x1b[1;32mResult\x1b[0m\n' + res) + return res + def process_tool_calls(interpreter, tool_calls): interpreter.messages.append({ "content": None, "tool_calls": tool_calls, "role": "assistant" }) for tool_call in tool_calls: @@ -101,18 +120,9 @@ def process_tool_calls(interpreter, tool_calls): if type(args) is str: args = { "command": args } if "command" in args: - builtins.print('\x1b[1;32mRunning \x1b[4m' + args["command"] + '\x1b[0m') - res = r2lang.cmd(args["command"]) - builtins.print(res) + res = r2cmd(args["command"]) elif tool_call["function"]["name"] == "run_python": - with open('r2ai_tmp.py', 'w') as f: - f.write(args["command"]) - builtins.print('\x1b[1;32mRunning \x1b[4m' + "python code" + '\x1b[0m') - builtins.print(args["command"]) - r2lang.cmd('#!python r2ai_tmp.py > $tmp') - res = r2lang.cmd('cat $tmp') - r2lang.cmd('rm r2ai_tmp.py') - builtins.print('\x1b[1;32mResult\x1b[0m\n' + res) + res = run_python(args["command"]) if (not res or len(res) == 0) and interpreter.model.startswith('meetkai/'): res = "OK done" interpreter.messages.append({"role": "tool", "content": ANSI_REGEX.sub('', res), "name": tool_call["function"]["name"], "tool_call_id": tool_call["id"] if "id" in tool_call else None}) @@ -197,7 +207,7 @@ def chat(interpreter): lastmsg = interpreter.messages[-1]["content"] chat_context = context_from_msg (lastmsg) #print("#### CONTEXT BEGIN") - print(chat_context) # DEBUG + #print(chat_context) # DEBUG #print("#### CONTEXT END") if chat_context != "": interpreter.messages.insert(0,{"role": "user", "content": chat_context}) @@ -277,6 +287,34 @@ def chat(interpreter): temperature=float(interpreter.env["llm.temperature"]), ) process_streaming_response(interpreter, [response]) + elif interpreter.model.startswith("google"): + if not interpreter.google_client: + try: + import google.generativeai as google + google.configure(api_key=os.environ['GOOGLE_API_KEY']) + except ImportError: + print("pip install -U google-generativeai", file=sys.stderr) + return + + interpreter.google_client = google.GenerativeModel(interpreter.model[7:]) + if not interpreter.google_chat: + interpreter.google_chat = interpreter.google_client.start_chat( + enable_automatic_function_calling=True + ) + + response = interpreter.google_chat.send_message( + interpreter.messages[-1]["content"], + generation_config={ + "max_output_tokens": int(interpreter.env["llm.maxtokens"]), + "temperature": float(interpreter.env["llm.temperature"]) + }, + safety_settings=[{ + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + }], + tools=[r2cmd, run_python] + ) + print(response.text) else: chat_format = interpreter.llama_instance.chat_format is_functionary = interpreter.model.startswith("meetkai/") diff --git a/r2ai/interpreter.py b/r2ai/interpreter.py index f4f9593f..b78076e3 100644 --- a/r2ai/interpreter.py +++ b/r2ai/interpreter.py @@ -7,6 +7,8 @@ from .voice import tts from .const import R2AI_HOMEDIR from . import auto +import os + try: from openai import OpenAI have_openai = True @@ -28,6 +30,14 @@ have_groq = False pass +try: + import google.generativeai as google + google.configure(api_key=os.environ['GOOGLE_API_KEY']) + have_google = True +except Exception as e: + have_google = False + pass + import re import os import traceback @@ -502,6 +512,8 @@ def __init__(self): self.openai_client = None self.anthropic_client = None self.groq_client = None + self.google_client = None + self.google_chat = None self.api_base = None # Will set it to whatever OpenAI wants self.system_message = "" self.env["debug"] = "false" @@ -912,6 +924,27 @@ def respond(self): temperature=float(self.env["llm.temperature"]), messages=self.messages ) + if self.env["chat.reply"] == "true": + self.messages.append({"role": "assistant", "content": completion.content}) + print(completion.content) + elif self.model.startswith('google:'): + if have_google: + if not self.google_client: + self.google_client = google.GenerativeModel(self.model[7:]) + if not self.google_chat: + self.google_chat = self.google_client.start_chat() + + completion = self.google_chat.send_message( + self.messages[-1]["content"], + generation_config={ + "max_output_tokens": maxtokens, + "temperature": float(self.env["llm.temperature"]) + } + ) + if self.env["chat.reply"] == "true": + self.messages.append({"role": "assistant", "content": completion.text}) + print(completion.text) + return else: # non-openai aka local-llama model if self.llama_instance == None: diff --git a/r2ai/main.py b/r2ai/main.py index b14a3a2d..3d87d7ee 100755 --- a/r2ai/main.py +++ b/r2ai/main.py @@ -51,6 +51,7 @@ def __main__(): have_r2pipe = True except: pass + if not have_rlang and not have_r2pipe and sys.argv[0] != 'main.py' and os.path.exists("venv/bin/python"): os.system("venv/bin/python main.py") sys.exit(0) diff --git a/r2ai/models.py b/r2ai/models.py index 3e922583..01435c5e 100644 --- a/r2ai/models.py +++ b/r2ai/models.py @@ -60,6 +60,9 @@ def models(): -m groq:gemma-7b-it -m groq:llama2-70b-4096 -m groq:mixtral-8x7b-32768 +Google: +-m google:gemini-1.0-pro +-m google:gemini-1.5-pro-latest GPT4: -m NousResearch/Hermes-2-Pro-Mistral-7B-GGUF -m TheBloke/Chronos-70B-v2-GGUF @@ -472,7 +475,7 @@ def enough_disk_space(size, path) -> bool: return False def new_get_hf_llm(repo_id, debug_mode, context_window): - if repo_id.startswith("openai:") or repo_id.startswith("anthropic:") or repo_id.startswith("groq:"): + if repo_id.startswith("openai:") or repo_id.startswith("anthropic:") or repo_id.startswith("groq:") or repo_id.startswith("google:"): return repo_id if not os.path.exists(repo_id): return get_hf_llm(repo_id, debug_mode, context_window)