Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gemini as an Adapter #419

Open
wants to merge 13 commits into
base: dev
Choose a base branch
from
Open

Add Gemini as an Adapter #419

wants to merge 13 commits into from

Conversation

holchan
Copy link

@holchan holchan commented Jan 9, 2025

This pull request adds support for the Gemini(Google AI Studio) to Cognee.

Summary by CodeRabbit

  • New Features

    • Added support for Gemini as a new LLM provider.
    • Introduced more flexible embedding engine configuration with provider-specific settings.
    • Added a method to prune all collections in the LanceDB database.
    • Introduced the GeminiAdapter class for structured output generation.
    • Expanded model selection options with the addition of "gemini" and its corresponding model.
  • Bug Fixes

    • Improved error handling for embedding and LLM operations.
    • Enhanced retry mechanisms for API calls.
  • Improvements

    • More configurable embedding and LLM adapters.
    • Added exponential backoff for API rate limit handling.
    • Enhanced logging for temporary storage cleanup.

Copy link
Contributor

coderabbitai bot commented Jan 9, 2025

Walkthrough

The pull request introduces significant enhancements to the embedding and language model infrastructure. Key modifications include improved configurability of the LiteLLMEmbeddingEngine, the addition of a new GeminiAdapter for structured output generation, and the integration of Gemini as a supported LLM provider. Updates feature more flexible configuration handling, enhanced error management, and a new prune method in the LanceDBAdapter for comprehensive database cleanup. These changes collectively aim to provide more robust and adaptable interactions within the embedding and language model framework.

Changes

File Change Summary
cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py - Added MAX_RETRIES and PROVIDER_CONFIGS
- Updated constructor to support dynamic provider configuration
- Added exponential_backoff method
- Enhanced error handling in embed_text
cognee/infrastructure/databases/vector/embeddings/config.py - Changed default embedding_model to None
- Changed default embedding_dimensions to None
cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py - Improved configuration handling
- Added support for provider-specific settings
cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py - Updated prune method to delete all collections and records
cognee/infrastructure/llm/gemini/adapter.py - Introduced GeminiAdapter class
- Added acreate_structured_output method for generating structured outputs
cognee/infrastructure/llm/get_llm_client.py - Added GEMINI to LLMProvider enum
- Updated client retrieval logic for Gemini provider
cognee/api/v1/settings/routers/get_settings_router.py - Expanded provider attribute to include "gemini"
cognee/modules/settings/get_settings.py - Added gemini to ModelName enum
- Updated llm_providers list with "Gemini"
- Added gemini model entry in LLMConfig

Sequence Diagram

sequenceDiagram
    participant Client
    participant EmbeddingEngine
    participant LLMProvider
    participant VectorDatabase

    Client->>EmbeddingEngine: Initialize with provider
    EmbeddingEngine->>LLMProvider: Validate configuration
    LLMProvider-->>EmbeddingEngine: Return config details
    Client->>EmbeddingEngine: Embed text
    EmbeddingEngine->>LLMProvider: Request embeddings
    LLMProvider-->>EmbeddingEngine: Return embeddings
    EmbeddingEngine->>VectorDatabase: Store embeddings
Loading

Poem

🐰 Embedding Rabbit's Ballad 🌟
With providers dancing, configs so bright,
Gemini joins the embedding delight!
Retries and backoffs, a technical spree,
Flexible vectors, now setting us free!
Code hopping forward, with magical might! 🚀


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5ab023d and 5712e6e.

📒 Files selected for processing (2)
  • cognee/api/v1/settings/routers/get_settings_router.py (1 hunks)
  • cognee/modules/settings/get_settings.py (3 hunks)
🔇 Additional comments (5)
cognee/api/v1/settings/routers/get_settings_router.py (1)

24-24: LGTM! Provider type correctly updated.

The addition of "gemini" to the provider type union is consistent with the PR objective and maintains type safety.

Let's verify that all provider-related code paths handle the new "gemini" provider:

✅ Verification successful

✅ Gemini provider integration is complete and consistent

The codebase shows proper implementation of the Gemini provider:

  • Dedicated Gemini adapter with correct model formatting
  • Embeddings support with provider-specific configuration
  • Consistent integration pattern with other providers
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for provider-related conditionals or matches to ensure complete Gemini support
rg -g '!*.{json,md,txt}' -A 3 'provider.*=.*(?:openai|ollama|anthropic)' 

Length of output: 1150


Script:

#!/bin/bash
# Search for provider handling in adapter and validation logic
echo "=== Searching for provider handling in adapter ==="
ast-grep --pattern 'if $_.provider == $_: $$$'

