-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove linking between nodes because it takes too much time
- Loading branch information
1 parent
5378541
commit af33296
Showing
4 changed files
with
151 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,45 @@ | ||
""" This module contains the function to find the neighbours of a given node in the graph""" | ||
|
||
|
||
async def search_adjacent(graph, query: str, other_param: dict = None) -> dict: | ||
""" Find the neighbours of a given node in the graph | ||
:param graph: A NetworkX graph object | ||
:return: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node | ||
from typing import Union, Dict | ||
import networkx as nx | ||
from neo4j import AsyncSession | ||
from cognee.shared.data_models import GraphDBType | ||
async def search_adjacent(graph: Union[nx.Graph, AsyncSession], query: str, infrastructure_config: Dict, other_param: dict = None) -> Dict[str, str]: | ||
""" | ||
Find the neighbours of a given node in the graph and return their descriptions. | ||
Supports both NetworkX graphs and Neo4j graph databases based on the configuration. | ||
Parameters: | ||
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. | ||
- query (str): Unused in this implementation but could be used for future enhancements. | ||
- infrastructure_config (Dict): Configuration that includes the graph engine type. | ||
- other_param (dict, optional): A dictionary that may contain 'node_id' to specify the node. | ||
Returns: | ||
- Dict[str, str]: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node. | ||
""" | ||
node_id = other_param.get('node_id') if other_param else None | ||
|
||
if node_id is None or node_id not in graph: | ||
if node_id is None: | ||
return {} | ||
|
||
neighbors = list(graph.neighbors(node_id)) | ||
neighbor_descriptions = {} | ||
|
||
for neighbor in neighbors: | ||
# Access the 'description' attribute for each neighbor | ||
# The get method returns None if 'description' attribute does not exist for the node | ||
neighbor_descriptions[neighbor] = graph.nodes[neighbor].get('description') | ||
|
||
return neighbor_descriptions | ||
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX: | ||
if node_id not in graph: | ||
return {} | ||
|
||
neighbors = list(graph.neighbors(node_id)) | ||
neighbor_descriptions = {neighbor: graph.nodes[neighbor].get('description') for neighbor in neighbors} | ||
return neighbor_descriptions | ||
|
||
elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J: | ||
cypher_query = """ | ||
MATCH (node {id: $node_id})-[:CONNECTED_TO]->(neighbor) | ||
RETURN neighbor.id AS neighbor_id, neighbor.description AS description | ||
""" | ||
results = await graph.run(cypher_query, node_id=node_id) | ||
neighbor_descriptions = {record["neighbor_id"]: record["description"] for record in await results.list() if "description" in record} | ||
return neighbor_descriptions | ||
|
||
else: | ||
raise ValueError("Unsupported graph engine type in the configuration.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,41 @@ | ||
from typing import Union, Dict | ||
|
||
""" Search categories in the graph and return their summary attributes. """ | ||
|
||
from neo4j import AsyncSession | ||
from cognee.shared.data_models import GraphDBType | ||
import networkx as nx | ||
|
||
async def search_categories(graph, query:str, other_param:str = None): | ||
async def search_categories(graph: Union[nx.Graph, AsyncSession], query_label: str, infrastructure_config: Dict): | ||
""" | ||
Filter nodes that contain 'LABEL' in their identifiers and return their summary attributes. | ||
Filter nodes in the graph that contain the specified label and return their summary attributes. | ||
This function supports both NetworkX graphs and Neo4j graph databases. | ||
Parameters: | ||
- G (nx.Graph): The graph from which to filter nodes. | ||
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. | ||
- query_label (str): The label to filter nodes by. | ||
- infrastructure_config (Dict): Configuration that includes the graph engine type. | ||
Returns: | ||
- dict: A dictionary where keys are nodes containing 'SUMMARY' in their identifiers, | ||
and values are their 'summary' attributes. | ||
- Union[Dict, List[Dict]]: For NetworkX, returns a dictionary where keys are node identifiers, | ||
and values are their 'content_labels' attributes. For Neo4j, returns a list of dictionaries, | ||
each representing a node with 'nodeId' and 'summary'. | ||
""" | ||
return {node: data.get('content_labels') for node, data in graph.nodes(data=True) if 'LABEL' in node and 'content_labels' in data} | ||
|
||
|
||
|
||
# Determine which client is in use based on the configuration | ||
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX: | ||
# Logic for NetworkX | ||
return {node: data.get('content_labels') for node, data in graph.nodes(data=True) if query_label in node and 'content_labels' in data} | ||
|
||
elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J: | ||
# Logic for Neo4j | ||
cypher_query = """ | ||
MATCH (n) | ||
WHERE $label IN labels(n) AND EXISTS(n.summary) | ||
RETURN id(n) AS nodeId, n.summary AS summary | ||
""" | ||
result = await graph.run(cypher_query, label=query_label) | ||
nodes_summary = [{"nodeId": record["nodeId"], "summary": record["summary"]} for record in await result.list()] | ||
return nodes_summary | ||
|
||
else: | ||
raise ValueError("Unsupported graph engine type.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,37 @@ | ||
|
||
|
||
|
||
async def search_summary(graph, query:str, other_param:str = None): | ||
from typing import Union, Dict | ||
import networkx as nx | ||
from neo4j import AsyncSession | ||
from cognee.shared.data_models import GraphDBType | ||
|
||
async def search_summary(graph: Union[nx.Graph, AsyncSession], query: str, infrastructure_config: Dict, other_param: str = None) -> Dict[str, str]: | ||
""" | ||
Filter nodes that contain 'SUMMARY' in their identifiers and return their summary attributes. | ||
Filter nodes based on a condition (such as containing 'SUMMARY' in their identifiers) and return their summary attributes. | ||
Supports both NetworkX graphs and Neo4j graph databases based on the configuration. | ||
Parameters: | ||
- G (nx.Graph): The graph from which to filter nodes. | ||
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. | ||
- query (str): The query string to filter nodes by, e.g., 'SUMMARY'. | ||
- infrastructure_config (Dict): Configuration that includes the graph engine type. | ||
- other_param (str, optional): An additional parameter, unused in this implementation but could be for future enhancements. | ||
Returns: | ||
- dict: A dictionary where keys are nodes containing 'SUMMARY' in their identifiers, | ||
and values are their 'summary' attributes. | ||
- Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'summary' attributes. | ||
""" | ||
return {node: data.get('summary') for node, data in graph.nodes(data=True) if 'SUMMARY' in node and 'summary' in data} | ||
|
||
|
||
|
||
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX: | ||
return {node: data.get('summary') for node, data in graph.nodes(data=True) if query in node and 'summary' in data} | ||
|
||
elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J: | ||
cypher_query = f""" | ||
MATCH (n) | ||
WHERE n.id CONTAINS $query AND EXISTS(n.summary) | ||
RETURN n.id AS nodeId, n.summary AS summary | ||
""" | ||
results = await graph.run(cypher_query, query=query) | ||
summary_data = {record["nodeId"]: record["summary"] for record in await results.list()} | ||
return summary_data | ||
|
||
else: | ||
raise ValueError("Unsupported graph engine type in the configuration.") |