Skip to content

Commit

Permalink
Add LiteLLMAICaller class for AI model interactions, integrate settin…
Browse files Browse the repository at this point in the history
…gs with config loader, and update dependencies
  • Loading branch information
mrT23 committed May 21, 2024
1 parent 6778dd5 commit ceffd76
Show file tree
Hide file tree
Showing 7 changed files with 1,021 additions and 112 deletions.
115 changes: 115 additions & 0 deletions cover_agent/LiteLLMAICaller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import tiktoken
from litellm import completion
import openai
import litellm

from settings.config_loader import get_settings


class LiteLLMAICaller:
def __init__(self, model):
self.azure = False
self.aws_bedrock_client = None
self.api_base = None
self.repetition_penalty = None
if get_settings().get("OPENAI_API_KEY", None):
openai.api_key = get_settings().OPENAI_API_KEY
litellm.openai_key = get_settings().OPENAI_API_KEY
if get_settings().get("LITELLM.DROP_PARAMS", None):
litellm.drop_params = get_settings().litellm.drop_params
if get_settings().get("OPENAI.ORG", None):
litellm.organization = get_settings().openai.org
if get_settings().get("OPENAI.API_TYPE", None):
if get_settings().openai.api_type == "azure":
self.azure = True
litellm.azure_key = get_settings().openai.key
if get_settings().get("OPENAI.API_VERSION", None):
litellm.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None):
litellm.api_base = get_settings().openai.api_base
if get_settings().get("ANTHROPIC.KEY", None):
litellm.anthropic_key = get_settings().anthropic.key
if get_settings().get("COHERE.KEY", None):
litellm.cohere_key = get_settings().cohere.key
if get_settings().get("GROQ.KEY", None):
litellm.api_key = get_settings().groq.key
if get_settings().get("REPLICATE.KEY", None):
litellm.replicate_key = get_settings().replicate.key
if get_settings().get("HUGGINGFACE.KEY", None):
litellm.huggingface_key = get_settings().huggingface.key
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model:
litellm.api_base = get_settings().huggingface.api_base
self.api_base = get_settings().huggingface.api_base
if get_settings().get("OLLAMA.API_BASE", None):
litellm.api_base = get_settings().ollama.api_base
self.api_base = get_settings().ollama.api_base
if get_settings().get("HUGGINGFACE.REPITITION_PENALTY", None):
self.repetition_penalty = float(get_settings().huggingface.repetition_penalty)
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None):
litellm.vertex_project = get_settings().vertexai.vertex_project
litellm.vertex_location = get_settings().get(
"VERTEXAI.VERTEX_LOCATION", None
)
if get_settings().get("AWS.BEDROCK_REGION", None):
litellm.AmazonAnthropicConfig.max_tokens_to_sample = 2000
litellm.AmazonAnthropicClaude3Config.max_tokens = 2000
self.aws_bedrock_client = boto3.client(
service_name="bedrock-runtime",
region_name=get_settings().aws.bedrock_region,
)

self.model = model

# Initialize the encoding for the model
try:
self.encoding = tiktoken.get_encoding("cl100k_base")
except Exception as e:
raise ValueError(f"Failed to get encoding: {e}")

def call_model(self, prompt, max_tokens=4096):
"""
Calls the OpenAI API with streaming enabled to get completions for a given prompt
and streams the output to the console in real time, while also accumulating the response to return.
Parameters:
prompt (str): The prompt to send to the model.
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 4096.
Returns:
str: The text generated by the model.
"""

model = self.model
if self.azure:
model = 'azure/' + model
messages = [{"role": "user", "content": prompt}]
kwargs = {
"model": model,
"messages": messages,
"temperature": 0,
"force_timeout": 180,
"api_base": self.api_base,
}
if self.aws_bedrock_client:
kwargs["aws_bedrock_client"] = self.aws_bedrock_client
if self.repetition_penalty:
kwargs["repetition_penalty"] = self.repetition_penalty

response = completion(**kwargs)
full_text = response.choices[0].model_extra['message']['content']
return full_text.strip()

def count_tokens(self, text):
"""
Counts the number of tokens in the given text using the model's encoding.
Parameters:
text (str): The text to encode.
Returns:
int: The number of tokens.
"""
try:
return len(self.encoding.encode(text))
except Exception as e:
raise ValueError(f"Error encoding text: {e}")
2 changes: 1 addition & 1 deletion cover_agent/UnitTestGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from cover_agent.CoverageProcessor import CoverageProcessor
from cover_agent.CustomLogger import CustomLogger
from cover_agent.PromptBuilder import PromptBuilder
from cover_agent.AICaller import AICaller
from cover_agent.LiteLLMAICaller import LiteLLMAICaller as AICaller
from cover_agent.FilePreprocessor import FilePreprocessor


Expand Down
7 changes: 6 additions & 1 deletion cover_agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from cover_agent.ReportGenerator import ReportGenerator
from cover_agent.UnitTestGenerator import UnitTestGenerator
from cover_agent.version import __version__
from settings.config_loader import get_settings


def parse_args():
Expand Down Expand Up @@ -140,8 +141,12 @@ def main():
logger.info(f"Desired Coverage: {test_gen.desired_coverage}%")

# Generate tests by making a call to the LLM
if get_settings().get("config.model", None):
model = get_settings().get("config.model")
else:
model = args.openai_model
generated_tests = test_gen.generate_tests(
LLM_model=args.openai_model, max_tokens=4096
LLM_model=model, max_tokens=4096
)

# Write test_gen.prompt to a debug markdown file
Expand Down
Loading

0 comments on commit ceffd76

Please sign in to comment.