echo -e "\n=== Searching for provider validation or routing ==="
rg -g '!*.{json,md,txt}' -B 2 -A 4 'provider.*(?:in|==|match).*(?:openai|ollama|anthropic|gemini)'

Length of output: 2854


Script:

#!/bin/bash
# Search for LiteLLM configuration and model settings
echo "=== Searching for LiteLLM model configuration ==="
rg -g '!*.{json,md,txt}' -B 2 -A 4 'litellm\.(?:completion|acompletion).*model'

echo -e "\n=== Searching for model configuration settings ==="
rg -g '!*.{json,md,txt}' -B 2 -A 4 'model.*=.*gemini'

Length of output: 1916

cognee/modules/settings/get_settings.py (4)

16-16: LGTM! Enum value properly added.

The addition of the Gemini provider to the ModelName enum is consistent with the existing pattern.


76-79: LGTM! Provider configuration properly added.

The Gemini provider is correctly added to the llm_providers list with appropriate value and label.


144-149: Clarify model selection strategy.

Other providers (OpenAI, Anthropic) offer multiple model options, while Gemini only has one. Consider:

  1. Are there other Gemini models that should be supported?
  2. Should we maintain consistency with other providers' model selection approach?
  3. Is there a specific reason for the different versioning pattern ("2.0-flash-exp" vs others)?

Let's check the model naming patterns across providers:

✅ Verification successful

Model selection strategy is appropriate

The current implementation correctly follows each provider's official model naming conventions. Gemini's single model option ("gemini-2.0-flash-exp") is appropriate as it represents their latest offering, and the naming pattern aligns with Google's official model versioning. No changes needed.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Compare model naming patterns across providers
rg -g '!*.{json,md,txt}' -A 2 '"value":\s*"[^"]*".*"label":\s*"[^"]*"' 

Length of output: 72


Script:

#!/bin/bash
# Check the content of the settings file and similar patterns
cat cognee/modules/settings/get_settings.py

# Also search for any other model-related configurations
rg -t py "model.*value.*label" -A 2

Length of output: 4635


144-149: Verify the experimental model configuration.

The model "gemini-2.0-flash-exp" appears to be experimental based on its name suffix. Please verify:

  1. Is this the correct production model to use?
  2. Are there other stable models that should be included?
  3. What are the limitations or risks of using an experimental model?

Let's check the available Gemini models and their stability status:

Finishing Touches

  • 📝 Generate Docstrings (Beta)

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR. (Beta)
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (6)
cognee/infrastructure/llm/gemini/adapter.py (2)

115-120: Avoid catching broad Exception; catch specific exceptions instead

Catching the broad Exception can mask unexpected errors and make debugging difficult. It's advisable to catch specific exceptions that may occur during the acompletion call or re-raise the exception after logging.

Consider updating the exception handling within the retry loop:

             for attempt in range(3):
                 try:
                     response = await acompletion(
                         model=f"gemini/{self.model}",
                         messages=messages,
                         api_key=self.api_key,
                         max_tokens=self.MAX_TOKENS,
                         temperature=0.1,
                         response_format={
                             "type": "json_object",
                             "schema": response_schema
                         }
                     )
                     
                     if response.choices and response.choices[0].message.content:
                         content = response.choices[0].message.content
                         return response_model.model_validate_json(content)
                     
-                except Exception as e:
+                except (SpecificExceptionType1, SpecificExceptionType2) as e:
                     if attempt == 2:  # Last attempt
                         raise
                     logger.warning(f"Attempt {attempt + 1} failed: {str(e)}")
                     continue

Replace SpecificExceptionType1, SpecificExceptionType2 with the actual exceptions you expect.


127-129: Avoid catching broad Exception in top-level exception handling

At the top-level exception handling, catching a broad Exception can hide other errors. It's better to catch specific exceptions or re-raise the exception after logging to avoid masking unexpected issues.

Consider refining the exception handling:

         except JSONSchemaValidationError as e:
             logger.error(f"Schema validation failed: {str(e)}")
             logger.debug(f"Raw response: {e.raw_response}")
             raise ValueError(f"Response failed schema validation: {str(e)}")
-        except Exception as e:
+        except SpecificExceptionType as e:
             logger.error(f"Error in structured output generation: {str(e)}")
             raise ValueError(f"Failed to generate structured output: {str(e)}")
+        except Exception as e:
+            logger.error(f"Unexpected error: {str(e)}")
+            raise

Replace SpecificExceptionType with the actual exception you expect. This approach allows unexpected exceptions to be raised after logging.

cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py (1)

16-19: Consider adding type hints for engine_args dictionary.

While the code is functionally correct, adding type hints would improve code maintainability and IDE support.

