Skip to content

Commit

Permalink
Fix not getting python output when running as plugin; Fix litellm con…
Browse files Browse the repository at this point in the history
…nection reset on openai calls; switch run_python and execute_binary to just use python subprocess, seems to mess up r2pipe
  • Loading branch information
dnakov committed Nov 9, 2024
1 parent 59f891f commit 50a659c
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 63 deletions.
46 changes: 16 additions & 30 deletions r2ai/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from litellm import _should_retry, acompletion, utils, ModelResponse
import asyncio
from .pipe import get_filename
from .tools import r2cmd, run_python, execute_binary
from .tools import r2cmd, run_python, execute_binary, schemas, print_tool_call
import json
import signal
from .spinner import spinner
Expand Down Expand Up @@ -40,7 +40,7 @@
"""

class ChatAuto:
def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interpreter=None, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, timeout=None, stream=True, cb=None ):
def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interpreter=None, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, timeout=60, stream=True, cb=None ):
self.logger = LOGGER
self.functions = {}
self.tools = []
Expand All @@ -63,9 +63,13 @@ def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interp
self.tool_choice = None
if tools:
for tool in tools:
f = utils.function_to_dict(tool)
self.tools.append({ "type": "function", "function": f })
self.functions[f['name']] = tool
if tool.__name__ in schemas:
schema = schemas[tool.__name__]
else:
schema = utils.function_to_dict(tool)

self.tools.append({ "type": "function", "function": schema })
self.functions[tool.__name__] = tool
self.tool_choice = tool_choice
self.llama_instance = llama_instance or interpreter.llama_instance if interpreter else None
#self.tool_end_message = '\nNOTE: The user saw this output, do not repeat it.'
Expand Down Expand Up @@ -143,7 +147,9 @@ async def process_streaming_response(self, resp):
self.cb('message', { "content": "", "id": 'message_' + chunk.id, 'done': True })
self.cb('message_stream', { "content": m if m else '', "id": 'message_' + chunk.id, 'done': done })
self.messages.append(current_message)
if len(current_message['tool_calls']) > 0:
if len(current_message['tool_calls']) == 0:
del current_message['tool_calls']
else:
await self.process_tool_calls(current_message['tool_calls'])
return current_message

Expand Down Expand Up @@ -247,8 +253,8 @@ async def get_completion(self):
async def achat(self, messages=None) -> str:
if messages:
self.messages = messages
self.logger.debug(self.messages)
response = await self.get_completion()
self.logger.debug(f'chat complete')
return response

def chat(self, **kwargs) -> str:
Expand All @@ -261,25 +267,12 @@ def cb(type, data):
sys.stdout.write(data['content'])
elif type == 'tool_call':
builtins.print()
if data['function']['name'] == 'r2cmd':
builtins.print('\x1b[1;32m> \x1b[4m' + data['function']['arguments']['command'] + '\x1b[0m')
elif data['function']['name'] == 'run_python':
builtins.print('\x1b[1;32m> \x1b[4m' + "#!python" + '\x1b[0m')
builtins.print(data['function']['arguments']['command'])
elif data['function']['name'] == 'execute_binary':
filename = get_filename()
stdin = data['function']['arguments']['stdin']
args = data['function']['arguments']['args']
cmd = filename
if len(args) > 0:
cmd += ' ' + ' '.join(args)
if stdin:
cmd += f' stdin={stdin}'
builtins.print('\x1b[1;32m> \x1b[4m' + cmd + '\x1b[0m')
print_tool_call(data)
elif type == 'tool_response':
if 'content' in data:
sys.stdout.write(data['content'])
sys.stdout.flush()
builtins.print()
# builtins.print(data['content'])
elif type == 'message' and data['done']:
builtins.print()
Expand Down Expand Up @@ -324,11 +317,4 @@ def chat(interpreter, **kwargs):
finally:
signal.signal(signal.SIGINT, original_handler)
spinner.stop()
try:
pending = asyncio.all_tasks(loop=loop)
for task in pending:
task.cancel()
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
loop.close()
litellm.in_memory_llm_clients_cache.clear()
1 change: 1 addition & 0 deletions r2ai/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

import os
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
import sys
import builtins
import traceback
Expand Down
8 changes: 7 additions & 1 deletion r2ai/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .interpreter import Interpreter
from .pipe import have_rlang, r2lang, r2singleton
from r2ai import bubble, LOGGER

from .test import run_test
tab_init()

print_buffer = ""
Expand Down Expand Up @@ -217,6 +217,10 @@ def runline(ai, usertext):
print(help_message)
elif usertext.startswith("clear") or usertext.startswith("-k"):
print("\x1b[2J\x1b[0;0H\r")
if ai.messages:
ai.messages = []
if autoai and autoai.messages:
autoai.messages = []
elif usertext.startswith("-MM"):
print(models().strip())
elif usertext.startswith("-M"):
Expand Down Expand Up @@ -469,6 +473,8 @@ def runline(ai, usertext):
print("r2 is not available", file=sys.stderr)
else:
builtins.print(r2_cmd(usertext[1:]))
elif usertext.startswith("--test"):
run_test()
elif usertext.startswith("-"):
print(f"Unknown flag '{usertext}'. See 'r2ai -h' for help", file=sys.stderr)
else:
Expand Down
37 changes: 37 additions & 0 deletions r2ai/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import builtins
from .tools import run_python, execute_binary, r2cmd
import subprocess
from .pipe import get_filename
import time
py_code = """
print('hello test')
"""

def run_test(args):
if not args or len(args) == 0:
res = run_python(py_code).strip()
print(f"run_python: {res}", len(res))
assert res == "hello test"
print("run_python: test passed")
r2cmd("o--;o /bin/ls")
res = execute_binary(args=["-d", "/etc"]).strip()
subp = subprocess.run(["/bin/ls", "-d", "/etc"], capture_output=True, text=True)
print("exec result", res)
print("subp result", subp.stdout)
assert ''.join(res).strip() == subp.stdout.strip()
print("execute_binary with args: test passed")
else:
cmd, *args = args.split(" ", 1)
if cmd == "get_filename":
builtins.print(get_filename())
elif cmd == "run_python":
builtins.print(f"--- args ---")
builtins.print(args)
builtins.print(f"--- end args ---")
builtins.print(f"--- result ---")
builtins.print(run_python(args[0]))
builtins.print(f"--- end result ---")
elif cmd == "r2cmd":
builtins.print(f"--- {args} ---")
builtins.print(r2cmd(args))
builtins.print("--- end ---")
145 changes: 119 additions & 26 deletions r2ai/tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
from r2ai.pipe import get_r2_inst
import json
import builtins
import base64
from .pipe import get_filename
from . import LOGGER
import time
import sys
from io import StringIO
import subprocess
import os
is_plugin = False
try:
import r2lang
is_plugin = True
except Exception:
is_plugin = False
pass

def r2cmd(command: str):
"""
Expand All @@ -17,8 +32,11 @@ def r2cmd(command: str):
The output of the r2 command
"""
r2 = get_r2_inst()
if command.startswith('r2 '):
return "You are already in r2!"
cmd = '{"cmd":' + json.dumps(command) + '}'
res = r2.cmd(cmd)

try:
res = json.loads(res)
if 'error' in res and res['error'] is True:
Expand All @@ -29,6 +47,10 @@ def r2cmd(command: str):

return res['res']
except json.JSONDecodeError:
if type(res) == str:
spl = res.strip().split('\n')
if spl[-1].startswith('{"res":""'):
res = '\n'.join(spl[:-1])
return res
except Exception as e:
# return { 'type': 'error', 'output': f"Error running r2cmd: {e}\nCommand: {command}\nResponse: {res}" }
Expand All @@ -49,34 +71,105 @@ def run_python(command: str):
The output of the python script
"""
r2 = get_r2_inst()
with open('r2ai_tmp.py', 'w') as f:
f.write(command)
r2 = get_r2_inst()
res = r2.cmd('#!python r2ai_tmp.py')
r2.cmd('rm r2ai_tmp.py')
return res
res = ""
is_plugin = False
python_path = sys.executable
try:
proc = subprocess.run([python_path, '-c', command],
capture_output=True,
text=True)
res = proc.stdout
if proc.stderr:
res += proc.stderr
except Exception as e:
res = str(e)

