diff --git a/cognee/modules/search/graph/search_adjacent.py b/cognee/modules/search/graph/search_adjacent.py index 94a88a0c..b198c344 100644 --- a/cognee/modules/search/graph/search_adjacent.py +++ b/cognee/modules/search/graph/search_adjacent.py @@ -1,24 +1,45 @@ """ This module contains the function to find the neighbours of a given node in the graph""" -async def search_adjacent(graph, query: str, other_param: dict = None) -> dict: - """ Find the neighbours of a given node in the graph - :param graph: A NetworkX graph object - - :return: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node +from typing import Union, Dict +import networkx as nx +from neo4j import AsyncSession +from cognee.shared.data_models import GraphDBType +async def search_adjacent(graph: Union[nx.Graph, AsyncSession], query: str, infrastructure_config: Dict, other_param: dict = None) -> Dict[str, str]: """ + Find the neighbours of a given node in the graph and return their descriptions. + Supports both NetworkX graphs and Neo4j graph databases based on the configuration. + + Parameters: + - graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. + - query (str): Unused in this implementation but could be used for future enhancements. + - infrastructure_config (Dict): Configuration that includes the graph engine type. + - other_param (dict, optional): A dictionary that may contain 'node_id' to specify the node. + Returns: + - Dict[str, str]: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node. + """ node_id = other_param.get('node_id') if other_param else None - if node_id is None or node_id not in graph: + if node_id is None: return {} - neighbors = list(graph.neighbors(node_id)) - neighbor_descriptions = {} - - for neighbor in neighbors: - # Access the 'description' attribute for each neighbor - # The get method returns None if 'description' attribute does not exist for the node - neighbor_descriptions[neighbor] = graph.nodes[neighbor].get('description') - - return neighbor_descriptions \ No newline at end of file + if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX: + if node_id not in graph: + return {} + + neighbors = list(graph.neighbors(node_id)) + neighbor_descriptions = {neighbor: graph.nodes[neighbor].get('description') for neighbor in neighbors} + return neighbor_descriptions + + elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J: + cypher_query = """ + MATCH (node {id: $node_id})-[:CONNECTED_TO]->(neighbor) + RETURN neighbor.id AS neighbor_id, neighbor.description AS description + """ + results = await graph.run(cypher_query, node_id=node_id) + neighbor_descriptions = {record["neighbor_id"]: record["description"] for record in await results.list() if "description" in record} + return neighbor_descriptions + + else: + raise ValueError("Unsupported graph engine type in the configuration.") \ No newline at end of file diff --git a/cognee/modules/search/graph/search_categories.py b/cognee/modules/search/graph/search_categories.py index a85b3a4e..cccb4fbc 100644 --- a/cognee/modules/search/graph/search_categories.py +++ b/cognee/modules/search/graph/search_categories.py @@ -1,18 +1,41 @@ +from typing import Union, Dict +""" Search categories in the graph and return their summary attributes. """ +from neo4j import AsyncSession +from cognee.shared.data_models import GraphDBType +import networkx as nx -async def search_categories(graph, query:str, other_param:str = None): +async def search_categories(graph: Union[nx.Graph, AsyncSession], query_label: str, infrastructure_config: Dict): """ - Filter nodes that contain 'LABEL' in their identifiers and return their summary attributes. + Filter nodes in the graph that contain the specified label and return their summary attributes. + This function supports both NetworkX graphs and Neo4j graph databases. Parameters: - - G (nx.Graph): The graph from which to filter nodes. + - graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. + - query_label (str): The label to filter nodes by. + - infrastructure_config (Dict): Configuration that includes the graph engine type. Returns: - - dict: A dictionary where keys are nodes containing 'SUMMARY' in their identifiers, - and values are their 'summary' attributes. + - Union[Dict, List[Dict]]: For NetworkX, returns a dictionary where keys are node identifiers, + and values are their 'content_labels' attributes. For Neo4j, returns a list of dictionaries, + each representing a node with 'nodeId' and 'summary'. """ - return {node: data.get('content_labels') for node, data in graph.nodes(data=True) if 'LABEL' in node and 'content_labels' in data} - - - + # Determine which client is in use based on the configuration + if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX: + # Logic for NetworkX + return {node: data.get('content_labels') for node, data in graph.nodes(data=True) if query_label in node and 'content_labels' in data} + + elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J: + # Logic for Neo4j + cypher_query = """ + MATCH (n) + WHERE $label IN labels(n) AND EXISTS(n.summary) + RETURN id(n) AS nodeId, n.summary AS summary + """ + result = await graph.run(cypher_query, label=query_label) + nodes_summary = [{"nodeId": record["nodeId"], "summary": record["summary"]} for record in await result.list()] + return nodes_summary + + else: + raise ValueError("Unsupported graph engine type.") diff --git a/cognee/modules/search/graph/search_neighbour.py b/cognee/modules/search/graph/search_neighbour.py index 43f853be..44e67346 100644 --- a/cognee/modules/search/graph/search_neighbour.py +++ b/cognee/modules/search/graph/search_neighbour.py @@ -1,22 +1,66 @@ """ Fetches the context of a given node in the graph""" +from typing import Union + +from neo4j import AsyncSession + from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client -async def search_neighbour(graph, id,other_param:dict = None): +import networkx as nx +from cognee.shared.data_models import GraphDBType + +async def search_neighbour(graph: Union[nx.Graph, AsyncSession], id: str, infrastructure_config: Dict, + other_param: dict = None): + """ + Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions. + Adapts to both NetworkX graphs and Neo4j graph databases based on the configuration. + + Parameters: + - graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. + - id (str): The identifier of the node to match against. + - infrastructure_config (Dict): Configuration that includes the graph engine type. + - other_param (dict, optional): A dictionary that may contain 'node_id' to specify the node. + Returns: + - List[str]: A list of 'description' attributes of nodes that share the same 'layer_uuid' with the specified node. + """ node_id = other_param.get('node_id') if other_param else None - if node_id is None or node_id not in graph: - return {} + if node_id is None: + return [] + + if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX: + if isinstance(graph, nx.Graph): + if node_id not in graph: + return [] + + relevant_context = [] + target_layer_uuid = graph.nodes[node_id].get('layer_uuid') + + for n, attr in graph.nodes(data=True): + if attr.get('layer_uuid') == target_layer_uuid and 'description' in attr: + relevant_context.append(attr['description']) + + return relevant_context + else: + raise ValueError("Graph object does not match the specified graph engine type in the configuration.") - relevant_context = [] - for n,attr in graph.nodes(data=True): - if id in n: - for n_, attr_ in graph.nodes(data=True): - relevant_layer = attr['layer_uuid'] + elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J: + if isinstance(graph, AsyncSession): + cypher_query = """ + MATCH (target {id: $node_id}) + WITH target.layer_uuid AS layer + MATCH (n) + WHERE n.layer_uuid = layer AND EXISTS(n.description) + RETURN n.description AS description + """ + result = await graph.run(cypher_query, node_id=node_id) + descriptions = [record["description"] for record in await result.list()] - if attr_.get('layer_uuid') == relevant_layer: - relevant_context.append(attr_['description']) + return descriptions + else: + raise ValueError("Graph session does not match the specified graph engine type in the configuration.") - return relevant_context + else: + raise ValueError("Unsupported graph engine type in the configuration.") diff --git a/cognee/modules/search/graph/search_summary.py b/cognee/modules/search/graph/search_summary.py index db5dce62..8d606602 100644 --- a/cognee/modules/search/graph/search_summary.py +++ b/cognee/modules/search/graph/search_summary.py @@ -1,18 +1,37 @@ -async def search_summary(graph, query:str, other_param:str = None): +from typing import Union, Dict +import networkx as nx +from neo4j import AsyncSession +from cognee.shared.data_models import GraphDBType + +async def search_summary(graph: Union[nx.Graph, AsyncSession], query: str, infrastructure_config: Dict, other_param: str = None) -> Dict[str, str]: """ - Filter nodes that contain 'SUMMARY' in their identifiers and return their summary attributes. + Filter nodes based on a condition (such as containing 'SUMMARY' in their identifiers) and return their summary attributes. + Supports both NetworkX graphs and Neo4j graph databases based on the configuration. Parameters: - - G (nx.Graph): The graph from which to filter nodes. + - graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. + - query (str): The query string to filter nodes by, e.g., 'SUMMARY'. + - infrastructure_config (Dict): Configuration that includes the graph engine type. + - other_param (str, optional): An additional parameter, unused in this implementation but could be for future enhancements. Returns: - - dict: A dictionary where keys are nodes containing 'SUMMARY' in their identifiers, - and values are their 'summary' attributes. + - Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'summary' attributes. """ - return {node: data.get('summary') for node, data in graph.nodes(data=True) if 'SUMMARY' in node and 'summary' in data} - - - + if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX: + return {node: data.get('summary') for node, data in graph.nodes(data=True) if query in node and 'summary' in data} + + elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J: + cypher_query = f""" + MATCH (n) + WHERE n.id CONTAINS $query AND EXISTS(n.summary) + RETURN n.id AS nodeId, n.summary AS summary + """ + results = await graph.run(cypher_query, query=query) + summary_data = {record["nodeId"]: record["summary"] for record in await results.list()} + return summary_data + + else: + raise ValueError("Unsupported graph engine type in the configuration.")