Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix not getting python output when running as plugin; fix litellm con… #91

Merged
merged 1 commit into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 16 additions & 30 deletions r2ai/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from litellm import _should_retry, acompletion, utils, ModelResponse
import asyncio
from .pipe import get_filename
from .tools import r2cmd, run_python, execute_binary
from .tools import r2cmd, run_python, execute_binary, schemas, print_tool_call
import json
import signal
from .spinner import spinner
Expand Down Expand Up @@ -40,7 +40,7 @@
"""

class ChatAuto:
def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interpreter=None, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, timeout=None, stream=True, cb=None ):
def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interpreter=None, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, timeout=60, stream=True, cb=None ):
self.logger = LOGGER
self.functions = {}
self.tools = []
Expand All @@ -63,9 +63,13 @@ def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interp
self.tool_choice = None
if tools:
for tool in tools:
f = utils.function_to_dict(tool)
self.tools.append({ "type": "function", "function": f })
self.functions[f['name']] = tool
if tool.__name__ in schemas:
schema = schemas[tool.__name__]
else:
schema = utils.function_to_dict(tool)

self.tools.append({ "type": "function", "function": schema })
self.functions[tool.__name__] = tool
self.tool_choice = tool_choice
self.llama_instance = llama_instance or interpreter.llama_instance if interpreter else None
#self.tool_end_message = '\nNOTE: The user saw this output, do not repeat it.'
Expand Down Expand Up @@ -143,7 +147,9 @@ async def process_streaming_response(self, resp):
self.cb('message', { "content": "", "id": 'message_' + chunk.id, 'done': True })
self.cb('message_stream', { "content": m if m else '', "id": 'message_' + chunk.id, 'done': done })
self.messages.append(current_message)
if len(current_message['tool_calls']) > 0:
if len(current_message['tool_calls']) == 0:
del current_message['tool_calls']
else:
await self.process_tool_calls(current_message['tool_calls'])
return current_message

Expand Down Expand Up @@ -247,8 +253,8 @@ async def get_completion(self):
async def achat(self, messages=None) -> str:
if messages:
self.messages = messages
self.logger.debug(self.messages)
response = await self.get_completion()
self.logger.debug(f'chat complete')
return response

def chat(self, **kwargs) -> str:
Expand All @@ -261,25 +267,12 @@ def cb(type, data):
sys.stdout.write(data['content'])
elif type == 'tool_call':
builtins.print()
if data['function']['name'] == 'r2cmd':
builtins.print('\x1b[1;32m> \x1b[4m' + data['function']['arguments']['command'] + '\x1b[0m')
elif data['function']['name'] == 'run_python':
builtins.print('\x1b[1;32m> \x1b[4m' + "#!python" + '\x1b[0m')
builtins.print(data['function']['arguments']['command'])
elif data['function']['name'] == 'execute_binary':
filename = get_filename()
stdin = data['function']['arguments']['stdin']
args = data['function']['arguments']['args']
cmd = filename
if len(args) > 0:
cmd += ' ' + ' '.join(args)
if stdin:
cmd += f' stdin={stdin}'
builtins.print('\x1b[1;32m> \x1b[4m' + cmd + '\x1b[0m')
print_tool_call(data)
elif type == 'tool_response':
if 'content' in data:
sys.stdout.write(data['content'])
sys.stdout.flush()
builtins.print()
# builtins.print(data['content'])
elif type == 'message' and data['done']:
builtins.print()
Expand Down Expand Up @@ -324,11 +317,4 @@ def chat(interpreter, **kwargs):
finally:
signal.signal(signal.SIGINT, original_handler)
spinner.stop()
try:
pending = asyncio.all_tasks(loop=loop)
for task in pending:
task.cancel()
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
loop.close()
litellm.in_memory_llm_clients_cache.clear()
1 change: 1 addition & 0 deletions r2ai/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

import os
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
import sys
import builtins
import traceback
Expand Down
8 changes: 7 additions & 1 deletion r2ai/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .interpreter import Interpreter
from .pipe import have_rlang, r2lang, r2singleton
from r2ai import bubble, LOGGER

from .test import run_test
tab_init()

print_buffer = ""
Expand Down Expand Up @@ -217,6 +217,10 @@ def runline(ai, usertext):
print(help_message)
elif usertext.startswith("clear") or usertext.startswith("-k"):
print("\x1b[2J\x1b[0;0H\r")
if ai.messages:
ai.messages = []
if autoai and autoai.messages:
autoai.messages = []
elif usertext.startswith("-MM"):
print(models().strip())
elif usertext.startswith("-M"):
Expand Down Expand Up @@ -469,6 +473,8 @@ def runline(ai, usertext):
print("r2 is not available", file=sys.stderr)
else:
builtins.print(r2_cmd(usertext[1:]))
elif usertext.startswith("--test"):
run_test(usertext[7:])
elif usertext.startswith("-"):
print(f"Unknown flag '{usertext}'. See 'r2ai -h' for help", file=sys.stderr)
else:
Expand Down
37 changes: 37 additions & 0 deletions r2ai/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import builtins
from .tools import run_python, execute_binary, r2cmd
import subprocess
from .pipe import get_filename
import time
py_code = """
print('hello test')
"""

def run_test(args):
if not args or len(args) == 0:
res = run_python(py_code).strip()
print(f"run_python: {res}", len(res))
assert res == "hello test"
print("run_python: test passed")
r2cmd("o--;o /bin/ls")
res = execute_binary(args=["-d", "/etc"]).strip()
subp = subprocess.run(["/bin/ls", "-d", "/etc"], capture_output=True, text=True)
print("exec result", res)
print("subp result", subp.stdout)
assert ''.join(res).strip() == subp.stdout.strip()
print("execute_binary with args: test passed")
else:
cmd, *args = args.split(" ", 1)
if cmd == "get_filename":
builtins.print(get_filename())
elif cmd == "run_python":
builtins.print(f"--- args ---")
builtins.print(args)
builtins.print(f"--- end args ---")
builtins.print(f"--- result ---")
builtins.print(run_python(args[0]))
builtins.print(f"--- end result ---")
elif cmd == "r2cmd":
builtins.print(f"--- {args} ---")
builtins.print(r2cmd(args))
builtins.print("--- end ---")
145 changes: 119 additions & 26 deletions r2ai/tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
from r2ai.pipe import get_r2_inst
import json
import builtins
import base64
from .pipe import get_filename
from . import LOGGER
import time
import sys
from io import StringIO
import subprocess
import os
is_plugin = False
try:
import r2lang
is_plugin = True
except Exception:
is_plugin = False
pass

def r2cmd(command: str):
"""
Expand All @@ -17,8 +32,11 @@ def r2cmd(command: str):
The output of the r2 command
"""
r2 = get_r2_inst()
if command.startswith('r2 '):
return "You are already in r2!"
cmd = '{"cmd":' + json.dumps(command) + '}'
res = r2.cmd(cmd)

try:
res = json.loads(res)
if 'error' in res and res['error'] is True:
Expand All @@ -29,6 +47,10 @@ def r2cmd(command: str):

return res['res']
except json.JSONDecodeError:
if type(res) == str:
spl = res.strip().split('\n')
if spl[-1].startswith('{"res":""'):
res = '\n'.join(spl[:-1])
return res
except Exception as e:
# return { 'type': 'error', 'output': f"Error running r2cmd: {e}\nCommand: {command}\nResponse: {res}" }
Expand All @@ -49,34 +71,105 @@ def run_python(command: str):
The output of the python script
"""
r2 = get_r2_inst()
with open('r2ai_tmp.py', 'w') as f:
f.write(command)
r2 = get_r2_inst()
res = r2.cmd('#!python r2ai_tmp.py')
r2.cmd('rm r2ai_tmp.py')
return res
res = ""
is_plugin = False
python_path = sys.executable
try:
proc = subprocess.run([python_path, '-c', command],
capture_output=True,
text=True)
res = proc.stdout
if proc.stderr:
res += proc.stderr
except Exception as e:
res = str(e)

# if is_plugin:
# base64cmd = base64.b64encode(command.encode('utf-8')).decode('utf-8')
# res += r2cmd(f'#!python -e base64:{base64cmd} > .r2ai_tmp.log')
# res += r2cmd('cat .r2ai_tmp.log')
# r2cmd('rm .r2ai_tmp.log')
# else:
# with open('r2ai_tmp.py', 'w') as f:
# f.write(command)
# r2 = get_r2_inst()
# res += r2cmd('#!python r2ai_tmp.py > .r2ai_tmp.log')
# time.sleep(0.1)
# res += r2cmd('!cat .r2ai_tmp.log')
# LOGGER.debug(f'run_python: {res}')
# # r2cmd('rm r2ai_tmp.py')
# # r2cmd('rm .r2ai_tmp.log')
return res


schemas = {
"execute_binary": {
"name": "execute_binary",
"description": "Execute a binary with the given arguments and stdin",
"parameters": {
"type": "object",
"properties": {
"args": {
"description": "The arguments to pass to the binary. Do not include the file name.",
"type": "array",
"items": {
"type": "string"
}
},
"stdin": {
"type": "string"
}
}
}
}
}

def execute_binary(args: list[str] = [], stdin: str = ""):
"""
Execute a binary with the given arguments and stdin
filename = get_filename()
if filename:
if os.path.isabs(filename):
abs_path = os.path.abspath(filename)
if os.path.exists(abs_path):
filename = abs_path
else:
cwd_path = os.path.join(os.getcwd(), filename)
if os.path.exists(cwd_path):
filename = cwd_path
try:
cmd = [filename] + args
proc = subprocess.run(cmd, input=stdin, capture_output=True, text=True)
res = proc.stdout
if proc.stderr:
res += proc.stderr
return res
except Exception as e:
return str(e)
return ""
# r2 = get_r2_inst()
# if stdin:
# r2.cmd(f'dor stdin={json.dumps(stdin)}')
# if len(args) > 0:
# r2.cmd(f"ood {' '.join(args)}")
# else:
# r2.cmd("ood")
# res = r2cmd("dc")
# return res

Parameters
----------
args: list[str]
The arguments to pass to the binary
stdin: str
The stdin to pass to the binary

Returns
-------
str
The output of the binary
"""