# if is_plugin:
# base64cmd = base64.b64encode(command.encode('utf-8')).decode('utf-8')
# res += r2cmd(f'#!python -e base64:{base64cmd} > .r2ai_tmp.log')
# res += r2cmd('cat .r2ai_tmp.log')
# r2cmd('rm .r2ai_tmp.log')
# else:
# with open('r2ai_tmp.py', 'w') as f:
# f.write(command)
# r2 = get_r2_inst()
# res += r2cmd('#!python r2ai_tmp.py > .r2ai_tmp.log')
# time.sleep(0.1)
# res += r2cmd('!cat .r2ai_tmp.log')
# LOGGER.debug(f'run_python: {res}')
# # r2cmd('rm r2ai_tmp.py')
# # r2cmd('rm .r2ai_tmp.log')
return res


schemas = {
"execute_binary": {
"name": "execute_binary",
"description": "Execute a binary with the given arguments and stdin",
"parameters": {
"type": "object",
"properties": {
"args": {
"description": "The arguments to pass to the binary. Do not include the file name.",
"type": "array",
"items": {
"type": "string"
}
},
"stdin": {
"type": "string"
}
}
}
}
}

def execute_binary(args: list[str] = [], stdin: str = ""):
"""
Execute a binary with the given arguments and stdin
filename = get_filename()
if filename:
if os.path.isabs(filename):
abs_path = os.path.abspath(filename)
if os.path.exists(abs_path):
filename = abs_path
else:
cwd_path = os.path.join(os.getcwd(), filename)
if os.path.exists(cwd_path):
filename = cwd_path
try:
cmd = [filename] + args
proc = subprocess.run(cmd, input=stdin, capture_output=True, text=True)
res = proc.stdout
if proc.stderr:
res += proc.stderr
return res
except Exception as e:
return str(e)
return ""
# r2 = get_r2_inst()
# if stdin:
# r2.cmd(f'dor stdin={json.dumps(stdin)}')
# if len(args) > 0:
# r2.cmd(f"ood {' '.join(args)}")
# else:
# r2.cmd("ood")
# res = r2cmd("dc")
# return res

Parameters
----------
args: list[str]
The arguments to pass to the binary
stdin: str
The stdin to pass to the binary

Returns
-------
str
The output of the binary
"""