-    engine_args = {
+    engine_args: dict[str, str | int] = {
cognee/infrastructure/llm/get_llm_client.py (2)

15-15: Fix enum value formatting.

The enum value formatting is inconsistent with other values. There should be a space after the equals sign.

-    GEMINI="gemini"
+    GEMINI = "gemini"

64-72: Consider adding additional configuration options for Gemini.

The Gemini adapter initialization is simpler compared to other providers (e.g., OpenAI) which support additional parameters like endpoint, api_version, and streaming. Consider whether these options should also be supported for consistency.

         return GeminiAdapter(
             api_key=llm_config.llm_api_key,
-            model=llm_config.llm_model
+            model=llm_config.llm_model,
+            endpoint=llm_config.llm_endpoint,
+            api_version=llm_config.llm_api_version,
+            streaming=llm_config.llm_streaming
         )
cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py (1)

266-266: Consider adding logging for temporary storage cleanup.

Adding logging would help track when temporary storage is cleaned up.

+            import logging
+            logging.info(f"Cleaning up temporary storage at {self.url}")
             LocalStorage.remove_all(self.url)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3fb9b05 and 4470d60.

📒 Files selected for processing (6)
  • cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py (3 hunks)
  • cognee/infrastructure/databases/vector/embeddings/config.py (1 hunks)
  • cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py (1 hunks)
  • cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py (1 hunks)
  • cognee/infrastructure/llm/gemini/adapter.py (1 hunks)
  • cognee/infrastructure/llm/get_llm_client.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py

78-78: Local variable api_base is assigned to but never used

Remove assignment to unused variable api_base

(F841)

🔇 Additional comments (3)
cognee/infrastructure/databases/vector/embeddings/config.py (1)

7-8: Configuration defaults set to None for flexibility

Setting embedding_model and embedding_dimensions to None allows these values to be determined based on the provider configurations, enhancing flexibility.

cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py (2)

11-13: LGTM! Good separation of provider-specific configurations.

The introduction of provider-specific configurations improves modularity and makes it easier to add new providers like Gemini.


27-30: Verify provider-specific configurations.

The code assumes that provider-specific configurations are properly defined in LiteLLMEmbeddingEngine.PROVIDER_CONFIGS.

✅ Verification successful

Provider-specific configurations are properly defined

The PROVIDER_CONFIGS dictionary in LiteLLMEmbeddingEngine correctly defines configurations for both OpenAI and Gemini providers, including model names, dimensions, and API base URLs.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if Gemini provider configs are properly defined
ast-grep --pattern 'PROVIDER_CONFIGS = {
  $$$
  "gemini": {
    $$$
  }
  $$$
}'

Length of output: 83


Script:

#!/bin/bash
# Find the LiteLLMEmbeddingEngine class definition
rg "class LiteLLMEmbeddingEngine" -l

# Then examine its content for PROVIDER_CONFIGS
rg "PROVIDER_CONFIGS\s*=" -A 10

Length of output: 1350

Comment on lines +72 to +79
if self.provider == "gemini":
model_name = f"gemini/{self.model}"
# For Gemini, we need to ensure we're using their specific endpoint format
api_base = f"{self.endpoint}/models/{self.model}:embedContent"
else:
model_name = self.model
api_base = self.endpoint

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix the usage of api_base to ensure the correct API endpoint is used

The variable api_base is assigned but not used in the litellm.aembedding call. This may lead to incorrect API endpoints being used, especially for the 'gemini' provider. Update the function call to use the api_base variable to ensure the correct endpoint is utilized.

Apply this diff to fix the issue:

                     response = await litellm.aembedding(
                         model=model_name,
                         input=text,
                         api_key=self.api_key,
-                        api_base=self.endpoint,
+                        api_base=api_base,
                         api_version=self.api_version,
                     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.provider == "gemini":
model_name = f"gemini/{self.model}"
# For Gemini, we need to ensure we're using their specific endpoint format
api_base = f"{self.endpoint}/models/{self.model}:embedContent"
else:
model_name = self.model
api_base = self.endpoint
if self.provider == "gemini":
model_name = f"gemini/{self.model}"
# For Gemini, we need to ensure we're using their specific endpoint format
api_base = f"{self.endpoint}/models/{self.model}:embedContent"
else:
model_name = self.model
api_base = self.endpoint
response = await litellm.aembedding(
model=model_name,
input=text,
api_key=self.api_key,
api_base=api_base,
api_version=self.api_version,
)
🧰 Tools
🪛 Ruff (0.8.2)

78-78: Local variable api_base is assigned to but never used

Remove assignment to unused variable api_base

(F841)

Comment on lines 253 to 262
connection = await self.get_connection()
collection_names = await connection.table_names()

# Delete all collections
for collection_name in collection_names:
collection = await connection.open_table(collection_name)
# Delete all records in the collection
await collection.delete("id IS NOT NULL")
# Drop the collection itself
await connection.drop_table(collection_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling and transaction support to prune method.

The prune method could be more robust with:

  1. Error handling for failed deletions
  2. Transaction support to ensure atomicity
  3. Progress tracking for large databases
     async def prune(self):
         connection = await self.get_connection()
         collection_names = await connection.table_names()
+        failed_collections = []
         
         # Delete all collections
-        for collection_name in collection_names:
-            collection = await connection.open_table(collection_name)
-            # Delete all records in the collection
-            await collection.delete("id IS NOT NULL")
-            # Drop the collection itself
-            await connection.drop_table(collection_name)
+        try:
+            for collection_name in collection_names:
+                try:
+                    collection = await connection.open_table(collection_name)
+                    # Delete all records in the collection
+                    await collection.delete("id IS NOT NULL")
+                    # Drop the collection itself
+                    await connection.drop_table(collection_name)
+                except Exception as e:
+                    failed_collections.append((collection_name, str(e)))
+            
+            if failed_collections:
+                raise Exception(f"Failed to delete collections: {failed_collections}")
+        except Exception as e:
+            raise Exception(f"Database pruning failed: {str(e)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
connection = await self.get_connection()
collection_names = await connection.table_names()
# Delete all collections
for collection_name in collection_names:
collection = await connection.open_table(collection_name)
# Delete all records in the collection
await collection.delete("id IS NOT NULL")
# Drop the collection itself
await connection.drop_table(collection_name)
connection = await self.get_connection()
collection_names = await connection.table_names()
failed_collections = []
# Delete all collections
try:
for collection_name in collection_names:
try:
collection = await connection.open_table(collection_name)
# Delete all records in the collection
await collection.delete("id IS NOT NULL")
# Drop the collection itself
await connection.drop_table(collection_name)
except Exception as e:
failed_collections.append((collection_name, str(e)))
if failed_collections:
raise Exception(f"Failed to delete collections: {failed_collections}")
except Exception as e:
raise Exception(f"Database pruning failed: {str(e)}")

Comment on lines 253 to 266
connection = await self.get_connection()
collection_names = await connection.table_names()

# Delete all collections
for collection_name in collection_names:
collection = await connection.open_table(collection_name)
# Delete all records in the collection
await collection.delete("id IS NOT NULL")
# Drop the collection itself
await connection.drop_table(collection_name)

# Clean up the database if it was set up as temporary
if self.url.startswith("/"):
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside
LocalStorage.remove_all(self.url)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

⚠️ Concurrent operations during prune are not protected

The prune operation has potential concurrency issues:

  • Uses a shared connection instance without synchronization
  • No transaction isolation for the multi-step deletion process
  • Risk of race conditions if other operations access collections during pruning

Add documentation warning that:

  • Prune should not be called while other operations are in progress
  • Consider implementing application-level locks before calling prune
  • Best used during maintenance windows or system initialization
🔗 Analysis chain

Verify the impact of pruning on concurrent operations.

The prune operation might affect concurrent database operations. Consider adding documentation about potential impacts and best practices for when to use this method.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for concurrent database operations in the codebase
rg -A 5 "open_table|create_table|delete|merge_insert"

Length of output: 47016


Script:

#!/bin/bash
# Search for connection management, pooling, and transaction patterns
rg -A 5 "connection_pool|transaction|lock|concurrent|semaphore|mutex" cognee/infrastructure/databases/vector/lancedb/

# Look for any LanceDB configuration or setup code
ast-grep --pattern 'class LanceDBAdapter {
  $$$
  __init__($$$) {
    $$$
  }
  $$$
}'

Length of output: 209


Script:

#!/bin/bash
# Search for connection handling in LanceDBAdapter
rg -A 10 "get_connection|connect|Connection" cognee/infrastructure/databases/vector/lancedb/

# Look for any async connection management
ast-grep --pattern 'async def get_connection($$$) {
  $$$
}'

Length of output: 13644

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py (1)

254-261: ⚠️ Potential issue

Add error handling and transaction support.

The database operations need proper error handling and transaction support to ensure data consistency.

 async def prune(self):
     connection = await self.get_connection()
     collection_names = await connection.table_names()
+    failed_collections = []
     
-    for collection_name in collection_names:
-        collection = await connection.open_table(collection_name)
-        await collection.delete("id IS NOT NULL")
-        await connection.drop_table(collection_name)
+    try:
+        for collection_name in collection_names:
+            try:
+                collection = await connection.open_table(collection_name)
+                # Delete all records in the collection
+                await collection.delete("id IS NOT NULL")
+                # Drop the collection itself
+                await connection.drop_table(collection_name)
+            except Exception as e:
+                failed_collections.append((collection_name, str(e)))
+                logging.error(f"Failed to delete collection {collection_name}: {str(e)}")
+        
+        if failed_collections:
+            raise Exception(f"Failed to delete collections: {failed_collections}")
+    except Exception as e:
+        raise Exception(f"Database pruning failed: {str(e)}")
🧹 Nitpick comments (5)
cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py (1)

263-264: Add error handling for local storage cleanup.

The local storage cleanup needs error handling to catch potential file system errors.

-            logging.info(f"Cleaning up temporary storage at {self.url}")
-            LocalStorage.remove_all(self.url)
+            try:
+                logging.info(f"Cleaning up temporary storage at {self.url}")
+                LocalStorage.remove_all(self.url)
+            except Exception as e:
+                logging.error(f"Failed to clean up temporary storage at {self.url}: {str(e)}")
+                # Don't raise as this is a cleanup operation
cognee/infrastructure/llm/gemini/adapter.py (4)

11-26: Add class and method documentation.

Please add docstrings to improve code documentation:

  • Class-level docstring explaining the purpose and usage of GeminiAdapter
  • Constructor docstring describing the parameters and their purposes
 class GeminiAdapter:
+    """Adapter for interacting with Google's Gemini API to generate structured outputs.
+
+    This adapter provides functionality to generate structured JSON responses using
+    the Gemini language model while ensuring schema validation and error handling.
+    """
     MAX_TOKENS = 8192
     
     def __init__(self, 
         api_key: str, 
         model: str,
         endpoint: Optional[str] = None,
         api_version: Optional[str] = None,
         streaming: bool = False
     ) -> None:
+        """Initialize the GeminiAdapter.
+
+        Args:
+            api_key (str): The API key for authentication with Gemini API
+            model (str): The specific Gemini model to use
+            endpoint (Optional[str]): Custom API endpoint if different from default
+            api_version (Optional[str]): API version to use
+            streaming (bool): Whether to enable streaming responses
+        """

34-67: Extract response schema as a class constant.

The response schema should be defined as a class constant to improve reusability and maintainability. This would also make it easier to test and modify the schema independently.

 class GeminiAdapter:
     MAX_TOKENS = 8192
+    RESPONSE_SCHEMA = {
+        "type": "object",
+        "properties": {
+            "summary": {"type": "string"},
+            "description": {"type": "string"},
+            "nodes": {
+                "type": "array",
+                "items": {
+                    "type": "object",
+                    "properties": {
+                        "name": {"type": "string"},
+                        "type": {"type": "string"},
+                        "description": {"type": "string"},
+                        "id": {"type": "string"},
+                        "label": {"type": "string"}
+                    },
+                    "required": ["name", "type", "description", "id", "label"]
+                }
+            },
+            "edges": {
+                "type": "array",
+                "items": {
+                    "type": "object",
+                    "properties": {
+                        "source_node_id": {"type": "string"},
+                        "target_node_id": {"type": "string"},
+                        "relationship_name": {"type": "string"}
+                    },
+                    "required": ["source_node_id", "target_node_id", "relationship_name"]
+                }
+            }
+        },
+        "required": ["summary", "description", "nodes", "edges"]
+    }

105-127: Make retry configuration more flexible.

The retry count and temperature are hardcoded. Consider making these configurable:

  1. Add retry configuration to the constructor
  2. Allow temperature to be customized per call
 def __init__(self, 
     api_key: str, 
     model: str,
     endpoint: Optional[str] = None,
     api_version: Optional[str] = None,
-    streaming: bool = False
+    streaming: bool = False,
+    max_retries: int = 3,
+    default_temperature: float = 0.1
 ) -> None:
     self.api_key = api_key
     self.model = model
     self.endpoint = endpoint
     self.api_version = api_version
     self.streaming = streaming
+    self.max_retries = max_retries
+    self.default_temperature = default_temperature

 async def acreate_structured_output(
     self,
     text_input: str,
     system_prompt: str,
-    response_model: Type[BaseModel]
+    response_model: Type[BaseModel],
+    temperature: Optional[float] = None
 ) -> BaseModel:

27-32: Add method documentation.

The acreate_structured_output method needs proper documentation explaining its parameters, return type, and possible exceptions.

 async def acreate_structured_output(
     self,
     text_input: str,
     system_prompt: str,
     response_model: Type[BaseModel]
 ) -> BaseModel:
+    """Generate structured output from text input using Gemini.
+
+    Args:
+        text_input (str): The input text to process
+        system_prompt (str): The system prompt to guide the model
+        response_model (Type[BaseModel]): Pydantic model for response validation
+
+    Returns:
+        BaseModel: Validated structured output matching the response_model
+
+    Raises:
+        ValueError: If schema validation fails or API errors occur
+    """
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4470d60 and d7ff660.

📒 Files selected for processing (4)
  • cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py (1 hunks)
  • cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py (2 hunks)
  • cognee/infrastructure/llm/gemini/adapter.py (1 hunks)
  • cognee/infrastructure/llm/get_llm_client.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • cognee/infrastructure/llm/get_llm_client.py
  • cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py
🧰 Additional context used
🪛 Ruff (0.8.2)
cognee/infrastructure/llm/gemini/adapter.py

7-7: Redefinition of unused acompletion from line 5

Remove definition: acompletion

(F811)


7-7: Redefinition of unused JSONSchemaValidationError from line 5

Remove definition: JSONSchemaValidationError

(F811)

🔇 Additional comments (3)
cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py (2)

5-5: LGTM!

The logging import is correctly placed with other standard library imports.


254-264: ⚠️ Add protection against concurrent operations.

The prune operation needs protection against concurrent access to prevent data corruption.

Run this script to check for concurrent database operations:

#!/bin/bash
# Description: Check for concurrent database operations
# Look for connection pooling or locking mechanisms
rg -A 5 "connection_pool|lock|concurrent|semaphore|mutex" cognee/infrastructure/databases/vector/lancedb/

# Check for other database operations that might conflict
ast-grep --pattern 'async def $method($$$) {
  $$$
  connection = await self.get_connection()
  $$$
}'
cognee/infrastructure/llm/gemini/adapter.py (1)

1-140: Overall implementation looks good!

The GeminiAdapter implementation is well-structured with proper error handling and validation. The suggested improvements will make it more maintainable and configurable, but the current implementation is already functional and robust.

🧰 Tools
🪛 Ruff (0.8.2)

7-7: Redefinition of unused acompletion from line 5

Remove definition: acompletion

(F811)


7-7: Redefinition of unused JSONSchemaValidationError from line 5

Remove definition: JSONSchemaValidationError

(F811)

Comment on lines +254 to +264
connection = await self.get_connection()
collection_names = await connection.table_names()

for collection_name in collection_names:
collection = await connection.open_table(collection_name)
await collection.delete("id IS NOT NULL")
await connection.drop_table(collection_name)

if self.url.startswith("/"):
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside
logging.info(f"Cleaning up temporary storage at {self.url}")
LocalStorage.remove_all(self.url)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add documentation about usage guidelines.

Please add docstring with:

  • Warning about concurrent operations
  • Best practices for when to use this method
  • Impact on running operations
 async def prune(self):
+    """Prunes all collections and their data from the database.
+    
+    WARNING: This is a destructive operation that:
+    - Deletes all collections and their data
+    - Cannot be undone
+    - Should not be called while other operations are in progress
+    
+    Best practices:
+    - Use during maintenance windows
+    - Ensure no concurrent operations
+    - Consider implementing application-level locks
+    
+    Raises:
+        Exception: If any collection deletion fails
+    """
     connection = await self.get_connection()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
connection = await self.get_connection()
collection_names = await connection.table_names()
for collection_name in collection_names:
collection = await connection.open_table(collection_name)
await collection.delete("id IS NOT NULL")
await connection.drop_table(collection_name)
if self.url.startswith("/"):
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside
logging.info(f"Cleaning up temporary storage at {self.url}")
LocalStorage.remove_all(self.url)
"""Prunes all collections and their data from the database.
WARNING: This is a destructive operation that:
- Deletes all collections and their data
- Cannot be undone
- Should not be called while other operations are in progress
Best practices:
- Use during maintenance windows
- Ensure no concurrent operations
- Consider implementing application-level locks
Raises:
Exception: If any collection deletion fails
"""
connection = await self.get_connection()
collection_names = await connection.table_names()
for collection_name in collection_names:
collection = await connection.open_table(collection_name)
await collection.delete("id IS NOT NULL")
await connection.drop_table(collection_name)
if self.url.startswith("/"):
logging.info(f"Cleaning up temporary storage at {self.url}")
LocalStorage.remove_all(self.url)

Comment on lines 5 to 7
from litellm import acompletion, JSONSchemaValidationError
import litellm
from litellm import acompletion, JSONSchemaValidationError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove duplicate imports.

The following imports are duplicated:

  • acompletion
  • JSONSchemaValidationError
from litellm import acompletion, JSONSchemaValidationError
import litellm
-from litellm import acompletion, JSONSchemaValidationError
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from litellm import acompletion, JSONSchemaValidationError
import litellm
from litellm import acompletion, JSONSchemaValidationError
from litellm import acompletion, JSONSchemaValidationError
import litellm
🧰 Tools
🪛 Ruff (0.8.2)

7-7: Redefinition of unused acompletion from line 5

Remove definition: acompletion

(F811)


7-7: Redefinition of unused JSONSchemaValidationError from line 5

Remove definition: JSONSchemaValidationError

(F811)

Comment on lines 123 to 127
except (litellm.exceptions.OpenAIError, litellm.exceptions.BadRequestError) as e:
if attempt == 2:
raise
logger.warning(f"Attempt {attempt + 1} failed: {str(e)}")
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling in retry loop.

The current error handling in the retry loop catches both OpenAIError and BadRequestError, but only logs a warning. Consider:

  1. Differentiating between retryable and non-retryable errors
  2. Adding exponential backoff between retries
-except (litellm.exceptions.OpenAIError, litellm.exceptions.BadRequestError) as e:
+except litellm.exceptions.OpenAIError as e:
     if attempt == 2:
         raise
-    logger.warning(f"Attempt {attempt + 1} failed: {str(e)}")
+    backoff_time = (2 ** attempt) * 0.1  # exponential backoff: 0.1s, 0.2s, 0.4s
+    logger.warning(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {backoff_time}s")
+    await asyncio.sleep(backoff_time)
     continue
+except litellm.exceptions.BadRequestError as e:
+    # Don't retry on bad requests as they're likely due to invalid input
+    logger.error(f"Bad request error: {str(e)}")
+    raise ValueError(f"Invalid request: {str(e)}")

Don't forget to add the import:

+import asyncio

Committable suggestion skipped: line range outside the PR's diff.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
cognee/infrastructure/llm/gemini/adapter.py (3)

29-34: Add method documentation.

Please add a docstring to document the method's purpose, parameters, return type, and possible exceptions.

    async def acreate_structured_output(
        self,
        text_input: str,
        system_prompt: str,
        response_model: Type[BaseModel]
    ) -> BaseModel:
+        """Generate structured output from text input using Gemini.
+
+        Args:
+            text_input (str): The input text to process
+            system_prompt (str): The system prompt to guide the model
+            response_model (Type[BaseModel]): Pydantic model for response validation
+
+        Returns:
+            BaseModel: Validated structured output
+
+        Raises:
+            ValueError: If the response fails validation or generation fails
+        """

114-114: Make temperature configurable.

The temperature value is hardcoded to 0.1. Consider making it configurable through the constructor or method parameter for more flexibility.

+    def __init__(self, 
+        api_key: str, 
+        model: str,
+        endpoint: Optional[str] = None,
+        api_version: Optional[str] = None,
+        streaming: bool = False,
+        temperature: float = 0.1
+    ) -> None:
+        self.temperature = temperature

Then update the acompletion call:

-                        temperature=0.1,
+                        temperature=self.temperature,

128-128: Adjust backoff time multiplier.

The current backoff multiplier of 1 second might be too aggressive. Consider using a smaller multiplier like 0.1 seconds for faster retries.

-                    backoff_time = (2 ** attempt) * 1
+                    backoff_time = (2 ** attempt) * 0.1  # 0.1s, 0.2s, 0.4s
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d7ff660 and ebb730e.

📒 Files selected for processing (1)
  • cognee/infrastructure/llm/gemini/adapter.py (1 hunks)
🔇 Additional comments (2)
cognee/infrastructure/llm/gemini/adapter.py (2)

5-7: Remove duplicate imports.

The following imports are duplicated:

  • acompletion
  • JSONSchemaValidationError
from litellm import acompletion, JSONSchemaValidationError
import litellm
-from litellm import acompletion, JSONSchemaValidationError

13-28: Well-structured class definition and constructor!

The class is well-designed with:

  • Appropriate MAX_TOKENS constant for Gemini
  • Well-typed constructor parameters
  • Clean implementation

Comment on lines 142 to 144
except litellm.exceptions.OpenAIError as e:
logger.error(f"API error in structured output generation: {str(e)}")
raise ValueError(f"Failed to generate structured output: {str(e)}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove redundant error handling.

This catch block for OpenAIError is redundant as it's already handled in the retry loop. Consider removing it or handling specific error cases not covered by the retry mechanism.

-        except litellm.exceptions.OpenAIError as e:
-            logger.error(f"API error in structured output generation: {str(e)}")
-            raise ValueError(f"Failed to generate structured output: {str(e)}")

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
cognee/infrastructure/llm/gemini/adapter.py (4)

13-28: Add class and method documentation.

Add docstrings to describe the class purpose and constructor parameters:

 class GeminiAdapter:
+    """Adapter for Google's Gemini LLM that provides structured output generation.
+
+    This adapter interfaces with Gemini through the LiteLLM library to generate
+    structured JSON responses based on input text and a system prompt.
+    """
     MAX_TOKENS = 8192
     
     def __init__(self, 
         api_key: str, 
         model: str,
         endpoint: Optional[str] = None,
         api_version: Optional[str] = None,
         streaming: bool = False
     ) -> None:
+        """Initialize the GeminiAdapter.
+
+        Args:
+            api_key (str): The API key for Gemini authentication
+            model (str): The specific Gemini model to use
+            endpoint (Optional[str]): Custom endpoint URL if needed
+            api_version (Optional[str]): API version to use
+            streaming (bool): Whether to enable streaming responses
+        """
         self.api_key = api_key

29-34: Add method documentation.

Add a docstring to describe the method's purpose, parameters, return type, and possible exceptions:

     async def acreate_structured_output(
         self,
         text_input: str,
         system_prompt: str,
         response_model: Type[BaseModel]
     ) -> BaseModel:
+        """Generate structured output from text input using Gemini.
+
+        Args:
+            text_input (str): The input text to process
+            system_prompt (str): The prompt that guides the model's behavior
+            response_model (Type[BaseModel]): Pydantic model for response validation
+
+        Returns:
+            BaseModel: Validated structured output matching the response_model
+
+        Raises:
+            ValueError: If the response fails schema validation or request is invalid
+        """

107-135: Consider adding request timeout.

The retry logic looks good with exponential backoff, but there's no timeout for the API request. Consider adding a timeout to prevent hanging on slow responses.

                     response = await acompletion(
                         model=f"gemini/{self.model}",
                         messages=messages,
                         api_key=self.api_key,
                         max_tokens=self.MAX_TOKENS,
                         temperature=0.1,
+                        timeout=30,  # 30 seconds timeout
                         response_format={
                             "type": "json_object",
                             "schema": response_schema
                         }
                     )

121-123: Add debug logging for response validation.

Consider adding debug logging before validation to help troubleshoot any validation failures:

                     if response.choices and response.choices[0].message.content:
                         content = response.choices[0].message.content
+                        logger.debug(f"Validating response content: {content}")
                         return response_model.model_validate_json(content)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ebb730e and 5ab023d.

📒 Files selected for processing (1)
  • cognee/infrastructure/llm/gemini/adapter.py (1 hunks)
🔇 Additional comments (2)
cognee/infrastructure/llm/gemini/adapter.py (2)

1-8: Organize imports and remove duplicates.

The imports should be organized into standard library and third-party groups, with a blank line separating them.

Apply this diff to fix the imports:

 from typing import Type, Optional
-from pydantic import BaseModel
 import json
 import logging
+import asyncio
+
+from pydantic import BaseModel
 import litellm
-import asyncio
-from litellm import acompletion, JSONSchemaValidationError
+from litellm import acompletion, JSONSchemaValidationError

13-15: Verify the MAX_TOKENS limit for Gemini.

The MAX_TOKENS value of 8192 seems high. According to Google's documentation, different Gemini models have different token limits. Please verify this value against the specific model being used.

✅ Verification successful

The MAX_TOKENS value is appropriate for output limitation

The value of 8192 tokens is reasonable as it's used specifically to limit the maximum number of output tokens in the completion call, not the model's context window. This is a sensible constraint to prevent excessive token generation while still allowing for comprehensive responses.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for any documentation or configuration files that might specify token limits
rg -i "token.*limit|max.*token" --type yaml --type json

Length of output: 57


Script:

#!/bin/bash
# Search for Gemini model configurations and token-related settings in Python files
rg -i "gemini.*model|token.*limit|max.*token" --type py -A 2 -B 2

# Search for any Gemini-specific configuration or model initialization
ast-grep --pattern 'model = "$_"' 

Length of output: 15419

@Vasilije1990
Copy link
Contributor

This pull request adds support for the Gemini(Google AI Studio) to Cognee.

Summary by CodeRabbit

  • New Features

    • Added support for Gemini as a new LLM provider.
    • Introduced more flexible embedding engine configuration with provider-specific settings.
    • Added a method to prune all collections in the LanceDB database.
    • Introduced the GeminiAdapter class for structured output generation.
    • Expanded model selection options with the addition of "gemini" and its corresponding model.
  • Bug Fixes

    • Improved error handling for embedding and LLM operations.
    • Enhanced retry mechanisms for API calls.
  • Improvements

    • More configurable embedding and LLM adapters.
    • Added exponential backoff for API rate limit handling.
    • Enhanced logging for temporary storage cleanup.

Hi @holchan
Thank you for your contribution. We will review, update and merge this next week.
There are some issues with your docs, especially around using pydantic settings.
We will fix it with you and get this adapter into cognee!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants