From c23a314b631fbf4f9015342aedd19c4ae2d7b031 Mon Sep 17 00:00:00 2001 From: Igor Ilic <30923996+dexters1@users.noreply.github.com> Date: Mon, 28 Oct 2024 09:57:30 +0100 Subject: [PATCH] COG-414: fix postgres database deletion (#163) * fix: Add deletion of all tables in all schemas for postgres Added deletion of all tables in postgres database, but this fix causes an issue regrading creation of duplicate tables on next run Fix #COG-414 * fix: Resolve issue with database deletion Resolve issue with database deletion by cleaning Metadata after every schema Fix #COG-414 * fix: Move cleaning of MetaData out of drop table loop Moved cleaning of MetaData to be after all tables have been dropped Fix #COG-414 * refactor: Remove unnecessary print statement Removed unnecessary print statement Refactor #COG-414 * fix: Fix table deletion for SqlAlchemyAdapter Fixed deletion of tables in SqlAlchemyAdapter so it works for sqlite and postgres Fix #COG-414 * feat: Add deletion by id for SqlAlchemyAdapter Added ability to delete data from database by id Feature #COG-414 * fix: Add support for postgresql syntax for getting table function in SqlAlchemyAdapter Added support for schema namespace for getting tables Fix #COG-414 --- .../sqlalchemy/SqlAlchemyAdapter.py | 83 ++++++++++++++++--- 1 file changed, 71 insertions(+), 12 deletions(-) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 81a828bd8..edde07565 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -1,7 +1,9 @@ from os import path -from typing import AsyncGenerator +from uuid import UUID +from typing import Optional +from typing import AsyncGenerator, List from contextlib import asynccontextmanager -from sqlalchemy import text, select +from sqlalchemy import text, select, MetaData, Table from sqlalchemy.orm import joinedload from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker @@ -50,11 +52,14 @@ async def create_table(self, schema_name: str, table_name: str, table_config: li await connection.execute(text(f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});")) await connection.close() - async def delete_table(self, table_name: str): + async def delete_table(self, table_name: str, schema_name: Optional[str] = "public"): async with self.engine.begin() as connection: - await connection.execute(text(f"DROP TABLE IF EXISTS {table_name} CASCADE;")) - - await connection.close() + if self.engine.dialect.name == "sqlite": + # SQLite doesn’t support schema namespaces and the CASCADE keyword. + # However, foreign key constraint can be defined with ON DELETE CASCADE during table creation. + await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};")) + else: + await connection.execute(text(f"DROP TABLE IF EXISTS {schema_name}.{table_name} CASCADE;")) async def insert_data(self, schema_name: str, table_name: str, data: list[dict]): columns = ", ".join(data[0].keys()) @@ -65,6 +70,55 @@ async def insert_data(self, schema_name: str, table_name: str, data: list[dict]) await connection.execute(insert_query, data) await connection.close() + async def get_schema_list(self) -> List[str]: + """ + Return a list of all schema names in database + """ + if self.engine.dialect.name == "postgresql": + async with self.engine.begin() as connection: + result = await connection.execute( + text(""" + SELECT schema_name FROM information_schema.schemata + WHERE schema_name NOT IN ('pg_catalog', 'pg_toast', 'information_schema'); + """) + ) + return [schema[0] for schema in result.fetchall()] + return [] + + async def delete_data_by_id(self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"): + """ + Delete data in given table based on id. Table must have an id Column. + """ + async with self.get_async_session() as session: + TableModel = await self.get_table(table_name, schema_name) + await session.execute(TableModel.delete().where(TableModel.c.id == data_id)) + await session.commit() + + async def get_table(self, table_name: str, schema_name: Optional[str] = "public") -> Table: + """ + Dynamically loads a table using the given table name and schema name. + """ + async with self.engine.begin() as connection: + if self.engine.dialect.name == "sqlite": + # Load the schema information into the MetaData object + await connection.run_sync(Base.metadata.reflect) + if table_name in Base.metadata.tables: + return Base.metadata.tables[table_name] + else: + raise ValueError(f"Table '{table_name}' not found.") + else: + # Create a MetaData instance to load table information + metadata = MetaData() + # Load table information from schema into MetaData + await connection.run_sync(metadata.reflect, schema=schema_name) + # Define the full table name + full_table_name = f"{schema_name}.{table_name}" + # Check if table is in list of tables for the given schema + if full_table_name in metadata.tables: + return metadata.tables[full_table_name] + raise ValueError(f"Table '{full_table_name}' not found.") + + async def get_data(self, table_name: str, filters: dict = None): async with self.engine.begin() as connection: query = f"SELECT * FROM {table_name}" @@ -119,12 +173,17 @@ async def delete_database(self): self.db_path = None else: async with self.engine.begin() as connection: - # Load the schema information into the MetaData object - await connection.run_sync(Base.metadata.reflect) - for table in Base.metadata.sorted_tables: - drop_table_query = text(f"DROP TABLE IF EXISTS {table.name} CASCADE") - await connection.execute(drop_table_query) - + schema_list = await self.get_schema_list() + # Create a MetaData instance to load table information + metadata = MetaData() + # Drop all tables from all schemas + for schema_name in schema_list: + # Load the schema information into the MetaData object + await connection.run_sync(metadata.reflect, schema=schema_name) + for table in metadata.sorted_tables: + drop_table_query = text(f"DROP TABLE IF EXISTS {schema_name}.{table.name} CASCADE") + await connection.execute(drop_table_query) + metadata.clear() except Exception as e: print(f"Error deleting database: {e}")