Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LiteLLMAICaller class for AI model interactions, integrate settin… #20

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions cover_agent/LiteLLMAICaller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import tiktoken
from litellm import completion
import openai
import litellm

from settings.config_loader import get_settings


class LiteLLMAICaller:
def __init__(self, model):
self.azure = False
self.aws_bedrock_client = None
self.api_base = None
self.repetition_penalty = None
if get_settings().get("OPENAI_API_KEY", None):
openai.api_key = get_settings().OPENAI_API_KEY
litellm.openai_key = get_settings().OPENAI_API_KEY
if get_settings().get("LITELLM.DROP_PARAMS", None):
litellm.drop_params = get_settings().litellm.drop_params
if get_settings().get("OPENAI.ORG", None):
litellm.organization = get_settings().openai.org
if get_settings().get("OPENAI.API_TYPE", None):
if get_settings().openai.api_type == "azure":
self.azure = True
litellm.azure_key = get_settings().openai.key
if get_settings().get("OPENAI.API_VERSION", None):
litellm.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None):
litellm.api_base = get_settings().openai.api_base
if get_settings().get("ANTHROPIC.KEY", None):
litellm.anthropic_key = get_settings().anthropic.key
if get_settings().get("COHERE.KEY", None):
litellm.cohere_key = get_settings().cohere.key
if get_settings().get("GROQ.KEY", None):
litellm.api_key = get_settings().groq.key
if get_settings().get("REPLICATE.KEY", None):
litellm.replicate_key = get_settings().replicate.key
if get_settings().get("HUGGINGFACE.KEY", None):
litellm.huggingface_key = get_settings().huggingface.key
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model:
litellm.api_base = get_settings().huggingface.api_base
self.api_base = get_settings().huggingface.api_base
if get_settings().get("OLLAMA.API_BASE", None):
litellm.api_base = get_settings().ollama.api_base
self.api_base = get_settings().ollama.api_base
if get_settings().get("HUGGINGFACE.REPITITION_PENALTY", None):
self.repetition_penalty = float(get_settings().huggingface.repetition_penalty)
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None):
litellm.vertex_project = get_settings().vertexai.vertex_project
litellm.vertex_location = get_settings().get(
"VERTEXAI.VERTEX_LOCATION", None
)
if get_settings().get("AWS.BEDROCK_REGION", None):
litellm.AmazonAnthropicConfig.max_tokens_to_sample = 2000
litellm.AmazonAnthropicClaude3Config.max_tokens = 2000
self.aws_bedrock_client = boto3.client(
service_name="bedrock-runtime",
region_name=get_settings().aws.bedrock_region,
)

self.model = model

# Initialize the encoding for the model
try:
self.encoding = tiktoken.get_encoding("cl100k_base")
except Exception as e:
raise ValueError(f"Failed to get encoding: {e}")

def call_model(self, prompt, max_tokens=4096):
"""
Calls the OpenAI API with streaming enabled to get completions for a given prompt
and streams the output to the console in real time, while also accumulating the response to return.

Parameters:
prompt (str): The prompt to send to the model.
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 4096.

Returns:
str: The text generated by the model.
"""

model = self.model
if self.azure:
model = 'azure/' + model
messages = [{"role": "user", "content": prompt}]
kwargs = {
"model": model,
"messages": messages,
"temperature": 0,
"force_timeout": 180,
"api_base": self.api_base,
}
if self.aws_bedrock_client:
kwargs["aws_bedrock_client"] = self.aws_bedrock_client
if self.repetition_penalty:
kwargs["repetition_penalty"] = self.repetition_penalty

response = completion(**kwargs)
full_text = response.choices[0].model_extra['message']['content']
return full_text.strip()

def count_tokens(self, text):
"""
Counts the number of tokens in the given text using the model's encoding.

Parameters:
text (str): The text to encode.

Returns:
int: The number of tokens.
"""
try:
return len(self.encoding.encode(text))
except Exception as e:
raise ValueError(f"Error encoding text: {e}")
2 changes: 1 addition & 1 deletion cover_agent/UnitTestGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from cover_agent.CoverageProcessor import CoverageProcessor
from cover_agent.CustomLogger import CustomLogger
from cover_agent.PromptBuilder import PromptBuilder
from cover_agent.AICaller import AICaller
from cover_agent.LiteLLMAICaller import LiteLLMAICaller as AICaller
from cover_agent.FilePreprocessor import FilePreprocessor


Expand Down
7 changes: 6 additions & 1 deletion cover_agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from cover_agent.ReportGenerator import ReportGenerator
from cover_agent.UnitTestGenerator import UnitTestGenerator
from cover_agent.version import __version__
from settings.config_loader import get_settings


def parse_args():
Expand Down Expand Up @@ -140,8 +141,12 @@ def main():
logger.info(f"Desired Coverage: {test_gen.desired_coverage}%")

# Generate tests by making a call to the LLM
if get_settings().get("config.model", None):
model = get_settings().get("config.model")
else:
model = args.openai_model
generated_tests = test_gen.generate_tests(
LLM_model=args.openai_model, max_tokens=4096
LLM_model=model, max_tokens=4096
)

# Write test_gen.prompt to a debug markdown file
Expand Down
Loading
Loading