From db642845a2215927b80b21e82f0e12752039ea87 Mon Sep 17 00:00:00 2001 From: Daniel Nakov Date: Thu, 7 Nov 2024 19:10:28 +0000 Subject: [PATCH] Add new execute_binary command for auto; some UI bug fixes --- r2ai/auto.py | 16 ++++++++++++++-- r2ai/tools.py | 27 +++++++++++++++++++++++++++ r2ai/ui/app.py | 14 ++++++++++++-- r2ai/ui/chat.py | 4 ++-- 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/r2ai/auto.py b/r2ai/auto.py index dbd417a..639b528 100644 --- a/r2ai/auto.py +++ b/r2ai/auto.py @@ -6,7 +6,8 @@ import litellm from litellm import _should_retry, acompletion, utils, ModelResponse import asyncio -from .tools import r2cmd, run_python +from .pipe import get_filename +from .tools import r2cmd, run_python, execute_binary import json import signal from .spinner import spinner @@ -259,11 +260,22 @@ def cb(type, data): if 'content' in data: sys.stdout.write(data['content']) elif type == 'tool_call': + builtins.print() if data['function']['name'] == 'r2cmd': builtins.print('\x1b[1;32m> \x1b[4m' + data['function']['arguments']['command'] + '\x1b[0m') elif data['function']['name'] == 'run_python': builtins.print('\x1b[1;32m> \x1b[4m' + "#!python" + '\x1b[0m') builtins.print(data['function']['arguments']['command']) + elif data['function']['name'] == 'execute_binary': + filename = get_filename() + stdin = data['function']['arguments']['stdin'] + args = data['function']['arguments']['args'] + cmd = filename + if len(args) > 0: + cmd += ' ' + ' '.join(args) + if stdin: + cmd += f' stdin={stdin}' + builtins.print('\x1b[1;32m> \x1b[4m' + cmd + '\x1b[0m') elif type == 'tool_response': if 'content' in data: sys.stdout.write(data['content']) @@ -277,7 +289,7 @@ def signal_handler(signum, frame): def chat(interpreter, **kwargs): model = interpreter.model.replace(":", "/") - tools = [r2cmd, run_python] + tools = [r2cmd, run_python, execute_binary] messages = interpreter.messages tool_choice = 'auto' diff --git a/r2ai/tools.py b/r2ai/tools.py index fabf958..6d89854 100644 --- a/r2ai/tools.py +++ b/r2ai/tools.py @@ -28,6 +28,8 @@ def r2cmd(command: str): return log_messages return res['res'] + except json.JSONDecodeError: + return res except Exception as e: # return { 'type': 'error', 'output': f"Error running r2cmd: {e}\nCommand: {command}\nResponse: {res}" } return f"Error running r2cmd: {e}\nCommand: {command}\nResponse: {res}" @@ -53,3 +55,28 @@ def run_python(command: str): res = r2.cmd('#!python r2ai_tmp.py') r2.cmd('rm r2ai_tmp.py') return res + +def execute_binary(args: list[str] = [], stdin: str = ""): + """ + Execute a binary with the given arguments and stdin + + Parameters + ---------- + args: list[str] + The arguments to pass to the binary + stdin: str + The stdin to pass to the binary + + Returns + ------- + str + The output of the binary + """ + + r2 = get_r2_inst() + if len(args) > 0: + r2.cmd(f"dor {' '.join(args)}") + if stdin: + r2.cmd(f'dor stdin="{stdin}"') + r2.cmd("ood") + return r2cmd("dc") diff --git a/r2ai/ui/app.py b/r2ai/ui/app.py index f7bd267..01ac4ec 100644 --- a/r2ai/ui/app.py +++ b/r2ai/ui/app.py @@ -193,7 +193,17 @@ def on_message(self, type: str, message: any) -> None: except NoMatches: existing = self.add_message(message["id"], "AI", message["content"]) elif type == 'tool_call': - self.add_message(message["id"], "Tool Call", f"{message['function']['name']} > {message['function']['arguments']['command']}") + if 'command' in message['function']['arguments']: + self.add_message(message["id"], "Tool Call", f"{message['function']['name']} > {message['function']['arguments']['command']}") + elif message['function']['name'] == 'execute_binary': + args = message['function']['arguments'] + output = get_filename() + if 'args' in args and len(args['args']) > 0: + output += f" {args['args'].join(' ')}\n" + if 'stdin' in args and len(args['stdin']) > 0: + output += f" stdin={args['stdin']}\n" + + self.add_message(message["id"], "Tool Call", f"{message['function']['name']} > {output}") elif type == 'tool_response': self.add_message(message["id"], "Tool Response", message['content']) @@ -218,7 +228,7 @@ async def validate_model(self) -> None: if not model: await self.select_model() if is_litellm_model(model): - model = self.ai.model + model = self.ai.model.replace(':', '/') keys = validate_environment(model) if keys['keys_in_environment'] is False: await self.push_screen_wait(ModelConfigDialog(keys['missing_keys'])) diff --git a/r2ai/ui/chat.py b/r2ai/ui/chat.py index 9e823fa..9b88299 100644 --- a/r2ai/ui/chat.py +++ b/r2ai/ui/chat.py @@ -3,7 +3,7 @@ import json import signal from r2ai.pipe import get_r2_inst -from r2ai.tools import run_python, r2cmd +from r2ai.tools import run_python, r2cmd, execute_binary from r2ai.repl import r2ai_singleton from r2ai.auto import ChatAuto, SYSTEM_PROMPT_AUTO from r2ai.interpreter import is_litellm_model @@ -14,7 +14,7 @@ def signal_handler(signum, frame): async def chat(ai, message, cb): model = ai.model.replace(":", "/") - tools = [r2cmd, run_python] + tools = [r2cmd, run_python, execute_binary] messages = ai.messages + [{"role": "user", "content": message}] tool_choice = 'auto' if not is_litellm_model(model) and ai and not ai.llama_instance: