Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
emeryberger committed Mar 8, 2024
2 parents 9e2196a + a487a24 commit 73b91b1
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 108 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ authors = [
{ name="Stephen Freund", email="[email protected]" },
]
dependencies = [
"llm-utils>=0.2.6",
"llm-utils>=0.2.8",
"openai>=1.6.1",
"rich>=13.7.0",
"ansicolors>=1.1.8",
Expand All @@ -22,6 +22,7 @@ dependencies = [
"litellm>=1.26.6",
"PyYAML>=6.0.1",
"ipyflow>=0.0.130",
"numpy>=1.26.3"
]
description = "AI-assisted debugging. Uses AI to answer 'why'."
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/chatdbg/assistant/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def run(self, prompt, client_print=print):
)
client_print()
client_print(f"[Cost: ~${cost:.2f} USD]")
return run.usage.total_tokens, cost, elapsed_time
return run.usage.total_tokens,run.usage.prompt_tokens, run.usage.completion_tokens, cost, elapsed_time
except OpenAIError as e:
client_print(f"*** OpenAI Error: {e}")
sys.exit(-1)
Expand Down
112 changes: 64 additions & 48 deletions src/chatdbg/chatdbg_lldb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import json

import llm_utils
import openai

from assistant.lite_assistant import LiteAssistant
import chatdbg_utils
Expand Down Expand Up @@ -234,7 +233,7 @@ def why(
sys.exit(1)

the_prompt = buildPrompt(debugger)
args, _ = chatdbg_utils.parse_known_args(command)
args, _ = chatdbg_utils.parse_known_args(command.split())
chatdbg_utils.explain(the_prompt[0], the_prompt[1], the_prompt[2], args)


Expand Down Expand Up @@ -389,7 +388,12 @@ def _instructions():
You are an assistant debugger.
The user is having an issue with their code, and you are trying to help them find the root cause.
They will provide a short summary of the issue and a question to be answered.
Call the `lldb` function to run lldb debugger commands on the stopped program.
Call the `get_code_surrounding` function to retrieve user code and give more context back to the user on their problem.
Call the `find_definition` function to retrieve the definition of a particular symbol.
You should call `find_definition` on every symbol that could be linked to the issue.
Don't hesitate to use as many function calls as needed to give the best possible answer.
Once you have identified the root cause of the problem, explain it and provide a way to fix the issue if you can.
"""
Expand Down Expand Up @@ -440,52 +444,6 @@ def get_code_surrounding(filename: str, lineno: int) -> str:
(lines, first) = llm_utils.read_lines(filename, lineno - 7, lineno + 3)
return llm_utils.number_group_of_lines(lines, first)

clangd = clangd_lsp_integration.clangd()

def find_definition(filename: str, lineno: int, character: int) -> str:
"""
{
"name": "find_definition",
"description": "Returns the definition for the symbol at the given source location.",
"parameters": {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "The filename the code location is from."
},
"lineno": {
"type": "integer",
"description": "The line number where the symbol is present."
},
"character": {
"type": "integer",
"description": "The column number where the symbol is present."
}
},
"required": [ "filename", "lineno", "character" ]
}
}
"""
clangd.didOpen(filename, "c" if filename.endswith(".c") else "cpp")
definition = clangd.definition(filename, lineno, character)
clangd.didClose(filename)

if "result" not in definition or not definition["result"]:
return "No definition found."

path = clangd_lsp_integration.uri_to_path(definition["result"][0]["uri"])
start_lineno = definition["result"][0]["range"]["start"]["line"] + 1
end_lineno = definition["result"][0]["range"]["end"]["line"] + 1
(lines, first) = llm_utils.read_lines(path, start_lineno - 5, end_lineno + 5)
content = llm_utils.number_group_of_lines(lines, first)
line_string = (
f"line {start_lineno}"
if start_lineno == end_lineno
else f"lines {start_lineno}-{end_lineno}"
)
return f"""File '{path}' at {line_string}:\n```\n{content}\n```"""

assistant = LiteAssistant(
_instructions(),
model=args.llm,
Expand All @@ -500,6 +458,62 @@ def find_definition(filename: str, lineno: int, character: int) -> str:
print("[WARNING] clangd is not available.")
print("[WARNING] The `find_definition` function will not be made available.")
else:
clangd = clangd_lsp_integration.clangd()

def find_definition(filename: str, lineno: int, symbol: str) -> str:
"""
{
"name": "find_definition",
"description": "Returns the definition for the given symbol at the given source line number.",
"parameters": {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "The filename the symbol is from."
},
"lineno": {
"type": "integer",
"description": "The line number where the symbol is present."
},
"symbol": {
"type": "string",
"description": "The symbol to lookup."
}
},
"required": [ "filename", "lineno", "symbol" ]
}
}
"""
# We just return the first match here. Maybe we should find all definitions.
with open(filename, "r") as file:
lines = file.readlines()
if lineno - 1 >= len(lines):
return "Symbol not found at that location!"
character = lines[lineno - 1].find(symbol)
if character == -1:
return "Symbol not found at that location!"
clangd.didOpen(filename, "c" if filename.endswith(".c") else "cpp")
definition = clangd.definition(filename, lineno, character + 1)
clangd.didClose(filename)

if "result" not in definition or not definition["result"]:
return "No definition found."

path = clangd_lsp_integration.uri_to_path(definition["result"][0]["uri"])
start_lineno = definition["result"][0]["range"]["start"]["line"] + 1
end_lineno = definition["result"][0]["range"]["end"]["line"] + 1
(lines, first) = llm_utils.read_lines(
path, start_lineno - 5, end_lineno + 5
)
content = llm_utils.number_group_of_lines(lines, first)
line_string = (
f"line {start_lineno}"
if start_lineno == end_lineno
else f"lines {start_lineno}-{end_lineno}"
)
return f"""File '{path}' at {line_string}:\n```\n{content}\n```"""

assistant.add_function(find_definition)

return assistant
Expand All @@ -517,6 +531,8 @@ def get_frame_summary() -> str:

summaries = []
for i, frame in enumerate(thread):
if not frame.GetDisplayFunctionName():
continue
name = frame.GetDisplayFunctionName().split("(")[0]
arguments = []
for j in range(
Expand Down
62 changes: 17 additions & 45 deletions src/chatdbg/chatdbg_pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .ipdb_util.logging import ChatDBGLog, CopyingTextIOWrapper
from .ipdb_util.prompts import pdb_instructions
from .ipdb_util.text import *
from .ipdb_util.locals import *

_valid_models = [
"gpt-4-turbo-preview",
Expand Down Expand Up @@ -194,7 +195,7 @@ def onecmd(self, line: str) -> bool:
output = strip_color(hist_file.getvalue())
if line not in [ 'quit', 'EOF']:
self._log.user_command(line, output)
if line not in [ 'hist', 'test_prompt' ] and not self.was_chat:
if line not in [ 'hist', 'test_prompt', 'c', 'continue' ] and not self.was_chat:
self._history += [ (line, output) ]

def message(self, msg) -> None:
Expand Down Expand Up @@ -389,56 +390,27 @@ def print_stack_trace(self, context=None, locals=None):
pass


def _get_defined_locals_and_params(self, frame):

class SymbolFinder(ast.NodeVisitor):
def __init__(self):
self.defined_symbols = set()

def visit_Assign(self, node):
for target in node.targets:
if isinstance(target, ast.Name):
self.defined_symbols.add(target.id)
self.generic_visit(node)

def visit_For(self, node):
if isinstance(node.target, ast.Name):
self.defined_symbols.add(node.target.id)
self.generic_visit(node)

def visit_comprehension(self, node):
if isinstance(node.target, ast.Name):
self.defined_symbols.add(node.target.id)
self.generic_visit(node)


try:
source = textwrap.dedent(inspect.getsource(frame))
tree = ast.parse(source)

finder = SymbolFinder()
finder.visit(tree)

args, varargs, keywords, locals = inspect.getargvalues(frame)
parameter_symbols = set(args + [ varargs, keywords ])
parameter_symbols.discard(None)

return (finder.defined_symbols | parameter_symbols) & locals.keys()
except OSError as e:
# yipes -silent fail if getsource fails
return set()

def _print_locals(self, frame):
locals = frame.f_locals
defined_locals = self._get_defined_locals_and_params(frame)
in_global_scope = locals is frame.f_globals
defined_locals = extract_locals(frame)
# if in_global_scope and "In" in locals: # in notebook
# defined_locals = defined_locals | extract_nb_globals(locals)
if len(defined_locals) > 0:
if locals is frame.f_globals:
if in_global_scope:
print(f' Global variables:', file=self.stdout)
else:
print(f' Variables in this frame:', file=self.stdout)
for name in sorted(defined_locals):
value = locals[name]
print(f" {name}= {format_limited(value, limit=20)}", file=self.stdout)
prefix = f' {name}= '
rep = format_limited(value, limit=20).split('\n')
if len(rep) > 1:
rep = prefix + rep[0] + '\n' + textwrap.indent('\n'.join(rep[1:]),
prefix = ' ' * len(prefix))
else:
rep = prefix + rep[0]
print(rep, file=self.stdout)
print(file=self.stdout)

def _stack_prompt(self):
Expand Down Expand Up @@ -499,8 +471,8 @@ def client_print(line=""):
full_prompt = truncate_proportionally(full_prompt)

self._log.push_chat(arg, full_prompt)
tokens, cost, time = self._assistant.run(full_prompt, client_print)
self._log.pop_chat(tokens, cost, time)
total_tokens, prompt_tokens, completion_tokens, cost, time = self._assistant.run(full_prompt, client_print)
self._log.pop_chat(total_tokens, prompt_tokens, completion_tokens, cost, time)

def do_mark(self, arg):
marks = [ 'Full', 'Partial', 'Wrong', 'None', '?' ]
Expand Down
1 change: 0 additions & 1 deletion src/chatdbg/chatdbg_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import os
import textwrap
from typing import Any, List, Optional, Tuple

Expand Down
7 changes: 3 additions & 4 deletions src/chatdbg/clangd_lsp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,12 @@ def uri_to_path(uri):
return urllib.parse.unquote(path) # clangd seems to escape paths.


def is_available():
def is_available(executable="clangd"):
try:
clangd = subprocess.Popen(
["clangd", "--version"],
clangd = subprocess.run(
[executable, "--version"],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
text=True,
)
return clangd.returncode == 0
except FileNotFoundError:
Expand Down
52 changes: 52 additions & 0 deletions src/chatdbg/ipdb_util/locals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import ast
import inspect
import textwrap

class SymbolFinder(ast.NodeVisitor):
def __init__(self):
self.defined_symbols = set()

def visit_Assign(self, node):
for target in node.targets:
if isinstance(target, ast.Name):
self.defined_symbols.add(target.id)
self.generic_visit(node)

def visit_For(self, node):
if isinstance(node.target, ast.Name):
self.defined_symbols.add(node.target.id)
self.generic_visit(node)

def visit_comprehension(self, node):
if isinstance(node.target, ast.Name):
self.defined_symbols.add(node.target.id)
self.generic_visit(node)

def extract_locals(frame):
try:
source = textwrap.dedent(inspect.getsource(frame))
tree = ast.parse(source)

finder = SymbolFinder()
finder.visit(tree)

args, varargs, keywords, locals = inspect.getargvalues(frame)
parameter_symbols = set(args + [ varargs, keywords ])
parameter_symbols.discard(None)

return (finder.defined_symbols | parameter_symbols) & locals.keys()
except:
# ipes
return set()

def extract_nb_globals(globals):
result = set()
for source in globals["In"]:
try:
tree = ast.parse(source)
finder = SymbolFinder()
finder.visit(tree)
result = result | (finder.defined_symbols & globals.keys())
except Exception as e:
pass
return result
6 changes: 4 additions & 2 deletions src/chatdbg/ipdb_util/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ def push_chat(self, line, full_prompt):
}
}

def pop_chat(self, tokens, cost, time):
def pop_chat(self, total_tokens, prompt_tokens, completion_tokens, cost, time):
self.chat_step['stats'] = {
'tokens' : tokens,
'tokens' : total_tokens,
'prompt' : prompt_tokens,
'completion' : completion_tokens,
'cost' : cost,
'time' : time
}
Expand Down
5 changes: 2 additions & 3 deletions src/chatdbg/ipdb_util/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@
"""

_general_instructions=f"""\
The root cause of any error is likely due to a problem in the source code within
the {os.getcwd()} directory.
The root cause of any error is likely due to a problem in the source code from the user.
Explain why each variable contributing to the error has been set to the value that it has.
Keep your answers under 10 paragraphs.
Continue with your explanations until you reach the root cause of the error. Your answer may be as long as necessary.
End your answer with a section titled "##### Recommendation\\n" that contains one of:
* a fix if you have identified the root cause
Expand Down
Loading

0 comments on commit 73b91b1

Please sign in to comment.