Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement: Support for HuggingFace Self-Served TGI Serving with Updated Dependencies and Improved AI Handler #472

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 34 additions & 17 deletions pr_agent/algo/ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import litellm
import openai
from litellm import acompletion
from openai.error import APIError, RateLimitError, Timeout, TryAgain
from openai import APIError, RateLimitError
from retry import retry
from tenacity import TryAgain

from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger

Expand All @@ -24,7 +26,7 @@ def __init__(self):
Raises a ValueError if the OpenAI key is missing.
"""
self.azure = False

self.api_base = None
if get_settings().get("OPENAI.KEY", None):
openai.api_key = get_settings().openai.key
litellm.openai_key = get_settings().openai.key
Expand Down Expand Up @@ -53,8 +55,9 @@ def __init__(self):
litellm.replicate_key = get_settings().replicate.key
if get_settings().get("HUGGINGFACE.KEY", None):
litellm.huggingface_key = get_settings().huggingface.key
if get_settings().get("HUGGINGFACE.API_BASE", None):
litellm.api_base = get_settings().huggingface.api_base
if get_settings().get("HUGGINGFACE.API_BASE", None):
litellm.api_base = get_settings().huggingface.api_base
self.api_base = get_settings().huggingface.api_base
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None):
litellm.vertex_project = get_settings().vertexai.vertex_project
litellm.vertex_location = get_settings().get(
Expand All @@ -68,22 +71,29 @@ def deployment_id(self):
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)

@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
@retry(
exceptions=(APIError, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES,
delay=2,
backoff=2,
jitter=(1, 3),
)
Comment on lines +74 to +80
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: It's a good practice to avoid using magic numbers directly in the code. The numbers used in the retry decorator in the chat_completion method (like 2, 1, 3) could be replaced with well-named constants to improve code readability and maintainability. [best practice]

Suggested change
@retry(
exceptions=(APIError, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES,
delay=2,
backoff=2,
jitter=(1, 3),
)
DELAY = 2
BACKOFF = 2
JITTER = (1, 3)
@retry(
exceptions=(APIError, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES,
delay=DELAY,
backoff=BACKOFF,
jitter=JITTER,
)

async def chat_completion(
self, model: str, system: str, user: str, temperature: float = 0.2
):
"""
Performs a chat completion using the OpenAI ChatCompletion API.
Retries in case of API errors or timeouts.

Args:
model (str): The model to use for chat completion.
temperature (float): The temperature parameter for chat completion.
system (str): The system message for chat completion.
user (str): The user message for chat completion.

Returns:
tuple: A tuple containing the response and finish reason from the API.

Raises:
TryAgain: If the API response is empty or there are no choices in the response.
APIError: If there is an error during OpenAI inference.
Expand All @@ -105,22 +115,29 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
deployment_id=deployment_id,
messages=messages,
temperature=temperature,
force_timeout=get_settings().config.ai_timeout
force_timeout=get_settings().config.ai_timeout,
api_base=self.api_base,
)
except (APIError, Timeout, TryAgain) as e:
except (APIError, TryAgain) as e:
mrT23 marked this conversation as resolved.
Show resolved Hide resolved
get_logger().error("Error during OpenAI inference: ", e)
raise
except (RateLimitError) as e:
except RateLimitError as e:
get_logger().error("Rate limit error during OpenAI inference: ", e)
raise
except (Exception) as e:
except Exception as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise TryAgain from e
mrT23 marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines +139 to 141
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Consider using a more specific exception instead of the generic Exception in the chat_completion method. This will make the error handling more precise and prevent catching and handling unexpected exceptions that might need different handling logic. [best practice]

Suggested change
except Exception as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise TryAgain from e
except SpecificException as e: # Replace SpecificException with the specific exception you want to catch
get_logger().error("Specific error during OpenAI inference: ", e)
raise TryAgain from e

if response is None or len(response["choices"]) == 0:
raise TryAgain
resp = response["choices"][0]['message']['content']
resp = response["choices"][0]["message"]["content"]
finish_reason = response["choices"][0]["finish_reason"]
mrT23 marked this conversation as resolved.
Show resolved Hide resolved
usage = response.get("usage")
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
model=model, usage=usage)
get_logger().info(
mrT23 marked this conversation as resolved.
Show resolved Hide resolved
"AI response",
response=resp,
messages=messages,
finish_reason=finish_reason,
model=model,
usage=usage,
)
return resp, finish_reason
2 changes: 1 addition & 1 deletion pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def get_user_labels(current_labels: List[str] = None):

def get_max_tokens(model):
settings = get_settings()
max_tokens_model = MAX_TOKENS[model]
max_tokens_model = MAX_TOKENS.get(model, 4000)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Consider adding a default value for MAX_TOKENS in the get_max_tokens function. This will make the function more robust in case the MAX_TOKENS dictionary doesn't contain the model as a key. [enhancement]

Suggested change
max_tokens_model = MAX_TOKENS.get(model, 4000)
DEFAULT_MAX_TOKENS = 4000
max_tokens_model = MAX_TOKENS.get(model, DEFAULT_MAX_TOKENS)

if settings.config.max_model_tokens:
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model)
# get_logger().debug(f"limiting max tokens to {max_tokens_model}")
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ dynaconf==3.1.12
fastapi==0.99.0
PyGithub==1.59.*
retry==0.9.2
openai==0.27.8
openai==1.3.5
Jinja2==3.1.2
tiktoken==0.4.0
uvicorn==0.22.0
Expand All @@ -13,7 +13,6 @@ atlassian-python-api==3.39.0
GitPython==3.1.32
PyYAML==6.0
starlette-context==0.3.6
litellm==0.12.5
boto3==1.28.25
google-cloud-storage==2.10.0
ujson==5.8.0
Expand All @@ -23,3 +22,4 @@ pinecone-client
pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main
loguru==0.7.2
google-cloud-aiplatform==1.35.0
litellm @ git+https://github.com/Codium-ai/litellm.git
Loading