Skip to content

Commit

Permalink
Initial commit of ABC concept.
Browse files Browse the repository at this point in the history
  • Loading branch information
EmbeddedDevops1 committed Jan 31, 2025
1 parent a9dfa1a commit 2293ea9
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 82 deletions.
57 changes: 57 additions & 0 deletions cover_agent/AgentCompletionABC.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from abc import ABC, abstractmethod
from typing import Tuple

class AgentCompletionABC(ABC):
"""Abstract base class for AI-driven prompt handling."""

@abstractmethod
def generate_tests(
self, failed_tests: str, language: str, test_framework: str, coverage_report: str
) -> Tuple[str, int, int, str]:
"""
Generates additional unit tests to improve test coverage.
Returns:
Tuple[str, int, int, str]: AI-generated test cases, input token count, output token count, and generated prompt.
"""
pass

@abstractmethod
def analyze_test_failure(self, stderr: str, stdout: str, processed_test_file: str) -> Tuple[str, int, int, str]:
"""
Analyzes a test failure and returns insights.
Returns:
Tuple[str, int, int, str]: AI-generated analysis, input token count, output token count, and generated prompt.
"""
pass

@abstractmethod
def analyze_test_insert_line(self, test_file: str) -> Tuple[str, int, int, str]:
"""
Determines where to insert new test cases.
Returns:
Tuple[str, int, int, str]: Suggested insertion point, input token count, output token count, and generated prompt.
"""
pass

@abstractmethod
def analyze_test_against_context(self, test_code: str, context: str) -> Tuple[str, int, int, str]:
"""
Validates whether a test is appropriate for its corresponding source code.
Returns:
Tuple[str, int, int, str]: AI validation result, input token count, output token count, and generated prompt.
"""
pass

@abstractmethod
def analyze_suite_test_headers_indentation(self, test_file: str) -> Tuple[str, int, int, str]:
"""
Determines the indentation style used in test suite headers.
Returns:
Tuple[str, int, int, str]: Suggested indentation style, input token count, output token count, and generated prompt.
"""
pass
61 changes: 61 additions & 0 deletions cover_agent/DefaultAgentCompletion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from cover_agent.AgentCompletionABC import AgentCompletionABC
from cover_agent.PromptBuilder import PromptBuilder
from cover_agent.AICaller import AICaller

class DefaultAgentCompletion(AgentCompletionABC):
"""Default implementation using PromptBuilder and AICaller."""

def __init__(self, builder: PromptBuilder, caller: AICaller):
self.builder = builder
self.caller = caller

def generate_tests(self, failed_tests, language, test_framework, coverage_report):
"""Generates additional unit tests to improve test coverage."""
prompt = self.builder.build_prompt_custom(
file="test_generation_prompt",
failed_tests_section=failed_tests,
language=language,
testing_framework=test_framework,
code_coverage_report=coverage_report,
)
response, prompt_tokens, completion_tokens = self.caller.call_model(prompt)
return response, prompt_tokens, completion_tokens, prompt

def analyze_test_failure(self, stderr, stdout, processed_test_file):
"""Analyzes the output of a failed test to determine the cause of failure."""
prompt = self.builder.build_prompt_custom(
file="analyze_test_run_failure",
stderr=stderr,
stdout=stdout,
processed_test_file=processed_test_file,
)
response, prompt_tokens, completion_tokens = self.caller.call_model(prompt)
return response, prompt_tokens, completion_tokens, prompt

def analyze_test_insert_line(self, test_file):
"""Determines where to insert new test cases."""
prompt = self.builder.build_prompt_custom(
file="analyze_suite_test_insert_line",
test_file=test_file,
)
response, prompt_tokens, completion_tokens = self.caller.call_model(prompt)
return response, prompt_tokens, completion_tokens, prompt

def analyze_test_against_context(self, test_code, context):
"""Validates whether a generated test is appropriate for its corresponding source code."""
prompt = self.builder.build_prompt_custom(
file="analyze_test_against_context",
test_code=test_code,
context=context,
)
response, prompt_tokens, completion_tokens = self.caller.call_model(prompt)
return response, prompt_tokens, completion_tokens, prompt

def analyze_suite_test_headers_indentation(self, test_file):
"""Determines the indentation style used in test suite headers."""
prompt = self.builder.build_prompt_custom(
file="analyze_suite_test_headers_indentation",
test_file=test_file,
)
response, prompt_tokens, completion_tokens = self.caller.call_model(prompt)
return response, prompt_tokens, completion_tokens, prompt
66 changes: 44 additions & 22 deletions cover_agent/PromptBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,42 +161,64 @@ def build_prompt(self) -> dict:
# print(f"#### user_prompt:\n\n{user_prompt}")
return {"system": system_prompt, "user": user_prompt}

def build_prompt_custom(self, file) -> dict:
def build_prompt_custom(
self,
file: str,
source_file_name: str = None,
test_file_name: str = None,
source_file_numbered: str = None,
test_file_numbered: str = None,
source_file: str = None,
test_file: str = None,
code_coverage_report: str = None,
additional_includes_section: str = None,
failed_tests_section: str = None,
additional_instructions_text: str = None,
language: str = None,
max_tests: int = MAX_TESTS_PER_RUN,
testing_framework: str = None,
stdout: str = None,
stderr: str = None,
processed_test_file: str = None,
) -> dict:
"""
Builds a custom prompt by replacing placeholders with actual content from files and settings.
Parameters:
file (str): The file to retrieve settings for building the prompt.
Other parameters override instance attributes if provided.
Returns:
dict: A dictionary containing the system and user prompts.
"""

# Use provided values or fallback to instance attributes
variables = {
"source_file_name": self.source_file_name_rel,
"test_file_name": self.test_file_name_rel,
"source_file_numbered": self.source_file_numbered,
"test_file_numbered": self.test_file_numbered,
"source_file": self.source_file,
"test_file": self.test_file,
"code_coverage_report": self.code_coverage_report,
"additional_includes_section": self.included_files,
"failed_tests_section": self.failed_test_runs,
"additional_instructions_text": self.additional_instructions,
"language": self.language,
"max_tests": MAX_TESTS_PER_RUN,
"testing_framework": self.testing_framework,
"stdout": self.stdout_from_run,
"stderr": self.stderr_from_run,
"processed_test_file": self.processed_test_file,
"source_file_name": source_file_name or self.source_file_name_rel,
"test_file_name": test_file_name or self.test_file_name_rel,
"source_file_numbered": source_file_numbered or self.source_file_numbered,
"test_file_numbered": test_file_numbered or self.test_file_numbered,
"source_file": source_file or self.source_file,
"test_file": test_file or self.test_file,
"code_coverage_report": code_coverage_report or self.code_coverage_report,
"additional_includes_section": additional_includes_section or self.included_files,
"failed_tests_section": failed_tests_section or self.failed_test_runs,
"additional_instructions_text": additional_instructions_text or self.additional_instructions,
"language": language or self.language,
"max_tests": max_tests,
"testing_framework": testing_framework or self.testing_framework,
"stdout": stdout or self.stdout_from_run,
"stderr": stderr or self.stderr_from_run,
"processed_test_file": processed_test_file or self.processed_test_file,
}

environment = Environment(undefined=StrictUndefined)
try:
settings = get_settings().get(file)
if settings is None or not hasattr(settings, "system") or not hasattr(
settings, "user"
):
if settings is None or not hasattr(settings, "system") or not hasattr(settings, "user"):
logging.error(f"Could not find settings for prompt file: {file}")
return {"system": "", "user": ""}

system_prompt = environment.from_string(settings.system).render(variables)
user_prompt = environment.from_string(settings.user).render(variables)
except Exception as e:
Expand Down
57 changes: 19 additions & 38 deletions cover_agent/UnitTestGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from cover_agent.CustomLogger import CustomLogger
from cover_agent.FilePreprocessor import FilePreprocessor
from cover_agent.PromptBuilder import PromptBuilder
from cover_agent.DefaultAgentCompletion import DefaultAgentCompletion
from cover_agent.Runner import Runner
from cover_agent.settings.config_loader import get_settings
from cover_agent.settings.token_handling import clip_tokens, TokenEncoder
Expand Down Expand Up @@ -70,6 +71,20 @@ def __init__(

# Objects to instantiate
self.ai_caller = AICaller(model=llm_model, api_base=api_base)
self.prompt_builder = PromptBuilder(
source_file_path=self.source_file_path,
test_file_path=self.test_file_path,
code_coverage_report="",
included_files=self.included_files,
additional_instructions=self.additional_instructions,
failed_test_runs="",
language="",
testing_framework="",
project_root=self.project_root,
)
self.agent_completion = DefaultAgentCompletion(
builder=self.prompt_builder, caller=self.ai_caller
)

# Get the logger instance from CustomLogger
self.logger = CustomLogger.get_logger(__name__)
Expand Down Expand Up @@ -185,42 +200,6 @@ def check_for_failed_test_runs(self, failed_test_runs):

return failed_test_runs_value

def build_prompt(self, failed_test_runs, language, testing_framework, code_coverage_report) -> dict:
"""
Builds a prompt using the provided information to be used for generating tests.
This method processes the failed test runs using the check_for_failed_test_runs method and then calls the PromptBuilder class to construct the prompt.
The prompt includes details such as the source file path, test file path, code coverage report, included files,
additional instructions, failed test runs, and the programming language being used.
Args:
failed_test_runs (list): A list of dictionaries containing information about failed test runs.
language (str): The programming language being used.
testing_framework (str): The testing framework being used.
code_coverage_report (str): The code coverage report.
Returns:
dict: The generated prompt to be used for test generation.
"""

# Check for failed test runs
failed_test_runs_value = self.check_for_failed_test_runs(failed_test_runs)

# Call PromptBuilder to build the prompt
self.prompt_builder = PromptBuilder(
source_file_path=self.source_file_path,
test_file_path=self.test_file_path,
code_coverage_report=code_coverage_report,
included_files=self.included_files,
additional_instructions=self.additional_instructions,
failed_test_runs=failed_test_runs_value,
language=language,
testing_framework=testing_framework,
project_root=self.project_root,
)

return self.prompt_builder.build_prompt()

def generate_tests(self, failed_test_runs, language, testing_framework, code_coverage_report):
"""
Generate tests using the AI model based on the constructed prompt.
Expand All @@ -239,8 +218,10 @@ def generate_tests(self, failed_test_runs, language, testing_framework, code_cov
Raises:
Exception: If there is an error during test generation, such as a parsing error while processing the AI model response.
"""
self.prompt = self.build_prompt(failed_test_runs, language, testing_framework, code_coverage_report)
response, prompt_token_count, response_token_count = self.ai_caller.call_model(prompt=self.prompt)
failed_test_runs_value = self.check_for_failed_test_runs(failed_test_runs)
response, prompt_token_count, response_token_count, self.prompt = self.agent_completion.generate_tests(
failed_test_runs_value, language, testing_framework, code_coverage_report
)

self.total_input_token_count += prompt_token_count
self.total_output_token_count += response_token_count
Expand Down
22 changes: 0 additions & 22 deletions tests/test_UnitTestGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,6 @@ def test_get_code_language_no_extension(self):
language = generator.get_code_language("filename")
assert language == "unknown"

def test_build_prompt_with_failed_tests(self):
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file:
generator = UnitTestGenerator(
source_file_path=temp_source_file.name,
test_file_path="test_test.py",
code_coverage_report_path="coverage.xml",
test_command="pytest",
llm_model="gpt-3"
)
failed_test_runs = [
{
"code": {"test_code": "def test_example(): assert False"},
"error_message": "AssertionError"
}
]
language = "python"
test_framework = "pytest"
code_coverage_report = ""
prompt = generator.build_prompt(failed_test_runs, language, test_framework, code_coverage_report)
assert "Failed Test:" in prompt['user']


def test_generate_tests_invalid_yaml(self):
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file:
generator = UnitTestGenerator(
Expand Down

0 comments on commit 2293ea9

Please sign in to comment.