Skip to content

Commit

Permalink
added a method for creating an empty table, given metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
rishsriv committed Jul 3, 2024
1 parent 8fe724f commit 6d2e1d8
Showing 1 changed file with 130 additions and 0 deletions.
130 changes: 130 additions & 0 deletions defog/admin_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def get_metadata(self, format="markdown", export_path=None, dev=False):
].to_csv(export_path, index=False)
print(f"Metadata exported to {export_path}")
return True
elif format == "json":
return resp["table_metadata"]


def get_feedback(self, n_rows: int = 50, start_from: int = 0):
Expand Down Expand Up @@ -257,3 +259,131 @@ def get_golden_queries(
return golden_queries
else:
raise ValueError("format must be either 'csv' or 'json'.")


def create_table_ddl(
table_name: str, columns: List[Dict[str, str]], if_not_exists=True
) -> str:
"""
Return a DDL statement for creating a table from a list of columns
`columns` is a list of dictionaries with the following keys:
- column_name: str
- data_type: str
- column_description: str
"""
md_create = ""
if if_not_exists:
md_create += f"CREATE TABLE IF NOT EXISTS {table_name} (\n"
else:
md_create += f"CREATE TABLE {table_name} (\n"
for i, column in enumerate(columns):
col_name = column["column_name"]
# if column name has spaces and hasn't been wrapped in double quotes, wrap it in double quotes
if " " in col_name and not col_name.startswith('"'):
col_name = f'"{col_name}"'
dtype = column["data_type"]
if i < len(columns) - 1:
md_create += f" {col_name} {dtype},\n"
else:
# avoid the trailing comma for the last line
md_create += f" {col_name} {dtype}\n"
md_create += ");\n"
return md_create


def create_ddl_from_metadata(metadata: Dict[str, List[Dict[str, str]]]) -> str:
"""
Return a DDL statement for creating tables from metadata
`metadata` is a dictionary with table names as keys and lists of dictionaries as values.
Each dictionary in the list has the following keys:
- column_name: str
- data_type: str
- column_description: str
"""
md_create = ""
for table_name, columns in metadata.items():
if "." in table_name:
table_name = table_name.split(".")[-1]
schema_name = table_name.split(".")[0]

md_create += f"CREATE SCHEMA IF NOT EXISTS {schema_name};\n"
md_create += create_table_ddl(table_name, columns)
return md_create


def create_empty_tables(self, dev: bool = False):
"""
Create empty tables based on metadata
"""
metadata = self.get_metadata(format="json", dev=dev)
if self.db_type == "sqlserver":
ddl = create_ddl_from_metadata(metadata, if_not_exists=False)
else:
ddl = create_ddl_from_metadata(metadata)

if self.db_type == "postgres" or self.db_type == "redshift":
import psycopg2

conn = psycopg2.connect(**self.db_creds)
cur = conn.cursor()
cur.execute(ddl)
conn.commit()
conn.close()
return True
elif self.db_type == "mysql":
import mysql.connector

conn = mysql.connector.connect(**self.db_creds)
cur = conn.cursor()
for statement in ddl.split(";"):
cur.execute(statement)
conn.commit()
conn.close()
return True
elif self.db_type == "databricks":
from databricks import sql

con = sql.connect(**self.db_creds)
con.execute(ddl)
return True
elif self.db_type == "snowflake":
import snowflake.connector

conn = snowflake.connector.connect(
user=self.db_creds["user"],
password=self.db_creds["password"],
account=self.db_creds["account"],
)
conn.cursor().execute(
f"USE WAREHOUSE {self.db_creds['warehouse']}"
) # set the warehouse
cur = conn.cursor()
for statement in ddl.split(";"):
cur.execute(statement)
conn.commit()
conn.close()
return True
elif self.db_type == "bigquery":
from google.cloud import bigquery

client = bigquery.Client.from_service_account_json(
self.db_creds["json_key_path"]
)
for statement in ddl.split(";"):
client.query(statement)
return True
elif self.db_type == "sqlserver":
import pyodbc

print(ddl)

connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};DATABASE={self.db_creds['database']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;"
conn = pyodbc.connect(connection_string)
cur = conn.cursor()
for statement in ddl.split(";"):
cur.execute(statement)
conn.commit()
conn.close()
return True
else:
raise ValueError(f"Unsupported DB type: {self.db_type}")

0 comments on commit 6d2e1d8

Please sign in to comment.