-
Notifications
You must be signed in to change notification settings - Fork 366
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor functional components of main.py to CoverAgent class. (#66)
* Refactored functional component of main.py to CoverAgent #19. * Added format make command. * Incremented version.
- Loading branch information
1 parent
592d84b
commit c4e60f3
Showing
7 changed files
with
197 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
# Generated reports | ||
.coverage | ||
coverage.xml | ||
cobertura.xml | ||
testLog.xml | ||
|
||
# Caches | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import os | ||
import shutil | ||
from cover_agent.CustomLogger import CustomLogger | ||
from cover_agent.ReportGenerator import ReportGenerator | ||
from cover_agent.UnitTestGenerator import UnitTestGenerator | ||
|
||
class CoverAgent: | ||
def __init__(self, args): | ||
self.args = args | ||
self.logger = CustomLogger.get_logger(__name__) | ||
|
||
self._validate_paths() | ||
self._duplicate_test_file() | ||
|
||
self.test_gen = UnitTestGenerator( | ||
source_file_path=args.source_file_path, | ||
test_file_path=args.test_file_output_path, | ||
code_coverage_report_path=args.code_coverage_report_path, | ||
test_command=args.test_command, | ||
test_command_dir=args.test_command_dir, | ||
included_files=args.included_files, | ||
coverage_type=args.coverage_type, | ||
desired_coverage=args.desired_coverage, | ||
additional_instructions=args.additional_instructions, | ||
llm_model=args.model, | ||
api_base=args.api_base, | ||
) | ||
|
||
def _validate_paths(self): | ||
if not os.path.isfile(self.args.source_file_path): | ||
raise FileNotFoundError(f"Source file not found at {self.args.source_file_path}") | ||
if not os.path.isfile(self.args.test_file_path): | ||
raise FileNotFoundError(f"Test file not found at {self.args.test_file_path}") | ||
|
||
def _duplicate_test_file(self): | ||
if self.args.test_file_output_path != "": | ||
shutil.copy(self.args.test_file_path, self.args.test_file_output_path) | ||
else: | ||
self.args.test_file_output_path = self.args.test_file_path | ||
|
||
def run(self): | ||
if not self.args.prompt_only: | ||
iteration_count = 0 | ||
test_results_list = [] | ||
|
||
self.test_gen.initial_test_suite_analysis() | ||
|
||
while ( | ||
self.test_gen.current_coverage < (self.test_gen.desired_coverage / 100) | ||
and iteration_count < self.args.max_iterations | ||
): | ||
self.logger.info( | ||
f"Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%" | ||
) | ||
self.logger.info(f"Desired Coverage: {self.test_gen.desired_coverage}%") | ||
|
||
generated_tests_dict = self.test_gen.generate_tests(max_tokens=4096) | ||
|
||
for generated_test in generated_tests_dict.get('new_tests', []): | ||
test_result = self.test_gen.validate_test(generated_test, generated_tests_dict) | ||
test_results_list.append(test_result) | ||
|
||
iteration_count += 1 | ||
|
||
if self.test_gen.current_coverage < (self.test_gen.desired_coverage / 100): | ||
self.test_gen.run_coverage() | ||
|
||
if self.test_gen.current_coverage >= (self.test_gen.desired_coverage / 100): | ||
self.logger.info( | ||
f"Reached above target coverage of {self.test_gen.desired_coverage}% (Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%) in {iteration_count} iterations.") | ||
elif iteration_count == self.args.max_iterations: | ||
self.logger.info( | ||
f"Reached maximum iteration limit without achieving desired coverage. Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%" | ||
) | ||
|
||
ReportGenerator.generate_report(test_results_list, self.args.report_filepath) | ||
else: | ||
self.logger.info( | ||
f"Prompt only option requested. Skipping call to LLM. Prompt can be found at: {self.args.prompt_only}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
0.1.38 | ||
0.1.39 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import os | ||
import argparse | ||
from unittest.mock import patch, MagicMock | ||
import pytest | ||
from cover_agent.CoverAgent import CoverAgent | ||
from cover_agent.main import parse_args | ||
|
||
class TestCoverAgent: | ||
def test_parse_args(self): | ||
with patch( | ||
"sys.argv", | ||
[ | ||
"program.py", | ||
"--source-file-path", | ||
"test_source.py", | ||
"--test-file-path", | ||
"test_file.py", | ||
"--code-coverage-report-path", | ||
"coverage_report.xml", | ||
"--test-command", | ||
"pytest", | ||
"--max-iterations", | ||
"10", | ||
], | ||
): | ||
args = parse_args() | ||
assert args.source_file_path == "test_source.py" | ||
assert args.test_file_path == "test_file.py" | ||
assert args.code_coverage_report_path == "coverage_report.xml" | ||
assert args.test_command == "pytest" | ||
assert args.test_command_dir == os.getcwd() | ||
assert args.included_files is None | ||
assert args.coverage_type == "cobertura" | ||
assert args.report_filepath == "test_results.html" | ||
assert args.desired_coverage == 90 | ||
assert args.max_iterations == 10 | ||
|
||
@patch("cover_agent.CoverAgent.UnitTestGenerator") | ||
@patch("cover_agent.CoverAgent.ReportGenerator") | ||
@patch("cover_agent.CoverAgent.os.path.isfile") | ||
def test_agent_source_file_not_found( | ||
self, mock_isfile, mock_report_generator, mock_unit_cover_agent | ||
): | ||
args = argparse.Namespace( | ||
source_file_path="test_source.py", | ||
test_file_path="test_file.py", | ||
code_coverage_report_path="coverage_report.xml", | ||
test_command="pytest", | ||
test_command_dir=os.getcwd(), | ||
included_files=None, | ||
coverage_type="cobertura", | ||
report_filepath="test_results.html", | ||
desired_coverage=90, | ||
max_iterations=10, | ||
) | ||
parse_args = lambda: args | ||
mock_isfile.return_value = False | ||
|
||
with patch("cover_agent.main.parse_args", parse_args): | ||
with pytest.raises(FileNotFoundError) as exc_info: | ||
agent = CoverAgent(args) | ||
|
||
assert ( | ||
str(exc_info.value) == f"Source file not found at {args.source_file_path}" | ||
) | ||
|
||
mock_unit_cover_agent.assert_not_called() | ||
mock_report_generator.generate_report.assert_not_called() | ||
|
||
@patch("cover_agent.CoverAgent.os.path.exists") | ||
@patch("cover_agent.CoverAgent.os.path.isfile") | ||
@patch("cover_agent.CoverAgent.UnitTestGenerator") | ||
def test_agent_test_file_not_found( | ||
self, mock_unit_cover_agent, mock_isfile, mock_exists | ||
): | ||
args = argparse.Namespace( | ||
source_file_path="test_source.py", | ||
test_file_path="test_file.py", | ||
code_coverage_report_path="coverage_report.xml", | ||
test_command="pytest", | ||
test_command_dir=os.getcwd(), | ||
included_files=None, | ||
coverage_type="cobertura", | ||
report_filepath="test_results.html", | ||
desired_coverage=90, | ||
max_iterations=10, | ||
prompt_only=False, | ||
) | ||
parse_args = lambda: args | ||
mock_isfile.side_effect = [True, False] | ||
mock_exists.return_value = True | ||
|
||
with patch("cover_agent.main.parse_args", parse_args): | ||
with pytest.raises(FileNotFoundError) as exc_info: | ||
agent = CoverAgent(args) | ||
|
||
assert str(exc_info.value) == f"Test file not found at {args.test_file_path}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters