Skip to content

Commit

Permalink
Merge pull request #36 from allegro/unsafe-prompt-handling
Browse files Browse the repository at this point in the history
Handle empty responses from VertexAI properly
  • Loading branch information
megatron6000 authored Aug 30, 2024
2 parents 28a785d + 58f8099 commit 8f43d92
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
3 changes: 2 additions & 1 deletion allms/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from allms.domain.input_data import InputData
from allms.domain.prompt_dto import SummaryOutputClass, KeywordsOutputClass
from allms.domain.response import ResponseData
from allms.models.vertexai_base import GCPInvalidRequestError
from allms.utils.long_text_processing_utils import get_max_allowed_number_of_tokens
from allms.utils.response_parsing_utils import ResponseParser

Expand Down Expand Up @@ -260,7 +261,7 @@ async def _predict_example(
model_response = None
error_message = f"{IODataConstants.ERROR_MESSAGE_STR}: {invalid_request_error}"

except (InvalidArgument, ValueError, TimeoutError, openai.error.Timeout) as other_error:
except (InvalidArgument, ValueError, TimeoutError, openai.error.Timeout, GCPInvalidRequestError) as other_error:
model_response = None
logger.info(f"Error for id {input_data.id} has occurred. Message: {other_error} ")
error_message = f"{type(other_error).__name__}: {other_error}"
Expand Down
7 changes: 7 additions & 0 deletions allms/models/vertexai_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from allms.constants.vertex_ai import VertexModelConstants


class GCPInvalidRequestError(Exception):
pass


class CustomVertexAI(VertexAI):
async def _agenerate(
self,
Expand All @@ -31,6 +35,9 @@ def was_response_blocked(generation: Generation) -> bool:
**kwargs
)

if not all(result.generations):
raise GCPInvalidRequestError("The response is empty. It may have been blocked due to content filtering.")

return LLMResult(
generations=(
chain(result.generations)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "allms"
version = "1.0.8"
version = "1.0.9"
description = ""
authors = ["Allegro Opensource <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit 8f43d92

Please sign in to comment.