r2 = get_r2_inst()
if len(args) > 0:
r2.cmd(f"dor {' '.join(args)}")
if stdin:
r2.cmd(f'dor stdin="{stdin}"')
r2.cmd("ood")
return r2cmd("dc")
def print_tool_call(msg):
if msg['function']['name'] == 'r2cmd':
builtins.print('\x1b[1;32m> \x1b[4m' + msg['function']['arguments']['command'] + '\x1b[0m')
elif msg['function']['name'] == 'run_python':
builtins.print('\x1b[1;32m> \x1b[4m' + "#!python" + '\x1b[0m')
builtins.print(msg['function']['arguments']['command'])
elif msg['function']['name'] == 'execute_binary':
filename = get_filename() or 'bin'
stdin = msg['function']['arguments']['stdin'] if 'stdin' in msg['function']['arguments'] else None
args = msg['function']['arguments']['args'] if 'args' in msg['function']['arguments'] else []
cmd = filename
if args and len(args) > 0:
cmd += ' ' + ' '.join(args)
if stdin and len(stdin) > 0:
cmd += f' stdin={stdin}'
builtins.print('\x1b[1;32m> \x1b[4m' + cmd + '\x1b[0m')
9 changes: 6 additions & 3 deletions r2ai/ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from .chat import chat
import asyncio
import json
import re
ANSI_REGEX = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')


class ModelConfigDialog(SystemModalScreen):
def __init__(self, keys: list[str]) -> None:
Expand Down Expand Up @@ -197,15 +200,15 @@ def on_message(self, type: str, message: any) -> None:
self.add_message(message["id"], "Tool Call", f"{message['function']['name']} > {message['function']['arguments']['command']}")
elif message['function']['name'] == 'execute_binary':
args = message['function']['arguments']
output = get_filename()
output = get_filename() or "bin"
if 'args' in args and len(args['args']) > 0:
output += f" {args['args'].join(' ')}\n"
output += f" {' '.join(args['args'])}\n"
if 'stdin' in args and len(args['stdin']) > 0:
output += f" stdin={args['stdin']}\n"

self.add_message(message["id"], "Tool Call", f"{message['function']['name']} > {output}")
elif type == 'tool_response':
self.add_message(message["id"], "Tool Response", message['content'])
self.add_message(message["id"], "Tool Response", ANSI_REGEX.sub('', message['content']))

async def send_message(self) -> None:
input_widget = self.query_one("#chat-input", Input)
Expand Down
6 changes: 3 additions & 3 deletions r2ai/ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def signal_handler(signum, frame):
async def chat(ai, message, cb):
model = ai.model.replace(":", "/")
tools = [r2cmd, run_python, execute_binary]
messages = ai.messages + [{"role": "user", "content": message}]
ai.messages.append({"role": "user", "content": message})
tool_choice = 'auto'
if not is_litellm_model(model) and ai and not ai.llama_instance:
ai.llama_instance = new_get_hf_llm(ai, model, int(ai.env["llm.window"]))

chat_auto = ChatAuto(model, interpreter=ai, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, cb=cb)
chat_auto = ChatAuto(model, interpreter=ai, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=ai.messages, tool_choice=tool_choice, cb=cb)

return await chat_auto.achat()

0 comments on commit 50a659c

Please sign in to comment.