Skip to content

Commit

Permalink
Merge pull request #6 from EmbeddedDevops1/5-indentation-requirement-…
Browse files Browse the repository at this point in the history
…for-python-classes

Preprocessor Script for Indents
  • Loading branch information
coditamar authored May 20, 2024
2 parents 21ee55c + e70e59e commit 4bcddcb
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 25 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ poetry run cover-agent \
--coverage-type "cobertura" \
--desired-coverage 70 \
--max-iterations 1 \
--openai-model "gpt-4o" \
--additional-instructions "Since I am using a test class each line of code (including the first line), In your response, will need to be prepended with 4 whitespaces. This is extremely important to check to make sure every line returned contains that 4 whitespace indent otherwise my code will not run."
--openai-model "gpt-4o"
```

Note: If you are using Poetry then use the `poetry run python -m cover-agent` command instead of the `cover-agent` run command.
Expand Down
49 changes: 49 additions & 0 deletions cover_agent/FilePreprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import ast
import textwrap


class FilePreprocessor:
def __init__(self, path_to_file):
self.path_to_file = path_to_file

# List of rules/action key pair.
# Add your new rule and how to process the text (function) here
self.rules = [(self._is_python_file, self._process_if_python)]

def process_file(self, text: str) -> str:
"""
Process the text based on the internal rules.
"""
for condition, action in self.rules:
if condition():
return action(text)
return text # Return the text unchanged if no rules apply

def _is_python_file(self) -> bool:
"""
Rule to check if the file is a Python file.
"""
return self.path_to_file.endswith(".py")

def _process_if_python(self, text: str) -> str:
"""
Action to process Python files by checking for class definitions and indenting if found.
"""
if self._contains_class_definition():
return textwrap.indent(text, " ")
return text

def _contains_class_definition(self) -> bool:
"""
Check if the file contains a Python class definition using the ast module.
"""
try:
with open(self.path_to_file, "r") as file:
content = file.read()
parsed_ast = ast.parse(content)
for node in ast.walk(parsed_ast):
if isinstance(node, ast.ClassDef):
return True
except SyntaxError as e:
print(f"Syntax error when parsing the file: {e}")
return False
22 changes: 19 additions & 3 deletions cover_agent/PromptBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
{failed_test_runs}
```
"""


class PromptBuilder:

def __init__(
Expand Down Expand Up @@ -60,9 +62,23 @@ def __init__(
self.code_coverage_report = code_coverage_report

# Conditionally fill in optional sections
self.included_files = ADDITIONAL_INCLUDES_TEXT.format(included_files=included_files) if included_files else included_files
self.additional_instructions = ADDITIONAL_INSTRUCTIONS_TEXT.format(additional_instructions=additional_instructions) if additional_instructions else additional_instructions
self.failed_test_runs = FAILED_TESTS_TEXT.format(failed_test_runs=failed_test_runs) if failed_test_runs else failed_test_runs
self.included_files = (
ADDITIONAL_INCLUDES_TEXT.format(included_files=included_files)
if included_files
else included_files
)
self.additional_instructions = (
ADDITIONAL_INSTRUCTIONS_TEXT.format(
additional_instructions=additional_instructions
)
if additional_instructions
else additional_instructions
)
self.failed_test_runs = (
FAILED_TESTS_TEXT.format(failed_test_runs=failed_test_runs)
if failed_test_runs
else failed_test_runs
)

def _read_file(self, file_path):
"""
Expand Down
26 changes: 20 additions & 6 deletions cover_agent/UnitTestGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from cover_agent.CustomLogger import CustomLogger
from cover_agent.PromptBuilder import PromptBuilder
from cover_agent.AICaller import AICaller
from cover_agent.FilePreprocessor import FilePreprocessor


