-
Notifications
You must be signed in to change notification settings - Fork 376
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add LiteLLMAICaller class for AI model interactions, integrate settin…
…gs with config loader, and update dependencies
- Loading branch information
Showing
7 changed files
with
1,021 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import tiktoken | ||
from litellm import completion | ||
import openai | ||
import litellm | ||
|
||
from settings.config_loader import get_settings | ||
|
||
|
||
class LiteLLMAICaller: | ||
def __init__(self, model): | ||
self.azure = False | ||
self.aws_bedrock_client = None | ||
self.api_base = None | ||
self.repetition_penalty = None | ||
if get_settings().get("OPENAI_API_KEY", None): | ||
openai.api_key = get_settings().OPENAI_API_KEY | ||
litellm.openai_key = get_settings().OPENAI_API_KEY | ||
if get_settings().get("LITELLM.DROP_PARAMS", None): | ||
litellm.drop_params = get_settings().litellm.drop_params | ||
if get_settings().get("OPENAI.ORG", None): | ||
litellm.organization = get_settings().openai.org | ||
if get_settings().get("OPENAI.API_TYPE", None): | ||
if get_settings().openai.api_type == "azure": | ||
self.azure = True | ||
litellm.azure_key = get_settings().openai.key | ||
if get_settings().get("OPENAI.API_VERSION", None): | ||
litellm.api_version = get_settings().openai.api_version | ||
if get_settings().get("OPENAI.API_BASE", None): | ||
litellm.api_base = get_settings().openai.api_base | ||
if get_settings().get("ANTHROPIC.KEY", None): | ||
litellm.anthropic_key = get_settings().anthropic.key | ||
if get_settings().get("COHERE.KEY", None): | ||
litellm.cohere_key = get_settings().cohere.key | ||
if get_settings().get("GROQ.KEY", None): | ||
litellm.api_key = get_settings().groq.key | ||
if get_settings().get("REPLICATE.KEY", None): | ||
litellm.replicate_key = get_settings().replicate.key | ||
if get_settings().get("HUGGINGFACE.KEY", None): | ||
litellm.huggingface_key = get_settings().huggingface.key | ||
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model: | ||
litellm.api_base = get_settings().huggingface.api_base | ||
self.api_base = get_settings().huggingface.api_base | ||
if get_settings().get("OLLAMA.API_BASE", None): | ||
litellm.api_base = get_settings().ollama.api_base | ||
self.api_base = get_settings().ollama.api_base | ||
if get_settings().get("HUGGINGFACE.REPITITION_PENALTY", None): | ||
self.repetition_penalty = float(get_settings().huggingface.repetition_penalty) | ||
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None): | ||
litellm.vertex_project = get_settings().vertexai.vertex_project | ||
litellm.vertex_location = get_settings().get( | ||
"VERTEXAI.VERTEX_LOCATION", None | ||
) | ||
if get_settings().get("AWS.BEDROCK_REGION", None): | ||
litellm.AmazonAnthropicConfig.max_tokens_to_sample = 2000 | ||
litellm.AmazonAnthropicClaude3Config.max_tokens = 2000 | ||
self.aws_bedrock_client = boto3.client( | ||
service_name="bedrock-runtime", | ||
region_name=get_settings().aws.bedrock_region, | ||
) | ||
|
||
self.model = model | ||
|
||
# Initialize the encoding for the model | ||
try: | ||
self.encoding = tiktoken.get_encoding("cl100k_base") | ||
except Exception as e: | ||
raise ValueError(f"Failed to get encoding: {e}") | ||
|
||
def call_model(self, prompt, max_tokens=4096): | ||
""" | ||
Calls the OpenAI API with streaming enabled to get completions for a given prompt | ||
and streams the output to the console in real time, while also accumulating the response to return. | ||
Parameters: | ||
prompt (str): The prompt to send to the model. | ||
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 4096. | ||
Returns: | ||
str: The text generated by the model. | ||
""" | ||
|
||
model = self.model | ||
if self.azure: | ||
model = 'azure/' + model | ||
messages = [{"role": "user", "content": prompt}] | ||
kwargs = { | ||
"model": model, | ||
"messages": messages, | ||
"temperature": 0, | ||
"force_timeout": 180, | ||
"api_base": self.api_base, | ||
} | ||
if self.aws_bedrock_client: | ||
kwargs["aws_bedrock_client"] = self.aws_bedrock_client | ||
if self.repetition_penalty: | ||
kwargs["repetition_penalty"] = self.repetition_penalty | ||
|
||
response = completion(**kwargs) | ||
full_text = response.choices[0].model_extra['message']['content'] | ||
return full_text.strip() | ||
|
||
def count_tokens(self, text): | ||
""" | ||
Counts the number of tokens in the given text using the model's encoding. | ||
Parameters: | ||
text (str): The text to encode. | ||
Returns: | ||
int: The number of tokens. | ||
""" | ||
try: | ||
return len(self.encoding.encode(text)) | ||
except Exception as e: | ||
raise ValueError(f"Error encoding text: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.