Skip to content

Commit

Permalink
Sqlite works, made fixes in config so it becomes a basis, added a few…
Browse files Browse the repository at this point in the history
… mods on top
  • Loading branch information
Vasilije1990 committed Feb 16, 2024
1 parent 3a33503 commit 91fe3f5
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 60 deletions.
54 changes: 29 additions & 25 deletions cognitive_architecture/database/graphdb/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def delete_specific_memory_type(self, user_id, memory_type):
except Exception as e:
return f"An error occurred: {str(e)}"

def retrieve_semantic_memory(
async def retrieve_semantic_memory(
self, user_id: str, timestamp: float = None, summarized: bool = None
):
if timestamp is not None and summarized is not None:
Expand Down Expand Up @@ -418,9 +418,10 @@ def retrieve_semantic_memory(
MATCH (semantic)-[:HAS_KNOWLEDGE]->(knowledge)
RETURN knowledge
"""
return self.query(query, params={"user_id": user_id})
output = await self.query(query, params={"user_id": user_id})
return output

def retrieve_episodic_memory(
async def retrieve_episodic_memory(
self, user_id: str, timestamp: float = None, summarized: bool = None
):
if timestamp is not None and summarized is not None:
Expand Down Expand Up @@ -450,9 +451,10 @@ def retrieve_episodic_memory(
MATCH (episodic)-[:HAS_EVENT]->(event)
RETURN event
"""
return self.query(query, params={"user_id": user_id})
output = await self.query(query, params={"user_id": user_id})
return output

def retrieve_buffer_memory(
async def retrieve_buffer_memory(
self, user_id: str, timestamp: float = None, summarized: bool = None
):
if timestamp is not None and summarized is not None:
Expand Down Expand Up @@ -482,15 +484,17 @@ def retrieve_buffer_memory(
MATCH (buffer)-[:CURRENTLY_HOLDING]->(item)
RETURN item
"""
return self.query(query, params={"user_id": user_id})
output = self.query(query, params={"user_id": user_id})
return output

def retrieve_public_memory(self, user_id: str):
async def retrieve_public_memory(self, user_id: str):
query = """
MATCH (user:User {userId: $user_id})-[:HAS_PUBLIC_MEMORY]->(public:PublicMemory)
MATCH (public)-[:HAS_DOCUMENT]->(document)
RETURN document
"""
return self.query(query, params={"user_id": user_id})
output = await self.query(query, params={"user_id": user_id})
return output

def generate_graph_semantic_memory_document_summary(
self,
Expand Down Expand Up @@ -698,7 +702,7 @@ def create_document_node_cypher(

return cypher_query

def update_document_node_with_db_ids(
async def update_document_node_with_db_ids(
self, vectordb_namespace: str, document_id: str, user_id: str = None
):
"""
Expand Down Expand Up @@ -731,7 +735,7 @@ def update_document_node_with_db_ids(

return cypher_query

def run_merge_query(
async def run_merge_query(
self, user_id: str, memory_type: str, similarity_threshold: float
) -> str:
"""
Expand Down Expand Up @@ -769,7 +773,7 @@ def run_merge_query(
RETURN labels(n) AS NodeType, collect(n) AS Nodes
"""

node_results = self.query(query)
node_results = await self.query(query)

node_types = [record["NodeType"] for record in node_results]

Expand All @@ -785,11 +789,11 @@ def run_merge_query(
CALL apoc.refactor.mergeNodes([n1, n2], {{mergeRels: true}}) YIELD node
RETURN node
"""
self.query(query)
self.close()
await self.query(query)
await self.close()
return query

def get_namespaces_by_document_category(self, user_id: str, category: str):
async def get_namespaces_by_document_category(self, user_id: str, category: str):
"""
Retrieve a list of Vectordb namespaces for documents of a specified category associated with a given user.
Expand All @@ -812,7 +816,7 @@ def get_namespaces_by_document_category(self, user_id: str, category: str):
WHERE document.documentCategory = '{category}'
RETURN document.vectordbNamespace AS namespace
"""
result = self.query(query)
result = await self.query(query)
namespaces = [record["namespace"] for record in result]
return namespaces
except Exception as e:
Expand Down Expand Up @@ -850,10 +854,10 @@ async def create_memory_node(self, labels, topic=None):
"""

try:
result = self.query(memory_cypher)
result = await self.query(memory_cypher)
# Assuming the result is a list of records, where each record contains 'memoryId'
memory_id = result[0]["memoryId"] if result else None
self.close()
await self.close()
return memory_id
except Neo4jError as e:
logging.error(f"Error creating or finding memory node: {e}")
Expand Down Expand Up @@ -882,7 +886,7 @@ def link_user_to_public(
logging.error(f"Error linking Public node to user: {e}")
raise

def delete_memory_node(self, memory_id: int, topic: str) -> None:
async def delete_memory_node(self, memory_id: int, topic: str) -> None:
if not memory_id or not topic:
raise ValueError("Memory ID and Topic are required for deletion.")

Expand All @@ -892,12 +896,12 @@ def delete_memory_node(self, memory_id: int, topic: str) -> None:
DETACH DELETE {topic.lower()}
"""
logging.info("Delete Cypher Query: %s", delete_cypher)
self.query(delete_cypher)
await self.query(delete_cypher)
except Neo4jError as e:
logging.error(f"Error deleting {topic} memory node: {e}")
raise

def unlink_memory_from_user(
async def unlink_memory_from_user(
self, memory_id: int, user_id: str, topic: str = "PublicMemory"
) -> None:
"""
Expand Down Expand Up @@ -929,27 +933,27 @@ def unlink_memory_from_user(
MATCH (user:User {{userId: '{user_id}'}})-[r:{relationship_type}]->(memory:{topic}) WHERE id(memory) = {memory_id}
DELETE r
"""
self.query(unlink_cypher)
await self.query(unlink_cypher)
except Neo4jError as e:
logging.error(f"Error unlinking {topic} from user: {e}")
raise

def link_public_memory_to_user(self, memory_id, user_id):
async def link_public_memory_to_user(self, memory_id, user_id):
# Link an existing Public Memory node to a User node
link_cypher = f"""
MATCH (user:User {{userId: '{user_id}'}})
MATCH (publicMemory:PublicMemory) WHERE id(publicMemory) = {memory_id}
MERGE (user)-[:HAS_PUBLIC_MEMORY]->(publicMemory)
"""
self.query(link_cypher)
await self.query(link_cypher)

def retrieve_node_id_for_memory_type(self, topic: str = "SemanticMemory"):
async def retrieve_node_id_for_memory_type(self, topic: str = "SemanticMemory"):
link_cypher = f""" MATCH(publicMemory: {topic})
RETURN
id(publicMemory)
AS
memoryId """
node_ids = self.query(link_cypher)
node_ids = await self.query(link_cypher)
return node_ids


Expand Down
70 changes: 35 additions & 35 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,13 @@ async def user_query_to_graph_db(session: AsyncSession, user_id: str, query_inpu
)
result = neo4j_graph_db.query(cypher_query)

neo4j_graph_db.run_merge_query(
await neo4j_graph_db.run_merge_query(
user_id=user_id, memory_type="SemanticMemory", similarity_threshold=0.8
)
neo4j_graph_db.run_merge_query(
await neo4j_graph_db.run_merge_query(
user_id=user_id, memory_type="EpisodicMemory", similarity_threshold=0.8
)
neo4j_graph_db.close()
await neo4j_graph_db.close()

await update_entity(session, Operation, job_id, "SUCCESS")

Expand Down Expand Up @@ -381,16 +381,16 @@ async def add_documents_to_graph_db(
await create_public_memory(
user_id=user_id, labels=["sr"], topic="PublicMemory"
)
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(
ids = await neo4j_graph_db.retrieve_node_id_for_memory_type(
topic="PublicMemory"
)
neo4j_graph_db.close()
await neo4j_graph_db.close()
print(ids)
else:
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(
ids = await neo4j_graph_db.retrieve_node_id_for_memory_type(
topic="SemanticMemory"
)
neo4j_graph_db.close()
await neo4j_graph_db.close()
print(ids)

for id in ids:
Expand All @@ -404,38 +404,38 @@ async def add_documents_to_graph_db(
rs = neo4j_graph_db.create_document_node_cypher(
classification, user_id, public_memory_id=id.get("memoryId")
)
neo4j_graph_db.close()
await neo4j_graph_db.close()
else:
rs = neo4j_graph_db.create_document_node_cypher(
classification, user_id, memory_type="SemanticMemory"
)
neo4j_graph_db.close()
await neo4j_graph_db.close()
logging.info("Cypher query is %s", str(rs))
neo4j_graph_db = Neo4jGraphDB(
url=config.graph_database_url,
username=config.graph_database_username,
password=config.graph_database_password,
)
neo4j_graph_db.query(rs)
neo4j_graph_db.close()
await neo4j_graph_db.query(rs)
await neo4j_graph_db.close()
logging.info("WE GOT HERE")
neo4j_graph_db = Neo4jGraphDB(
url=config.graph_database_url,
username=config.graph_database_username,
password=config.graph_database_password,
)
if memory_details[0][1] == "PUBLIC":
neo4j_graph_db.update_document_node_with_db_ids(
await neo4j_graph_db.update_document_node_with_db_ids(
vectordb_namespace=memory_details[0][0], document_id=doc_id
)
neo4j_graph_db.close()
await neo4j_graph_db.close()
else:
neo4j_graph_db.update_document_node_with_db_ids(
await neo4j_graph_db.update_document_node_with_db_ids(
vectordb_namespace=memory_details[0][0],
document_id=doc_id,
user_id=user_id,
)
neo4j_graph_db.close()
await neo4j_graph_db.close()
# await update_entity_graph_summary(session, DocsModel, doc_id, True)
except Exception as e:
return e
Expand Down Expand Up @@ -518,14 +518,14 @@ async def user_context_enrichment(
# await user_query_to_graph_db(session, user_id, query)

semantic_mem = neo4j_graph_db.retrieve_semantic_memory(user_id=user_id)
neo4j_graph_db.close()
await neo4j_graph_db.close()
neo4j_graph_db = Neo4jGraphDB(
url=config.graph_database_url,
username=config.graph_database_username,
password=config.graph_database_password,
)
episodic_mem = neo4j_graph_db.retrieve_episodic_memory(user_id=user_id)
neo4j_graph_db.close()
await neo4j_graph_db.close()
# public_mem = neo4j_graph_db.retrieve_public_memory(user_id=user_id)

if detect_language(query) != "en":
Expand All @@ -541,7 +541,7 @@ async def user_context_enrichment(
summaries = await neo4j_graph_db.get_memory_linked_document_summaries(
user_id=user_id, memory_type=memory_type
)
neo4j_graph_db.close()
await neo4j_graph_db.close()
logging.info("Summaries are is %s", summaries)
# logging.info("Context from graphdb is %s", context)
# result = neo4j_graph_db.query(document_categories_query)
Expand Down Expand Up @@ -571,7 +571,7 @@ async def user_context_enrichment(
postgres_id = await neo4j_graph_db.get_memory_linked_document_ids(
user_id, summary_id=relevant_summary_id, memory_type=memory_type
)
neo4j_graph_db.close()
await neo4j_graph_db.close()
# postgres_id = neo4j_graph_db.query(get_doc_ids)
logging.info("Postgres ids are %s", postgres_id)
namespace_id = await get_memory_name_by_doc_id(session, postgres_id[0])
Expand Down Expand Up @@ -688,7 +688,7 @@ async def create_public_memory(
# Assuming the topic for public memory is predefined, e.g., "PublicMemory"
# Create the memory node
memory_id = await neo4j_graph_db.create_memory_node(labels=labels, topic=topic)
neo4j_graph_db.close()
await neo4j_graph_db.close()
return memory_id
except Neo4jError as e:
logging.error(f"Error creating public memory node: {e}")
Expand Down Expand Up @@ -729,19 +729,19 @@ async def attach_user_to_memory(
)

# Assuming the topic for public memory is predefined, e.g., "PublicMemory"
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(topic=topic)
neo4j_graph_db.close()
ids = await neo4j_graph_db.retrieve_node_id_for_memory_type(topic=topic)
await neo4j_graph_db.close()

for id in ids:
neo4j_graph_db = Neo4jGraphDB(
url=config.graph_database_url,
username=config.graph_database_username,
password=config.graph_database_password,
)
linked_memory = neo4j_graph_db.link_public_memory_to_user(
linked_memory = await neo4j_graph_db.link_public_memory_to_user(
memory_id=id.get("memoryId"), user_id=user_id
)
neo4j_graph_db.close()
await neo4j_graph_db.close()
return 1
except Neo4jError as e:
logging.error(f"Error creating public memory node: {e}")
Expand Down Expand Up @@ -781,8 +781,8 @@ async def unlink_user_from_memory(
)

# Assuming the topic for public memory is predefined, e.g., "PublicMemory"
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(topic=topic)
neo4j_graph_db.close()
ids = await neo4j_graph_db.retrieve_node_id_for_memory_type(topic=topic)
await neo4j_graph_db.close()

for id in ids:
neo4j_graph_db = Neo4jGraphDB(
Expand All @@ -793,7 +793,7 @@ async def unlink_user_from_memory(
linked_memory = neo4j_graph_db.unlink_memory_from_user(
memory_id=id.get("memoryId"), user_id=user_id
)
neo4j_graph_db.close()
await neo4j_graph_db.close()
return 1
except Neo4jError as e:
logging.error(f"Error creating public memory node: {e}")
Expand Down Expand Up @@ -879,14 +879,14 @@ class GraphQLQuery(BaseModel):
# print(out)
# load_doc_to_graph = await add_documents_to_graph_db(session, user_id)
# print(load_doc_to_graph)
user_id = "test_user"
loader_settings = {
"format": "PDF",
"source": "DEVICE",
"path": [".data"]
}
await load_documents_to_vectorstore(session, user_id, loader_settings=loader_settings)
# await create_public_memory(user_id=user_id, labels=['sr'], topic="PublicMemory")
# user_id = "test_user"
# loader_settings = {
# "format": "PDF",
# "source": "DEVICE",
# "path": [".data"]
# }
# await load_documents_to_vectorstore(session, user_id, loader_settings=loader_settings)
await create_public_memory(user_id=user_id, labels=['sr'], topic="PublicMemory")
# await add_documents_to_graph_db(session, user_id)
#
# neo4j_graph_db = Neo4jGraphDB(
Expand Down

0 comments on commit 91fe3f5

Please sign in to comment.