diff --git a/defog/generate_schema.py b/defog/generate_schema.py index bdb4de3..fcee97c 100644 --- a/defog/generate_schema.py +++ b/defog/generate_schema.py @@ -12,6 +12,7 @@ def generate_postgres_schema( return_format: str = "csv", scan: bool = True, return_tables_only: bool = False, + schemas: list[str] = ["public"], ) -> str: # when upload is True, we send the schema to the defog servers and generate a CSV # when its false, we return the schema as a dict @@ -24,60 +25,48 @@ def generate_postgres_schema( conn = psycopg2.connect(**self.db_creds) cur = conn.cursor() - schemas = {} + schemas = tuple(schemas) if len(tables) == 0: # get all tables - cur.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" - ) - tables = [row[0] for row in cur.fetchall()] + for schema in schemas: + cur.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = %s;", + (schema,), + ) + if schema == "public": + tables += [row[0] for row in cur.fetchall()] + else: + tables += [schema + "." + row[0] for row in cur.fetchall()] if return_tables_only: return tables print("Getting schema for each table that you selected...") - # get the schema for each table - for table_name in tables: - cur.execute( - "SELECT CAST(column_name AS TEXT), CAST(data_type AS TEXT) FROM information_schema.columns WHERE table_name::text = %s;", - (table_name,), - ) - rows = cur.fetchall() - rows = [row for row in rows] - rows = [{"column_name": i[0], "data_type": i[1]} for i in rows] - if scan: - rows = identify_categorical_columns(cur, table_name, rows) - schemas[table_name] = rows - # get foreign key relationships - print("Getting foreign keys for each table that you selected...") - tables_regclass_str = ", ".join( - [f"'{table_name}'::regclass" for table_name in tables] - ) - query = f"""SELECT - conrelid::regclass AS table_from, - pg_get_constraintdef(oid) AS foreign_key_definition - FROM pg_constraint - WHERE contype = 'f' - AND conrelid::regclass IN ({tables_regclass_str}) - AND confrelid::regclass IN ({tables_regclass_str}); - """ - cur.execute(query) - foreign_keys = list(cur.fetchall()) - foreign_keys = [fk[0] + " " + fk[1] for fk in foreign_keys] - - # get indexes for each table - print("Getting indexes for each table that you selected...") - tables_str = ", ".join([f"'{table_name}'" for table_name in tables]) - query = f"""SELECT indexdef FROM pg_indexes WHERE tablename IN ({tables_str});""" - cur.execute(query) - indexes = list(cur.fetchall()) - if len(indexes) > 0: - indexes = [index[0] for index in indexes] - else: - indexes = [] - # print("No indexes found.") + table_columns = {} + + # get the columns for each table + for schema in schemas: + for table_name in tables: + if "." in table_name: + _, table_name = table_name.split(".", 1) + cur.execute( + "SELECT CAST(column_name AS TEXT), CAST(data_type AS TEXT) FROM information_schema.columns WHERE table_name::text = %s AND table_schema = %s;", + ( + table_name, + schema, + ), + ) + rows = cur.fetchall() + rows = [row for row in rows] + rows = [{"column_name": i[0], "data_type": i[1]} for i in rows] + if scan: + rows = identify_categorical_columns(cur, table_name, rows) + if schema == "public": + table_columns[table_name] = rows + else: + table_columns[schema + table_name] = rows conn.close() print( @@ -89,9 +78,7 @@ def generate_postgres_schema( f"{self.base_url}/get_schema_csv", json={ "api_key": self.api_key, - "schemas": schemas, - "foreign_keys": foreign_keys, - "indexes": indexes, + "schemas": table_columns, }, ) resp = r.json() @@ -110,7 +97,7 @@ def generate_postgres_schema( f"Please feel free to open a github issue at https://github.com/defog-ai/defog-python if this a generic library issue, or email support@defog.ai." ) else: - return schemas + return table_columns def generate_redshift_schema( diff --git a/defog/query_methods.py b/defog/query_methods.py index 4da6e24..e9e2838 100644 --- a/defog/query_methods.py +++ b/defog/query_methods.py @@ -28,14 +28,10 @@ def get_query( "previous_context": previous_context, "db_type": self.db_type if self.db_type != "databricks" else "postgres", "glossary": glossary, - "language": language, "hard_filters": hard_filters, "dev": dev, "ignore_cache": ignore_cache, } - if schema != {}: - data["schema"] = schema - data["is_direct"] = True t_start = datetime.now() r = requests.post(