r2 = get_r2_inst()
if len(args) > 0:
r2.cmd(f"dor {' '.join(args)}")
if stdin:
r2.cmd(f'dor stdin="{stdin}"')
r2.cmd("ood")
return r2cmd("dc")
def print_tool_call(msg):
if msg['function']['name'] == 'r2cmd':
builtins.print('\x1b[1;32m> \x1b[4m' + msg['function']['arguments']['command'] + '\x1b[0m')
elif msg['function']['name'] == 'run_python':
builtins.print('\x1b[1;32m> \x1b[4m' + "#!python" + '\x1b[0m')
builtins.print(msg['function']['arguments']['command'])
elif msg['function']['name'] == 'execute_binary':
filename = get_filename() or 'bin'
stdin = msg['function']['arguments']['stdin'] if 'stdin' in msg['function']['arguments'] else None
args = msg['function']['arguments']['args'] if 'args' in msg['function']['arguments'] else []
cmd = filename
if args and len(args) > 0:
cmd += ' ' + ' '.join(args)
if stdin and len(stdin) > 0:
cmd += f' stdin={stdin}'
builtins.print('\x1b[1;32m> \x1b[4m' + cmd + '\x1b[0m')
9 changes: 6 additions & 3 deletions r2ai/ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from .chat import chat
import asyncio
import json
import re
ANSI_REGEX = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')


class ModelConfigDialog(SystemModalScreen):
def __init__(self, keys: list[str]) -> None:
Expand Down Expand Up @@ -197,15 +200,15 @@ def on_message(self, type: str, message: any) -> None:
self.add_message(message["id"], "Tool Call", f"{message['function']['name']} > {message['function']['arguments']['command']}")
elif message['function']['name'] == 'execute_binary':
args = message['function']['arguments']
output = get_filename()
output = get_filename() or "bin"
if 'args' in args and len(args['args']) > 0:
output += f" {args['args'].join(' ')}\n"
output += f" {' '.join(args['args'])}\n"
if 'stdin' in args and len(args['stdin']) > 0:
output += f" stdin={args['stdin']}\n"

self.add_message(message["id"], "Tool Call", f"{message['function']['name']} > {output}")
elif type == 'tool_response':
self.add_message(message["id"], "Tool Response", message['content'])
self.add_message(message["id"], "Tool Response", ANSI_REGEX.sub('', message['content']))

async def send_message(self) -> None:
input_widget = self.query_one("#chat-input", Input)
Expand Down
6 changes: 3 additions & 3 deletions r2ai/ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def signal_handler(signum, frame):
async def chat(ai, message, cb):
model = ai.model.replace(":", "/")
tools = [r2cmd, run_python, execute_binary]
messages = ai.messages + [{"role": "user", "content": message}]
ai.messages.append({"role": "user", "content": message})
tool_choice = 'auto'
if not is_litellm_model(model) and ai and not ai.llama_instance:
ai.llama_instance = new_get_hf_llm(ai, model, int(ai.env["llm.window"]))

chat_auto = ChatAuto(model, interpreter=ai, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, cb=cb)
chat_auto = ChatAuto(model, interpreter=ai, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=ai.messages, tool_choice=tool_choice, cb=cb)

return await chat_auto.achat()