-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Gemini as an Adapter #419
base: dev
Are you sure you want to change the base?
Conversation
WalkthroughThe pull request introduces significant enhancements to the embedding and language model infrastructure. Key modifications include improved configurability of the Changes
Sequence DiagramsequenceDiagram
participant Client
participant EmbeddingEngine
participant LLMProvider
participant VectorDatabase
Client->>EmbeddingEngine: Initialize with provider
EmbeddingEngine->>LLMProvider: Validate configuration
LLMProvider-->>EmbeddingEngine: Return config details
Client->>EmbeddingEngine: Embed text
EmbeddingEngine->>LLMProvider: Request embeddings
LLMProvider-->>EmbeddingEngine: Return embeddings
EmbeddingEngine->>VectorDatabase: Store embeddings
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (2)
🔇 Additional comments (5)
Finishing Touches
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (6)
cognee/infrastructure/llm/gemini/adapter.py (2)
115-120
: Avoid catching broadException
; catch specific exceptions insteadCatching the broad
Exception
can mask unexpected errors and make debugging difficult. It's advisable to catch specific exceptions that may occur during theacompletion
call or re-raise the exception after logging.Consider updating the exception handling within the retry loop:
for attempt in range(3): try: response = await acompletion( model=f"gemini/{self.model}", messages=messages, api_key=self.api_key, max_tokens=self.MAX_TOKENS, temperature=0.1, response_format={ "type": "json_object", "schema": response_schema } ) if response.choices and response.choices[0].message.content: content = response.choices[0].message.content return response_model.model_validate_json(content) - except Exception as e: + except (SpecificExceptionType1, SpecificExceptionType2) as e: if attempt == 2: # Last attempt raise logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") continueReplace
SpecificExceptionType1
,SpecificExceptionType2
with the actual exceptions you expect.
127-129
: Avoid catching broadException
in top-level exception handlingAt the top-level exception handling, catching a broad
Exception
can hide other errors. It's better to catch specific exceptions or re-raise the exception after logging to avoid masking unexpected issues.Consider refining the exception handling:
except JSONSchemaValidationError as e: logger.error(f"Schema validation failed: {str(e)}") logger.debug(f"Raw response: {e.raw_response}") raise ValueError(f"Response failed schema validation: {str(e)}") - except Exception as e: + except SpecificExceptionType as e: logger.error(f"Error in structured output generation: {str(e)}") raise ValueError(f"Failed to generate structured output: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error: {str(e)}") + raiseReplace
SpecificExceptionType
with the actual exception you expect. This approach allows unexpected exceptions to be raised after logging.cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py (1)
16-19
: Consider adding type hints for engine_args dictionary.While the code is functionally correct, adding type hints would improve code maintainability and IDE support.
- engine_args = { + engine_args: dict[str, str | int] = {cognee/infrastructure/llm/get_llm_client.py (2)
15-15
: Fix enum value formatting.The enum value formatting is inconsistent with other values. There should be a space after the equals sign.
- GEMINI="gemini" + GEMINI = "gemini"
64-72
: Consider adding additional configuration options for Gemini.The Gemini adapter initialization is simpler compared to other providers (e.g., OpenAI) which support additional parameters like endpoint, api_version, and streaming. Consider whether these options should also be supported for consistency.
return GeminiAdapter( api_key=llm_config.llm_api_key, - model=llm_config.llm_model + model=llm_config.llm_model, + endpoint=llm_config.llm_endpoint, + api_version=llm_config.llm_api_version, + streaming=llm_config.llm_streaming )cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py (1)
266-266
: Consider adding logging for temporary storage cleanup.Adding logging would help track when temporary storage is cleaned up.
+ import logging + logging.info(f"Cleaning up temporary storage at {self.url}") LocalStorage.remove_all(self.url)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py
(3 hunks)cognee/infrastructure/databases/vector/embeddings/config.py
(1 hunks)cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py
(1 hunks)cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py
(1 hunks)cognee/infrastructure/llm/gemini/adapter.py
(1 hunks)cognee/infrastructure/llm/get_llm_client.py
(2 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py
78-78: Local variable api_base
is assigned to but never used
Remove assignment to unused variable api_base
(F841)
🔇 Additional comments (3)
cognee/infrastructure/databases/vector/embeddings/config.py (1)
7-8
: Configuration defaults set toNone
for flexibilitySetting
embedding_model
andembedding_dimensions
toNone
allows these values to be determined based on the provider configurations, enhancing flexibility.cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py (2)
11-13
: LGTM! Good separation of provider-specific configurations.The introduction of provider-specific configurations improves modularity and makes it easier to add new providers like Gemini.
27-30
: Verify provider-specific configurations.The code assumes that provider-specific configurations are properly defined in LiteLLMEmbeddingEngine.PROVIDER_CONFIGS.
✅ Verification successful
Provider-specific configurations are properly defined
The PROVIDER_CONFIGS dictionary in LiteLLMEmbeddingEngine correctly defines configurations for both OpenAI and Gemini providers, including model names, dimensions, and API base URLs.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check if Gemini provider configs are properly defined ast-grep --pattern 'PROVIDER_CONFIGS = { $$$ "gemini": { $$$ } $$$ }'Length of output: 83
Script:
#!/bin/bash # Find the LiteLLMEmbeddingEngine class definition rg "class LiteLLMEmbeddingEngine" -l # Then examine its content for PROVIDER_CONFIGS rg "PROVIDER_CONFIGS\s*=" -A 10Length of output: 1350
if self.provider == "gemini": | ||
model_name = f"gemini/{self.model}" | ||
# For Gemini, we need to ensure we're using their specific endpoint format | ||
api_base = f"{self.endpoint}/models/{self.model}:embedContent" | ||
else: | ||
model_name = self.model | ||
api_base = self.endpoint | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the usage of api_base
to ensure the correct API endpoint is used
The variable api_base
is assigned but not used in the litellm.aembedding
call. This may lead to incorrect API endpoints being used, especially for the 'gemini' provider. Update the function call to use the api_base
variable to ensure the correct endpoint is utilized.
Apply this diff to fix the issue:
response = await litellm.aembedding(
model=model_name,
input=text,
api_key=self.api_key,
- api_base=self.endpoint,
+ api_base=api_base,
api_version=self.api_version,
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if self.provider == "gemini": | |
model_name = f"gemini/{self.model}" | |
# For Gemini, we need to ensure we're using their specific endpoint format | |
api_base = f"{self.endpoint}/models/{self.model}:embedContent" | |
else: | |
model_name = self.model | |
api_base = self.endpoint | |
if self.provider == "gemini": | |
model_name = f"gemini/{self.model}" | |
# For Gemini, we need to ensure we're using their specific endpoint format | |
api_base = f"{self.endpoint}/models/{self.model}:embedContent" | |
else: | |
model_name = self.model | |
api_base = self.endpoint | |
response = await litellm.aembedding( | |
model=model_name, | |
input=text, | |
api_key=self.api_key, | |
api_base=api_base, | |
api_version=self.api_version, | |
) |
🧰 Tools
🪛 Ruff (0.8.2)
78-78: Local variable api_base
is assigned to but never used
Remove assignment to unused variable api_base
(F841)
connection = await self.get_connection() | ||
collection_names = await connection.table_names() | ||
|
||
# Delete all collections | ||
for collection_name in collection_names: | ||
collection = await connection.open_table(collection_name) | ||
# Delete all records in the collection | ||
await collection.delete("id IS NOT NULL") | ||
# Drop the collection itself | ||
await connection.drop_table(collection_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling and transaction support to prune method.
The prune method could be more robust with:
- Error handling for failed deletions
- Transaction support to ensure atomicity
- Progress tracking for large databases
async def prune(self):
connection = await self.get_connection()
collection_names = await connection.table_names()
+ failed_collections = []
# Delete all collections
- for collection_name in collection_names:
- collection = await connection.open_table(collection_name)
- # Delete all records in the collection
- await collection.delete("id IS NOT NULL")
- # Drop the collection itself
- await connection.drop_table(collection_name)
+ try:
+ for collection_name in collection_names:
+ try:
+ collection = await connection.open_table(collection_name)
+ # Delete all records in the collection
+ await collection.delete("id IS NOT NULL")
+ # Drop the collection itself
+ await connection.drop_table(collection_name)
+ except Exception as e:
+ failed_collections.append((collection_name, str(e)))
+
+ if failed_collections:
+ raise Exception(f"Failed to delete collections: {failed_collections}")
+ except Exception as e:
+ raise Exception(f"Database pruning failed: {str(e)}")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
connection = await self.get_connection() | |
collection_names = await connection.table_names() | |
# Delete all collections | |
for collection_name in collection_names: | |
collection = await connection.open_table(collection_name) | |
# Delete all records in the collection | |
await collection.delete("id IS NOT NULL") | |
# Drop the collection itself | |
await connection.drop_table(collection_name) | |
connection = await self.get_connection() | |
collection_names = await connection.table_names() | |
failed_collections = [] | |
# Delete all collections | |
try: | |
for collection_name in collection_names: | |
try: | |
collection = await connection.open_table(collection_name) | |
# Delete all records in the collection | |
await collection.delete("id IS NOT NULL") | |
# Drop the collection itself | |
await connection.drop_table(collection_name) | |
except Exception as e: | |
failed_collections.append((collection_name, str(e))) | |
if failed_collections: | |
raise Exception(f"Failed to delete collections: {failed_collections}") | |
except Exception as e: | |
raise Exception(f"Database pruning failed: {str(e)}") |
connection = await self.get_connection() | ||
collection_names = await connection.table_names() | ||
|
||
# Delete all collections | ||
for collection_name in collection_names: | ||
collection = await connection.open_table(collection_name) | ||
# Delete all records in the collection | ||
await collection.delete("id IS NOT NULL") | ||
# Drop the collection itself | ||
await connection.drop_table(collection_name) | ||
|
||
# Clean up the database if it was set up as temporary | ||
if self.url.startswith("/"): | ||
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside | ||
LocalStorage.remove_all(self.url) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
The prune operation has potential concurrency issues:
- Uses a shared connection instance without synchronization
- No transaction isolation for the multi-step deletion process
- Risk of race conditions if other operations access collections during pruning
Add documentation warning that:
- Prune should not be called while other operations are in progress
- Consider implementing application-level locks before calling prune
- Best used during maintenance windows or system initialization
🔗 Analysis chain
Verify the impact of pruning on concurrent operations.
The prune operation might affect concurrent database operations. Consider adding documentation about potential impacts and best practices for when to use this method.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for concurrent database operations in the codebase
rg -A 5 "open_table|create_table|delete|merge_insert"
Length of output: 47016
Script:
#!/bin/bash
# Search for connection management, pooling, and transaction patterns
rg -A 5 "connection_pool|transaction|lock|concurrent|semaphore|mutex" cognee/infrastructure/databases/vector/lancedb/
# Look for any LanceDB configuration or setup code
ast-grep --pattern 'class LanceDBAdapter {
$$$
__init__($$$) {
$$$
}
$$$
}'
Length of output: 209
Script:
#!/bin/bash
# Search for connection handling in LanceDBAdapter
rg -A 10 "get_connection|connect|Connection" cognee/infrastructure/databases/vector/lancedb/
# Look for any async connection management
ast-grep --pattern 'async def get_connection($$$) {
$$$
}'
Length of output: 13644
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (1)
cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py (1)
254-261
:⚠️ Potential issueAdd error handling and transaction support.
The database operations need proper error handling and transaction support to ensure data consistency.
async def prune(self): connection = await self.get_connection() collection_names = await connection.table_names() + failed_collections = [] - for collection_name in collection_names: - collection = await connection.open_table(collection_name) - await collection.delete("id IS NOT NULL") - await connection.drop_table(collection_name) + try: + for collection_name in collection_names: + try: + collection = await connection.open_table(collection_name) + # Delete all records in the collection + await collection.delete("id IS NOT NULL") + # Drop the collection itself + await connection.drop_table(collection_name) + except Exception as e: + failed_collections.append((collection_name, str(e))) + logging.error(f"Failed to delete collection {collection_name}: {str(e)}") + + if failed_collections: + raise Exception(f"Failed to delete collections: {failed_collections}") + except Exception as e: + raise Exception(f"Database pruning failed: {str(e)}")
🧹 Nitpick comments (5)
cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py (1)
263-264
: Add error handling for local storage cleanup.The local storage cleanup needs error handling to catch potential file system errors.
- logging.info(f"Cleaning up temporary storage at {self.url}") - LocalStorage.remove_all(self.url) + try: + logging.info(f"Cleaning up temporary storage at {self.url}") + LocalStorage.remove_all(self.url) + except Exception as e: + logging.error(f"Failed to clean up temporary storage at {self.url}: {str(e)}") + # Don't raise as this is a cleanup operationcognee/infrastructure/llm/gemini/adapter.py (4)
11-26
: Add class and method documentation.Please add docstrings to improve code documentation:
- Class-level docstring explaining the purpose and usage of
GeminiAdapter
- Constructor docstring describing the parameters and their purposes
class GeminiAdapter: + """Adapter for interacting with Google's Gemini API to generate structured outputs. + + This adapter provides functionality to generate structured JSON responses using + the Gemini language model while ensuring schema validation and error handling. + """ MAX_TOKENS = 8192 def __init__(self, api_key: str, model: str, endpoint: Optional[str] = None, api_version: Optional[str] = None, streaming: bool = False ) -> None: + """Initialize the GeminiAdapter. + + Args: + api_key (str): The API key for authentication with Gemini API + model (str): The specific Gemini model to use + endpoint (Optional[str]): Custom API endpoint if different from default + api_version (Optional[str]): API version to use + streaming (bool): Whether to enable streaming responses + """
34-67
: Extract response schema as a class constant.The response schema should be defined as a class constant to improve reusability and maintainability. This would also make it easier to test and modify the schema independently.
class GeminiAdapter: MAX_TOKENS = 8192 + RESPONSE_SCHEMA = { + "type": "object", + "properties": { + "summary": {"type": "string"}, + "description": {"type": "string"}, + "nodes": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "type": {"type": "string"}, + "description": {"type": "string"}, + "id": {"type": "string"}, + "label": {"type": "string"} + }, + "required": ["name", "type", "description", "id", "label"] + } + }, + "edges": { + "type": "array", + "items": { + "type": "object", + "properties": { + "source_node_id": {"type": "string"}, + "target_node_id": {"type": "string"}, + "relationship_name": {"type": "string"} + }, + "required": ["source_node_id", "target_node_id", "relationship_name"] + } + } + }, + "required": ["summary", "description", "nodes", "edges"] + }
105-127
: Make retry configuration more flexible.The retry count and temperature are hardcoded. Consider making these configurable:
- Add retry configuration to the constructor
- Allow temperature to be customized per call
def __init__(self, api_key: str, model: str, endpoint: Optional[str] = None, api_version: Optional[str] = None, - streaming: bool = False + streaming: bool = False, + max_retries: int = 3, + default_temperature: float = 0.1 ) -> None: self.api_key = api_key self.model = model self.endpoint = endpoint self.api_version = api_version self.streaming = streaming + self.max_retries = max_retries + self.default_temperature = default_temperature async def acreate_structured_output( self, text_input: str, system_prompt: str, - response_model: Type[BaseModel] + response_model: Type[BaseModel], + temperature: Optional[float] = None ) -> BaseModel:
27-32
: Add method documentation.The
acreate_structured_output
method needs proper documentation explaining its parameters, return type, and possible exceptions.async def acreate_structured_output( self, text_input: str, system_prompt: str, response_model: Type[BaseModel] ) -> BaseModel: + """Generate structured output from text input using Gemini. + + Args: + text_input (str): The input text to process + system_prompt (str): The system prompt to guide the model + response_model (Type[BaseModel]): Pydantic model for response validation + + Returns: + BaseModel: Validated structured output matching the response_model + + Raises: + ValueError: If schema validation fails or API errors occur + """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py
(1 hunks)cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py
(2 hunks)cognee/infrastructure/llm/gemini/adapter.py
(1 hunks)cognee/infrastructure/llm/get_llm_client.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- cognee/infrastructure/llm/get_llm_client.py
- cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py
🧰 Additional context used
🪛 Ruff (0.8.2)
cognee/infrastructure/llm/gemini/adapter.py
7-7: Redefinition of unused acompletion
from line 5
Remove definition: acompletion
(F811)
7-7: Redefinition of unused JSONSchemaValidationError
from line 5
Remove definition: JSONSchemaValidationError
(F811)
🔇 Additional comments (3)
cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py (2)
5-5
: LGTM!The logging import is correctly placed with other standard library imports.
254-264
:⚠️ Add protection against concurrent operations.The prune operation needs protection against concurrent access to prevent data corruption.
Run this script to check for concurrent database operations:
#!/bin/bash # Description: Check for concurrent database operations # Look for connection pooling or locking mechanisms rg -A 5 "connection_pool|lock|concurrent|semaphore|mutex" cognee/infrastructure/databases/vector/lancedb/ # Check for other database operations that might conflict ast-grep --pattern 'async def $method($$$) { $$$ connection = await self.get_connection() $$$ }'cognee/infrastructure/llm/gemini/adapter.py (1)
1-140
: Overall implementation looks good!The
GeminiAdapter
implementation is well-structured with proper error handling and validation. The suggested improvements will make it more maintainable and configurable, but the current implementation is already functional and robust.🧰 Tools
🪛 Ruff (0.8.2)
7-7: Redefinition of unused
acompletion
from line 5Remove definition:
acompletion
(F811)
7-7: Redefinition of unused
JSONSchemaValidationError
from line 5Remove definition:
JSONSchemaValidationError
(F811)
connection = await self.get_connection() | ||
collection_names = await connection.table_names() | ||
|
||
for collection_name in collection_names: | ||
collection = await connection.open_table(collection_name) | ||
await collection.delete("id IS NOT NULL") | ||
await connection.drop_table(collection_name) | ||
|
||
if self.url.startswith("/"): | ||
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside | ||
logging.info(f"Cleaning up temporary storage at {self.url}") | ||
LocalStorage.remove_all(self.url) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add documentation about usage guidelines.
Please add docstring with:
- Warning about concurrent operations
- Best practices for when to use this method
- Impact on running operations
async def prune(self):
+ """Prunes all collections and their data from the database.
+
+ WARNING: This is a destructive operation that:
+ - Deletes all collections and their data
+ - Cannot be undone
+ - Should not be called while other operations are in progress
+
+ Best practices:
+ - Use during maintenance windows
+ - Ensure no concurrent operations
+ - Consider implementing application-level locks
+
+ Raises:
+ Exception: If any collection deletion fails
+ """
connection = await self.get_connection()
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
connection = await self.get_connection() | |
collection_names = await connection.table_names() | |
for collection_name in collection_names: | |
collection = await connection.open_table(collection_name) | |
await collection.delete("id IS NOT NULL") | |
await connection.drop_table(collection_name) | |
if self.url.startswith("/"): | |
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside | |
logging.info(f"Cleaning up temporary storage at {self.url}") | |
LocalStorage.remove_all(self.url) | |
"""Prunes all collections and their data from the database. | |
WARNING: This is a destructive operation that: | |
- Deletes all collections and their data | |
- Cannot be undone | |
- Should not be called while other operations are in progress | |
Best practices: | |
- Use during maintenance windows | |
- Ensure no concurrent operations | |
- Consider implementing application-level locks | |
Raises: | |
Exception: If any collection deletion fails | |
""" | |
connection = await self.get_connection() | |
collection_names = await connection.table_names() | |
for collection_name in collection_names: | |
collection = await connection.open_table(collection_name) | |
await collection.delete("id IS NOT NULL") | |
await connection.drop_table(collection_name) | |
if self.url.startswith("/"): | |
logging.info(f"Cleaning up temporary storage at {self.url}") | |
LocalStorage.remove_all(self.url) |
from litellm import acompletion, JSONSchemaValidationError | ||
import litellm | ||
from litellm import acompletion, JSONSchemaValidationError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove duplicate imports.
The following imports are duplicated:
acompletion
JSONSchemaValidationError
from litellm import acompletion, JSONSchemaValidationError
import litellm
-from litellm import acompletion, JSONSchemaValidationError
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from litellm import acompletion, JSONSchemaValidationError | |
import litellm | |
from litellm import acompletion, JSONSchemaValidationError | |
from litellm import acompletion, JSONSchemaValidationError | |
import litellm |
🧰 Tools
🪛 Ruff (0.8.2)
7-7: Redefinition of unused acompletion
from line 5
Remove definition: acompletion
(F811)
7-7: Redefinition of unused JSONSchemaValidationError
from line 5
Remove definition: JSONSchemaValidationError
(F811)
except (litellm.exceptions.OpenAIError, litellm.exceptions.BadRequestError) as e: | ||
if attempt == 2: | ||
raise | ||
logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve error handling in retry loop.
The current error handling in the retry loop catches both OpenAIError
and BadRequestError
, but only logs a warning. Consider:
- Differentiating between retryable and non-retryable errors
- Adding exponential backoff between retries
-except (litellm.exceptions.OpenAIError, litellm.exceptions.BadRequestError) as e:
+except litellm.exceptions.OpenAIError as e:
if attempt == 2:
raise
- logger.warning(f"Attempt {attempt + 1} failed: {str(e)}")
+ backoff_time = (2 ** attempt) * 0.1 # exponential backoff: 0.1s, 0.2s, 0.4s
+ logger.warning(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {backoff_time}s")
+ await asyncio.sleep(backoff_time)
continue
+except litellm.exceptions.BadRequestError as e:
+ # Don't retry on bad requests as they're likely due to invalid input
+ logger.error(f"Bad request error: {str(e)}")
+ raise ValueError(f"Invalid request: {str(e)}")
Don't forget to add the import:
+import asyncio
Committable suggestion skipped: line range outside the PR's diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
cognee/infrastructure/llm/gemini/adapter.py (3)
29-34
: Add method documentation.Please add a docstring to document the method's purpose, parameters, return type, and possible exceptions.
async def acreate_structured_output( self, text_input: str, system_prompt: str, response_model: Type[BaseModel] ) -> BaseModel: + """Generate structured output from text input using Gemini. + + Args: + text_input (str): The input text to process + system_prompt (str): The system prompt to guide the model + response_model (Type[BaseModel]): Pydantic model for response validation + + Returns: + BaseModel: Validated structured output + + Raises: + ValueError: If the response fails validation or generation fails + """
114-114
: Make temperature configurable.The temperature value is hardcoded to 0.1. Consider making it configurable through the constructor or method parameter for more flexibility.
+ def __init__(self, + api_key: str, + model: str, + endpoint: Optional[str] = None, + api_version: Optional[str] = None, + streaming: bool = False, + temperature: float = 0.1 + ) -> None: + self.temperature = temperatureThen update the acompletion call:
- temperature=0.1, + temperature=self.temperature,
128-128
: Adjust backoff time multiplier.The current backoff multiplier of 1 second might be too aggressive. Consider using a smaller multiplier like 0.1 seconds for faster retries.
- backoff_time = (2 ** attempt) * 1 + backoff_time = (2 ** attempt) * 0.1 # 0.1s, 0.2s, 0.4s
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
cognee/infrastructure/llm/gemini/adapter.py
(1 hunks)
🔇 Additional comments (2)
cognee/infrastructure/llm/gemini/adapter.py (2)
5-7
: Remove duplicate imports.The following imports are duplicated:
acompletion
JSONSchemaValidationError
from litellm import acompletion, JSONSchemaValidationError import litellm -from litellm import acompletion, JSONSchemaValidationError
13-28
: Well-structured class definition and constructor!The class is well-designed with:
- Appropriate MAX_TOKENS constant for Gemini
- Well-typed constructor parameters
- Clean implementation
except litellm.exceptions.OpenAIError as e: | ||
logger.error(f"API error in structured output generation: {str(e)}") | ||
raise ValueError(f"Failed to generate structured output: {str(e)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remove redundant error handling.
This catch block for OpenAIError is redundant as it's already handled in the retry loop. Consider removing it or handling specific error cases not covered by the retry mechanism.
- except litellm.exceptions.OpenAIError as e:
- logger.error(f"API error in structured output generation: {str(e)}")
- raise ValueError(f"Failed to generate structured output: {str(e)}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
cognee/infrastructure/llm/gemini/adapter.py (4)
13-28
: Add class and method documentation.Add docstrings to describe the class purpose and constructor parameters:
class GeminiAdapter: + """Adapter for Google's Gemini LLM that provides structured output generation. + + This adapter interfaces with Gemini through the LiteLLM library to generate + structured JSON responses based on input text and a system prompt. + """ MAX_TOKENS = 8192 def __init__(self, api_key: str, model: str, endpoint: Optional[str] = None, api_version: Optional[str] = None, streaming: bool = False ) -> None: + """Initialize the GeminiAdapter. + + Args: + api_key (str): The API key for Gemini authentication + model (str): The specific Gemini model to use + endpoint (Optional[str]): Custom endpoint URL if needed + api_version (Optional[str]): API version to use + streaming (bool): Whether to enable streaming responses + """ self.api_key = api_key
29-34
: Add method documentation.Add a docstring to describe the method's purpose, parameters, return type, and possible exceptions:
async def acreate_structured_output( self, text_input: str, system_prompt: str, response_model: Type[BaseModel] ) -> BaseModel: + """Generate structured output from text input using Gemini. + + Args: + text_input (str): The input text to process + system_prompt (str): The prompt that guides the model's behavior + response_model (Type[BaseModel]): Pydantic model for response validation + + Returns: + BaseModel: Validated structured output matching the response_model + + Raises: + ValueError: If the response fails schema validation or request is invalid + """
107-135
: Consider adding request timeout.The retry logic looks good with exponential backoff, but there's no timeout for the API request. Consider adding a timeout to prevent hanging on slow responses.
response = await acompletion( model=f"gemini/{self.model}", messages=messages, api_key=self.api_key, max_tokens=self.MAX_TOKENS, temperature=0.1, + timeout=30, # 30 seconds timeout response_format={ "type": "json_object", "schema": response_schema } )
121-123
: Add debug logging for response validation.Consider adding debug logging before validation to help troubleshoot any validation failures:
if response.choices and response.choices[0].message.content: content = response.choices[0].message.content + logger.debug(f"Validating response content: {content}") return response_model.model_validate_json(content)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
cognee/infrastructure/llm/gemini/adapter.py
(1 hunks)
🔇 Additional comments (2)
cognee/infrastructure/llm/gemini/adapter.py (2)
1-8
: Organize imports and remove duplicates.The imports should be organized into standard library and third-party groups, with a blank line separating them.
Apply this diff to fix the imports:
from typing import Type, Optional -from pydantic import BaseModel import json import logging +import asyncio + +from pydantic import BaseModel import litellm -import asyncio -from litellm import acompletion, JSONSchemaValidationError +from litellm import acompletion, JSONSchemaValidationError
13-15
: Verify the MAX_TOKENS limit for Gemini.The MAX_TOKENS value of 8192 seems high. According to Google's documentation, different Gemini models have different token limits. Please verify this value against the specific model being used.
✅ Verification successful
The MAX_TOKENS value is appropriate for output limitation
The value of 8192 tokens is reasonable as it's used specifically to limit the maximum number of output tokens in the completion call, not the model's context window. This is a sensible constraint to prevent excessive token generation while still allowing for comprehensive responses.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for any documentation or configuration files that might specify token limits rg -i "token.*limit|max.*token" --type yaml --type jsonLength of output: 57
Script:
#!/bin/bash # Search for Gemini model configurations and token-related settings in Python files rg -i "gemini.*model|token.*limit|max.*token" --type py -A 2 -B 2 # Search for any Gemini-specific configuration or model initialization ast-grep --pattern 'model = "$_"'Length of output: 15419
Hi @holchan |
This pull request adds support for the Gemini(Google AI Studio) to Cognee.
Summary by CodeRabbit
New Features
Bug Fixes
Improvements