diff --git a/defog/generate_schema.py b/defog/generate_schema.py index 090f633..55f33d7 100644 --- a/defog/generate_schema.py +++ b/defog/generate_schema.py @@ -130,14 +130,22 @@ def generate_redshift_schema( "psycopg2 not installed. Please install it with `pip install psycopg2-binary`." ) - conn = psycopg2.connect(**self.db_creds) + if "schema" not in self.db_creds or self.db_creds["schema"].lower() == "public": + schema = "public" + conn = psycopg2.connect(**self.db_creds) + else: + schema = self.db_creds["schema"] + del self.db_creds["schema"] + conn = psycopg2.connect(**self.db_creds) cur = conn.cursor() + schemas = {} if len(tables) == 0: # get all tables cur.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" + "SELECT table_name FROM information_schema.tables WHERE table_schema = %s;", + (schema,), ) tables = [row[0] for row in cur.fetchall()] @@ -149,8 +157,11 @@ def generate_redshift_schema( for table_name in tables: try: cur.execute( - "SELECT CAST(column_name AS TEXT), CAST(data_type AS TEXT) FROM information_schema.columns WHERE table_name::text = %s;", - (table_name,), + "SELECT CAST(column_name AS TEXT), CAST(data_type AS TEXT) FROM information_schema.columns WHERE table_name::text = %s AND table_schema= %s;", + ( + table_name, + schema, + ), ) rows = cur.fetchall() rows = [row for row in rows] @@ -159,7 +170,8 @@ def generate_redshift_schema( rows = [] if len(rows) == 0: cur.execute( - f"SELECT CAST(columnname AS TEXT), CAST(external_type AS TEXT) FROM svv_external_columns WHERE table_name = '{table_name}';" + "SELECT CAST(columnname AS TEXT), CAST(external_type AS TEXT) FROM svv_external_columns WHERE table_name = %s;", + (table_name,), ) rows = cur.fetchall() rows = [row for row in rows] diff --git a/defog/query.py b/defog/query.py index b067bb4..2f3ada1 100644 --- a/defog/query.py +++ b/defog/query.py @@ -9,13 +9,38 @@ def execute_query_once(db_type: str, db_creds, query: str): """ Executes the query once and returns the column names and results. """ - if db_type == "postgres" or db_type == "redshift": + if db_type == "postgres": try: import psycopg2 except: raise Exception("psycopg2 not installed.") conn = psycopg2.connect(**db_creds) cur = conn.cursor() + cur.execute(query) + colnames = [desc[0] for desc in cur.description] + results = cur.fetchall() + cur.close() + conn.close() + return colnames, results + elif db_type == "redshift": + try: + import psycopg2 + except: + raise Exception("redshift_connector not installed.") + + if "schema" not in db_creds or db_creds["schema"].lower() == "public": + schema = None + conn = psycopg2.connect(**db_creds) + else: + schema = db_creds["schema"] + del db_creds["schema"] + conn = psycopg2.connect(**db_creds) + + cur = conn.cursor() + + if schema is not None: + cur.execute(f"SET search_path TO {schema}") + cur.execute(query) colnames = [desc[0] for desc in cur.description] results = cur.fetchall() diff --git a/setup.py b/setup.py index 325e0d1..f98ec5b 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def package_files(directory): name="defog", packages=find_packages(), package_data={"defog": ["gcp/*", "aws/*"] + next_static_files}, - version="0.60.0", + version="0.62.0", description="Defog is a Python library that helps you generate data queries from natural language questions.", author="Full Stack Data Pte. Ltd.", license="MIT",