-
Notifications
You must be signed in to change notification settings - Fork 651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhancement: Support for HuggingFace Self-Served TGI Serving with Updated Dependencies and Improved AI Handler #472
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3,8 +3,10 @@ | |||||||||||||
import litellm | ||||||||||||||
import openai | ||||||||||||||
from litellm import acompletion | ||||||||||||||
from openai.error import APIError, RateLimitError, Timeout, TryAgain | ||||||||||||||
from openai import APIError, RateLimitError | ||||||||||||||
from retry import retry | ||||||||||||||
from tenacity import TryAgain | ||||||||||||||
|
||||||||||||||
from pr_agent.config_loader import get_settings | ||||||||||||||
from pr_agent.log import get_logger | ||||||||||||||
|
||||||||||||||
|
@@ -24,7 +26,7 @@ def __init__(self): | |||||||||||||
Raises a ValueError if the OpenAI key is missing. | ||||||||||||||
""" | ||||||||||||||
self.azure = False | ||||||||||||||
|
||||||||||||||
self.api_base = None | ||||||||||||||
if get_settings().get("OPENAI.KEY", None): | ||||||||||||||
openai.api_key = get_settings().openai.key | ||||||||||||||
litellm.openai_key = get_settings().openai.key | ||||||||||||||
|
@@ -53,8 +55,9 @@ def __init__(self): | |||||||||||||
litellm.replicate_key = get_settings().replicate.key | ||||||||||||||
if get_settings().get("HUGGINGFACE.KEY", None): | ||||||||||||||
litellm.huggingface_key = get_settings().huggingface.key | ||||||||||||||
if get_settings().get("HUGGINGFACE.API_BASE", None): | ||||||||||||||
litellm.api_base = get_settings().huggingface.api_base | ||||||||||||||
if get_settings().get("HUGGINGFACE.API_BASE", None): | ||||||||||||||
litellm.api_base = get_settings().huggingface.api_base | ||||||||||||||
self.api_base = get_settings().huggingface.api_base | ||||||||||||||
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None): | ||||||||||||||
litellm.vertex_project = get_settings().vertexai.vertex_project | ||||||||||||||
litellm.vertex_location = get_settings().get( | ||||||||||||||
|
@@ -68,22 +71,29 @@ def deployment_id(self): | |||||||||||||
""" | ||||||||||||||
return get_settings().get("OPENAI.DEPLOYMENT_ID", None) | ||||||||||||||
|
||||||||||||||
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), | ||||||||||||||
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) | ||||||||||||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): | ||||||||||||||
@retry( | ||||||||||||||
exceptions=(APIError, TryAgain, AttributeError, RateLimitError), | ||||||||||||||
tries=OPENAI_RETRIES, | ||||||||||||||
delay=2, | ||||||||||||||
backoff=2, | ||||||||||||||
jitter=(1, 3), | ||||||||||||||
) | ||||||||||||||
async def chat_completion( | ||||||||||||||
self, model: str, system: str, user: str, temperature: float = 0.2 | ||||||||||||||
): | ||||||||||||||
""" | ||||||||||||||
Performs a chat completion using the OpenAI ChatCompletion API. | ||||||||||||||
Retries in case of API errors or timeouts. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
model (str): The model to use for chat completion. | ||||||||||||||
temperature (float): The temperature parameter for chat completion. | ||||||||||||||
system (str): The system message for chat completion. | ||||||||||||||
user (str): The user message for chat completion. | ||||||||||||||
|
||||||||||||||
Returns: | ||||||||||||||
tuple: A tuple containing the response and finish reason from the API. | ||||||||||||||
|
||||||||||||||
Raises: | ||||||||||||||
TryAgain: If the API response is empty or there are no choices in the response. | ||||||||||||||
APIError: If there is an error during OpenAI inference. | ||||||||||||||
|
@@ -105,22 +115,29 @@ async def chat_completion(self, model: str, system: str, user: str, temperature: | |||||||||||||
deployment_id=deployment_id, | ||||||||||||||
messages=messages, | ||||||||||||||
temperature=temperature, | ||||||||||||||
force_timeout=get_settings().config.ai_timeout | ||||||||||||||
force_timeout=get_settings().config.ai_timeout, | ||||||||||||||
api_base=self.api_base, | ||||||||||||||
) | ||||||||||||||
except (APIError, Timeout, TryAgain) as e: | ||||||||||||||
except (APIError, TryAgain) as e: | ||||||||||||||
mrT23 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
get_logger().error("Error during OpenAI inference: ", e) | ||||||||||||||
raise | ||||||||||||||
except (RateLimitError) as e: | ||||||||||||||
except RateLimitError as e: | ||||||||||||||
get_logger().error("Rate limit error during OpenAI inference: ", e) | ||||||||||||||
raise | ||||||||||||||
except (Exception) as e: | ||||||||||||||
except Exception as e: | ||||||||||||||
get_logger().error("Unknown error during OpenAI inference: ", e) | ||||||||||||||
raise TryAgain from e | ||||||||||||||
mrT23 marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+139
to
141
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: Consider using a more specific exception instead of the generic
Suggested change
|
||||||||||||||
if response is None or len(response["choices"]) == 0: | ||||||||||||||
raise TryAgain | ||||||||||||||
resp = response["choices"][0]['message']['content'] | ||||||||||||||
resp = response["choices"][0]["message"]["content"] | ||||||||||||||
finish_reason = response["choices"][0]["finish_reason"] | ||||||||||||||
mrT23 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
usage = response.get("usage") | ||||||||||||||
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, | ||||||||||||||
model=model, usage=usage) | ||||||||||||||
get_logger().info( | ||||||||||||||
mrT23 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
"AI response", | ||||||||||||||
response=resp, | ||||||||||||||
messages=messages, | ||||||||||||||
finish_reason=finish_reason, | ||||||||||||||
model=model, | ||||||||||||||
usage=usage, | ||||||||||||||
) | ||||||||||||||
return resp, finish_reason |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -370,7 +370,7 @@ def get_user_labels(current_labels: List[str] = None): | |||||||
|
||||||||
def get_max_tokens(model): | ||||||||
settings = get_settings() | ||||||||
max_tokens_model = MAX_TOKENS[model] | ||||||||
max_tokens_model = MAX_TOKENS.get(model, 4000) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: Consider adding a default value for
Suggested change
|
||||||||
if settings.config.max_model_tokens: | ||||||||
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model) | ||||||||
# get_logger().debug(f"limiting max tokens to {max_tokens_model}") | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: It's a good practice to avoid using magic numbers directly in the code. The numbers used in the
retry
decorator in thechat_completion
method (like2
,1
,3
) could be replaced with well-named constants to improve code readability and maintainability. [best practice]