diff --git a/cover_agent/AgentCompletionABC.py b/cover_agent/AgentCompletionABC.py new file mode 100644 index 00000000..50649535 --- /dev/null +++ b/cover_agent/AgentCompletionABC.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +class AgentCompletionABC(ABC): + """Abstract base class for AI-driven prompt handling.""" + + @abstractmethod + def generate_tests( + self, failed_tests: str, language: str, test_framework: str, coverage_report: str + ) -> Tuple[str, int, int, str]: + """ + Generates additional unit tests to improve test coverage. + + Returns: + Tuple[str, int, int, str]: AI-generated test cases, input token count, output token count, and generated prompt. + """ + pass + + @abstractmethod + def analyze_test_failure(self, stderr: str, stdout: str, processed_test_file: str) -> Tuple[str, int, int, str]: + """ + Analyzes a test failure and returns insights. + + Returns: + Tuple[str, int, int, str]: AI-generated analysis, input token count, output token count, and generated prompt. + """ + pass + + @abstractmethod + def analyze_test_insert_line(self, test_file: str) -> Tuple[str, int, int, str]: + """ + Determines where to insert new test cases. + + Returns: + Tuple[str, int, int, str]: Suggested insertion point, input token count, output token count, and generated prompt. + """ + pass + + @abstractmethod + def analyze_test_against_context(self, test_code: str, context: str) -> Tuple[str, int, int, str]: + """ + Validates whether a test is appropriate for its corresponding source code. + + Returns: + Tuple[str, int, int, str]: AI validation result, input token count, output token count, and generated prompt. + """ + pass + + @abstractmethod + def analyze_suite_test_headers_indentation(self, test_file: str) -> Tuple[str, int, int, str]: + """ + Determines the indentation style used in test suite headers. + + Returns: + Tuple[str, int, int, str]: Suggested indentation style, input token count, output token count, and generated prompt. + """ + pass diff --git a/cover_agent/DefaultAgentCompletion.py b/cover_agent/DefaultAgentCompletion.py new file mode 100644 index 00000000..f695e0a1 --- /dev/null +++ b/cover_agent/DefaultAgentCompletion.py @@ -0,0 +1,61 @@ +from cover_agent.AgentCompletionABC import AgentCompletionABC +from cover_agent.PromptBuilder import PromptBuilder +from cover_agent.AICaller import AICaller + +class DefaultAgentCompletion(AgentCompletionABC): + """Default implementation using PromptBuilder and AICaller.""" + + def __init__(self, builder: PromptBuilder, caller: AICaller): + self.builder = builder + self.caller = caller + + def generate_tests(self, failed_tests, language, test_framework, coverage_report): + """Generates additional unit tests to improve test coverage.""" + prompt = self.builder.build_prompt_custom( + file="test_generation_prompt", + failed_tests_section=failed_tests, + language=language, + testing_framework=test_framework, + code_coverage_report=coverage_report, + ) + response, prompt_tokens, completion_tokens = self.caller.call_model(prompt) + return response, prompt_tokens, completion_tokens, prompt + + def analyze_test_failure(self, stderr, stdout, processed_test_file): + """Analyzes the output of a failed test to determine the cause of failure.""" + prompt = self.builder.build_prompt_custom( + file="analyze_test_run_failure", + stderr=stderr, + stdout=stdout, + processed_test_file=processed_test_file, + ) + response, prompt_tokens, completion_tokens = self.caller.call_model(prompt) + return response, prompt_tokens, completion_tokens, prompt + + def analyze_test_insert_line(self, test_file): + """Determines where to insert new test cases.""" + prompt = self.builder.build_prompt_custom( + file="analyze_suite_test_insert_line", + test_file=test_file, + ) + response, prompt_tokens, completion_tokens = self.caller.call_model(prompt) + return response, prompt_tokens, completion_tokens, prompt + + def analyze_test_against_context(self, test_code, context): + """Validates whether a generated test is appropriate for its corresponding source code.""" + prompt = self.builder.build_prompt_custom( + file="analyze_test_against_context", + test_code=test_code, + context=context, + ) + response, prompt_tokens, completion_tokens = self.caller.call_model(prompt) + return response, prompt_tokens, completion_tokens, prompt + + def analyze_suite_test_headers_indentation(self, test_file): + """Determines the indentation style used in test suite headers.""" + prompt = self.builder.build_prompt_custom( + file="analyze_suite_test_headers_indentation", + test_file=test_file, + ) + response, prompt_tokens, completion_tokens = self.caller.call_model(prompt) + return response, prompt_tokens, completion_tokens, prompt diff --git a/cover_agent/PromptBuilder.py b/cover_agent/PromptBuilder.py index 0a5c01a5..ebe424a7 100644 --- a/cover_agent/PromptBuilder.py +++ b/cover_agent/PromptBuilder.py @@ -161,42 +161,64 @@ def build_prompt(self) -> dict: # print(f"#### user_prompt:\n\n{user_prompt}") return {"system": system_prompt, "user": user_prompt} - def build_prompt_custom(self, file) -> dict: + def build_prompt_custom( + self, + file: str, + source_file_name: str = None, + test_file_name: str = None, + source_file_numbered: str = None, + test_file_numbered: str = None, + source_file: str = None, + test_file: str = None, + code_coverage_report: str = None, + additional_includes_section: str = None, + failed_tests_section: str = None, + additional_instructions_text: str = None, + language: str = None, + max_tests: int = MAX_TESTS_PER_RUN, + testing_framework: str = None, + stdout: str = None, + stderr: str = None, + processed_test_file: str = None, + ) -> dict: """ Builds a custom prompt by replacing placeholders with actual content from files and settings. - + Parameters: file (str): The file to retrieve settings for building the prompt. - + Other parameters override instance attributes if provided. + Returns: dict: A dictionary containing the system and user prompts. """ + + # Use provided values or fallback to instance attributes variables = { - "source_file_name": self.source_file_name_rel, - "test_file_name": self.test_file_name_rel, - "source_file_numbered": self.source_file_numbered, - "test_file_numbered": self.test_file_numbered, - "source_file": self.source_file, - "test_file": self.test_file, - "code_coverage_report": self.code_coverage_report, - "additional_includes_section": self.included_files, - "failed_tests_section": self.failed_test_runs, - "additional_instructions_text": self.additional_instructions, - "language": self.language, - "max_tests": MAX_TESTS_PER_RUN, - "testing_framework": self.testing_framework, - "stdout": self.stdout_from_run, - "stderr": self.stderr_from_run, - "processed_test_file": self.processed_test_file, + "source_file_name": source_file_name or self.source_file_name_rel, + "test_file_name": test_file_name or self.test_file_name_rel, + "source_file_numbered": source_file_numbered or self.source_file_numbered, + "test_file_numbered": test_file_numbered or self.test_file_numbered, + "source_file": source_file or self.source_file, + "test_file": test_file or self.test_file, + "code_coverage_report": code_coverage_report or self.code_coverage_report, + "additional_includes_section": additional_includes_section or self.included_files, + "failed_tests_section": failed_tests_section or self.failed_test_runs, + "additional_instructions_text": additional_instructions_text or self.additional_instructions, + "language": language or self.language, + "max_tests": max_tests, + "testing_framework": testing_framework or self.testing_framework, + "stdout": stdout or self.stdout_from_run, + "stderr": stderr or self.stderr_from_run, + "processed_test_file": processed_test_file or self.processed_test_file, } + environment = Environment(undefined=StrictUndefined) try: settings = get_settings().get(file) - if settings is None or not hasattr(settings, "system") or not hasattr( - settings, "user" - ): + if settings is None or not hasattr(settings, "system") or not hasattr(settings, "user"): logging.error(f"Could not find settings for prompt file: {file}") return {"system": "", "user": ""} + system_prompt = environment.from_string(settings.system).render(variables) user_prompt = environment.from_string(settings.user).render(variables) except Exception as e: diff --git a/cover_agent/UnitTestGenerator.py b/cover_agent/UnitTestGenerator.py index ab0dc93b..c2013715 100644 --- a/cover_agent/UnitTestGenerator.py +++ b/cover_agent/UnitTestGenerator.py @@ -9,6 +9,7 @@ from cover_agent.CustomLogger import CustomLogger from cover_agent.FilePreprocessor import FilePreprocessor from cover_agent.PromptBuilder import PromptBuilder +from cover_agent.DefaultAgentCompletion import DefaultAgentCompletion from cover_agent.Runner import Runner from cover_agent.settings.config_loader import get_settings from cover_agent.settings.token_handling import clip_tokens, TokenEncoder @@ -70,6 +71,20 @@ def __init__( # Objects to instantiate self.ai_caller = AICaller(model=llm_model, api_base=api_base) + self.prompt_builder = PromptBuilder( + source_file_path=self.source_file_path, + test_file_path=self.test_file_path, + code_coverage_report="", + included_files=self.included_files, + additional_instructions=self.additional_instructions, + failed_test_runs="", + language="", + testing_framework="", + project_root=self.project_root, + ) + self.agent_completion = DefaultAgentCompletion( + builder=self.prompt_builder, caller=self.ai_caller + ) # Get the logger instance from CustomLogger self.logger = CustomLogger.get_logger(__name__) @@ -185,42 +200,6 @@ def check_for_failed_test_runs(self, failed_test_runs): return failed_test_runs_value - def build_prompt(self, failed_test_runs, language, testing_framework, code_coverage_report) -> dict: - """ - Builds a prompt using the provided information to be used for generating tests. - - This method processes the failed test runs using the check_for_failed_test_runs method and then calls the PromptBuilder class to construct the prompt. - The prompt includes details such as the source file path, test file path, code coverage report, included files, - additional instructions, failed test runs, and the programming language being used. - - Args: - failed_test_runs (list): A list of dictionaries containing information about failed test runs. - language (str): The programming language being used. - testing_framework (str): The testing framework being used. - code_coverage_report (str): The code coverage report. - - Returns: - dict: The generated prompt to be used for test generation. - """ - - # Check for failed test runs - failed_test_runs_value = self.check_for_failed_test_runs(failed_test_runs) - - # Call PromptBuilder to build the prompt - self.prompt_builder = PromptBuilder( - source_file_path=self.source_file_path, - test_file_path=self.test_file_path, - code_coverage_report=code_coverage_report, - included_files=self.included_files, - additional_instructions=self.additional_instructions, - failed_test_runs=failed_test_runs_value, - language=language, - testing_framework=testing_framework, - project_root=self.project_root, - ) - - return self.prompt_builder.build_prompt() - def generate_tests(self, failed_test_runs, language, testing_framework, code_coverage_report): """ Generate tests using the AI model based on the constructed prompt. @@ -239,8 +218,10 @@ def generate_tests(self, failed_test_runs, language, testing_framework, code_cov Raises: Exception: If there is an error during test generation, such as a parsing error while processing the AI model response. """ - self.prompt = self.build_prompt(failed_test_runs, language, testing_framework, code_coverage_report) - response, prompt_token_count, response_token_count = self.ai_caller.call_model(prompt=self.prompt) + failed_test_runs_value = self.check_for_failed_test_runs(failed_test_runs) + response, prompt_token_count, response_token_count, self.prompt = self.agent_completion.generate_tests( + failed_test_runs_value, language, testing_framework, code_coverage_report + ) self.total_input_token_count += prompt_token_count self.total_output_token_count += response_token_count diff --git a/tests/test_UnitTestGenerator.py b/tests/test_UnitTestGenerator.py index c220d327..e8555915 100644 --- a/tests/test_UnitTestGenerator.py +++ b/tests/test_UnitTestGenerator.py @@ -43,28 +43,6 @@ def test_get_code_language_no_extension(self): language = generator.get_code_language("filename") assert language == "unknown" - def test_build_prompt_with_failed_tests(self): - with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file: - generator = UnitTestGenerator( - source_file_path=temp_source_file.name, - test_file_path="test_test.py", - code_coverage_report_path="coverage.xml", - test_command="pytest", - llm_model="gpt-3" - ) - failed_test_runs = [ - { - "code": {"test_code": "def test_example(): assert False"}, - "error_message": "AssertionError" - } - ] - language = "python" - test_framework = "pytest" - code_coverage_report = "" - prompt = generator.build_prompt(failed_test_runs, language, test_framework, code_coverage_report) - assert "Failed Test:" in prompt['user'] - - def test_generate_tests_invalid_yaml(self): with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file: generator = UnitTestGenerator(