class UnitTestGenerator:
Expand Down Expand Up @@ -54,6 +55,8 @@ def __init__(
# Get the logger instance from CustomLogger
self.logger = CustomLogger.get_logger(__name__)

# States to maintain within this class
self.preprocessor = FilePreprocessor(self.test_file_path)
self.failed_test_runs = []

# Run coverage and build the prompt
Expand Down Expand Up @@ -140,7 +143,9 @@ def build_prompt(self):
if not self.failed_test_runs:
failed_test_runs_value = ""
else:
failed_test_runs_value = json.dumps(self.failed_test_runs).replace("\\n", "\n")
failed_test_runs_value = json.dumps(self.failed_test_runs).replace(
"\\n", "\n"
)

# Call PromptBuilder to build the prompt
prompt = PromptBuilder(
Expand Down Expand Up @@ -172,7 +177,7 @@ def generate_tests(self, LLM_model="gpt-4o", max_tokens=4096, dry_run=False):
# We want to remove them and split up the tests into a list of tests
response = ai_caller.call_model(prompt=self.prompt, max_tokens=max_tokens)

# Split the response into a list of tests and strip off the trailing whitespaces
# Split the response into a list of tests and strip off the trailing whitespaces
# (as we sometimes anticipate indentations in the returned code from the LLM)
tests = response.split("```")
return [test.rstrip() for test in tests if test.rstrip()]
Expand All @@ -191,13 +196,16 @@ def validate_test(self, generated_test: str):
dict: A dictionary containing the test result status, reason for failure (if any),
stdout, stderr, exit code, and the test itself.
"""
# Step 0: Run the test through the preprocessor rule set
processed_test = self.preprocessor.process_file(generated_test)

# Step 1: Append the generated test to the test file and save the original content
with open(self.test_file_path, "r+") as test_file:
original_content = test_file.read() # Store original content
test_file.write(
"\n"
+ ("\n" if not original_content.endswith("\n") else "")
+ generated_test
+ processed_test
+ "\n"
) # Append the new test at the end

Expand All @@ -223,7 +231,9 @@ def validate_test(self, generated_test: str):
"stdout": stdout,
"test": generated_test,
}
self.failed_test_runs.append(fail_details["test"]) # Append failure details to the list
self.failed_test_runs.append(
fail_details["test"]
) # Append failure details to the list
return fail_details

# If test passed, check for coverage increase
Expand Down Expand Up @@ -253,7 +263,9 @@ def validate_test(self, generated_test: str):
"stdout": stdout,
"test": generated_test,
}
self.failed_test_runs.append(fail_details["test"]) # Append failure details to the list
self.failed_test_runs.append(
fail_details["test"]
) # Append failure details to the list
return fail_details
except Exception as e:
# Handle errors gracefully
Expand All @@ -269,7 +281,9 @@ def validate_test(self, generated_test: str):
"stdout": stdout,
"test": generated_test,
}
self.failed_test_runs.append(fail_details["test"]) # Append failure details to the list
self.failed_test_runs.append(
fail_details["test"]
) # Append failure details to the list
return fail_details

# If everything passed and coverage increased, update current coverage and log success
Expand Down
12 changes: 8 additions & 4 deletions cover_agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def parse_args():
"--included-files",
default=None,
nargs="*",
help="List of files to include in the coverage. For example, \"--included-files library1.c library2.c.\" Default: %(default)s.",
help='List of files to include in the coverage. For example, "--included-files library1.c library2.c." Default: %(default)s.',
)
parser.add_argument(
"--coverage-type",
Expand Down Expand Up @@ -134,9 +134,11 @@ def main():
and iteration_count < args.max_iterations
):
# Provide coverage feedback to user
logger.info(f"Current Coverage: {round(test_gen.current_coverage * 100, 2)}%")
logger.info(
f"Current Coverage: {round(test_gen.current_coverage * 100, 2)}%"
)
logger.info(f"Desired Coverage: {test_gen.desired_coverage}%")

# Generate tests by making a call to the LLM
generated_tests = test_gen.generate_tests(
LLM_model=args.openai_model, max_tokens=4096
Expand All @@ -154,7 +156,9 @@ def main():
iteration_count += 1

if iteration_count == args.max_iterations:
logger.info("Reached maximum iteration limit without achieving desired coverage.")
logger.info(
"Reached maximum iteration limit without achieving desired coverage."
)

# Dump the test results to a report
ReportGenerator.generate_report(test_results_list, "test_results.html")
Expand Down
2 changes: 1 addition & 1 deletion cover_agent/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.31
0.1.32
53 changes: 53 additions & 0 deletions tests/test_FilePreprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
import tempfile
import textwrap
from cover_agent.FilePreprocessor import FilePreprocessor


class TestFilePreprocessor:
# Test for a C file
def test_c_file(self):
with tempfile.NamedTemporaryFile(delete=False, suffix=".c") as tmp:
preprocessor = FilePreprocessor(tmp.name)
input_text = "Lorem ipsum dolor sit amet,\nconsectetur adipiscing elit,\nsed do eiusmod tempor incididunt."
processed_text = preprocessor.process_file(input_text)
assert (
processed_text == input_text
), "C file processing should not alter the text."

# Test for a Python file with only a function
def test_py_file_with_function_only(self):
with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as tmp:
tmp.write(b"def function():\n pass\n")
tmp.close()
preprocessor = FilePreprocessor(tmp.name)
input_text = "Lorem ipsum dolor sit amet,\nconsectetur adipiscing elit,\nsed do eiusmod tempor incididunt."
processed_text = preprocessor.process_file(input_text)
assert (
processed_text == input_text
), "Python file without class should not alter the text."

# Test for a Python file with a comment that looks like a class definition
def test_py_file_with_commented_class(self):
with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as tmp:
tmp.write(b"# class myPythonFile:\n pass\n")
tmp.close()
preprocessor = FilePreprocessor(tmp.name)
input_text = "Lorem ipsum dolor sit amet,\nconsectetur adipiscing elit,\nsed do eiusmod tempor incididunt."
processed_text = preprocessor.process_file(input_text)
assert (
processed_text == input_text
), "Commented class definition should not trigger processing."

# Test for a Python file with an actual class definition
def test_py_file_with_class(self):
with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as tmp:
tmp.write(b"class MyClass:\n def method(self):\n pass\n")
tmp.close()
preprocessor = FilePreprocessor(tmp.name)
input_text = "Lorem ipsum dolor sit amet,\nconsectetur adipiscing elit,\nsed do eiusmod tempor incididunt."
processed_text = preprocessor.process_file(input_text)
expected_output = textwrap.indent(input_text, " ")
assert (
processed_text == expected_output
), "Python file with class should indent the text."
25 changes: 16 additions & 9 deletions tests/test_PromptBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import patch, mock_open
from cover_agent.PromptBuilder import PromptBuilder


class TestPromptBuilder:
@pytest.fixture(autouse=True)
def setup_method(self, monkeypatch):
Expand All @@ -11,7 +12,10 @@ def setup_method(self, monkeypatch):

def test_initialization_reads_file_contents(self):
builder = PromptBuilder(
"cover_agent/prompt_template.md", "source_path", "test_path", "dummy content"
"cover_agent/prompt_template.md",
"source_path",
"test_path",
"dummy content",
)
assert builder.prompt_template == "dummy content"
assert builder.source_file == "dummy content"
Expand All @@ -27,7 +31,7 @@ def test_build_prompt_replaces_placeholders_correctly(self):
"coverage_report",
"Included Files Content",
"Additional Instructions Content",
"Failed Test Runs Content"
"Failed Test Runs Content",
)
builder.prompt_template = "Template: {source_file}, Test: {test_file}, Coverage: {code_coverage_report}, Includes: {additional_includes_section}, Instructions: {additional_instructions_text}, Failed Tests: {failed_tests_section}"
builder.source_file = "Source Content"
Expand All @@ -48,7 +52,10 @@ def mock_open_raise(*args, **kwargs):
monkeypatch.setattr("builtins.open", mock_open_raise)

builder = PromptBuilder(
"cover_agent/prompt_template.md", "source_path", "test_path", "coverage_report"
"cover_agent/prompt_template.md",
"source_path",
"test_path",
"coverage_report",
)
assert "Error reading cover_agent/prompt_template.md" in builder.prompt_template
assert "Error reading source_path" in builder.source_file
Expand All @@ -62,7 +69,7 @@ def test_empty_included_files_section_not_in_prompt(self, monkeypatch):
source_file_path="source_path",
test_file_path="test_path",
code_coverage_report="coverage_report",
included_files="Included Files Content"
included_files="Included Files Content",
)
# Directly read the real file content for the prompt template
with open("cover_agent/prompt_template.md", "r") as f:
Expand All @@ -83,7 +90,7 @@ def test_non_empty_included_files_section_in_prompt(self, monkeypatch):
source_file_path="source_path",
test_file_path="test_path",
code_coverage_report="coverage_report",
included_files="Included Files Content"
included_files="Included Files Content",
)

# Directly read the real file content for the prompt template
Expand All @@ -106,7 +113,7 @@ def test_empty_additional_instructions_section_not_in_prompt(self, monkeypatch):
source_file_path="source_path",
test_file_path="test_path",
code_coverage_report="coverage_report",
additional_instructions=""
additional_instructions="",
)
# Directly read the real file content for the prompt template
with open("cover_agent/prompt_template.md", "r") as f:
Expand All @@ -126,7 +133,7 @@ def test_empty_failed_test_runs_section_not_in_prompt(self, monkeypatch):
source_file_path="source_path",
test_file_path="test_path",
code_coverage_report="coverage_report",
failed_test_runs=""
failed_test_runs="",
)
# Directly read the real file content for the prompt template
with open("cover_agent/prompt_template.md", "r") as f:
Expand All @@ -146,7 +153,7 @@ def test_non_empty_additional_instructions_section_in_prompt(self, monkeypatch):
source_file_path="source_path",
test_file_path="test_path",
code_coverage_report="coverage_report",
additional_instructions="Additional Instructions Content"
additional_instructions="Additional Instructions Content",
)
# Directly read the real file content for the prompt template
with open("cover_agent/prompt_template.md", "r") as f:
Expand All @@ -167,7 +174,7 @@ def test_non_empty_failed_test_runs_section_in_prompt(self, monkeypatch):
source_file_path="source_path",
test_file_path="test_path",
code_coverage_report="coverage_report",
failed_test_runs="Failed Test Runs Content"
failed_test_runs="Failed Test Runs Content",
)
# Directly read the real file content for the prompt template
with open("cover_agent/prompt_template.md", "r") as f:
Expand Down

0 comments on commit 4bcddcb

Please sign in to comment.