Skip to content

Commit

Permalink
fixes suggested by jp
Browse files Browse the repository at this point in the history
  • Loading branch information
rishsriv committed Jul 3, 2024
1 parent ba93985 commit 14e45e5
Showing 1 changed file with 78 additions and 74 deletions.
152 changes: 78 additions & 74 deletions defog/admin_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def get_golden_queries(


def create_table_ddl(
table_name: str, columns: List[Dict[str, str]], if_not_exists=True
table_name: str, columns: List[Dict[str, str]], add_exists=True
) -> str:
"""
Return a DDL statement for creating a table from a list of columns
Expand All @@ -272,7 +272,7 @@ def create_table_ddl(
- column_description: str
"""
md_create = ""
if if_not_exists:
if add_exists:
md_create += f"CREATE TABLE IF NOT EXISTS {table_name} (\n"
else:
md_create += f"CREATE TABLE {table_name} (\n"
Expand All @@ -291,7 +291,9 @@ def create_table_ddl(
return md_create


def create_ddl_from_metadata(metadata: Dict[str, List[Dict[str, str]]]) -> str:
def create_ddl_from_metadata(
metadata: Dict[str, List[Dict[str, str]]], add_exists=True
) -> str:
"""
Return a DDL statement for creating tables from metadata
`metadata` is a dictionary with table names as keys and lists of dictionaries as values.
Expand All @@ -303,11 +305,11 @@ def create_ddl_from_metadata(metadata: Dict[str, List[Dict[str, str]]]) -> str:
md_create = ""
for table_name, columns in metadata.items():
if "." in table_name:
table_name = table_name.split(".")[-1]
table_name = table_name.split(".", 1)[1]
schema_name = table_name.split(".")[0]

md_create += f"CREATE SCHEMA IF NOT EXISTS {schema_name};\n"
md_create += create_table_ddl(table_name, columns)
md_create += create_table_ddl(table_name, columns, add_exists=add_exists)
return md_create


Expand All @@ -317,75 +319,77 @@ def create_empty_tables(self, dev: bool = False):
"""
metadata = self.get_metadata(format="json", dev=dev)
if self.db_type == "sqlserver":
ddl = create_ddl_from_metadata(metadata, if_not_exists=False)
ddl = create_ddl_from_metadata(metadata, add_exists=False)
else:
ddl = create_ddl_from_metadata(metadata)

if self.db_type == "postgres" or self.db_type == "redshift":
import psycopg2

conn = psycopg2.connect(**self.db_creds)
cur = conn.cursor()
cur.execute(ddl)
conn.commit()
conn.close()
return True
elif self.db_type == "mysql":
import mysql.connector

conn = mysql.connector.connect(**self.db_creds)
cur = conn.cursor()
for statement in ddl.split(";"):
cur.execute(statement)
conn.commit()
conn.close()
return True
elif self.db_type == "databricks":
from databricks import sql

con = sql.connect(**self.db_creds)
con.execute(ddl)
conn.commit()
conn.close()
return True
elif self.db_type == "snowflake":
import snowflake.connector

conn = snowflake.connector.connect(
user=self.db_creds["user"],
password=self.db_creds["password"],
account=self.db_creds["account"],
)
conn.cursor().execute(
f"USE WAREHOUSE {self.db_creds['warehouse']}"
) # set the warehouse
cur = conn.cursor()
for statement in ddl.split(";"):
cur.execute(statement)
conn.commit()
conn.close()
return True
elif self.db_type == "bigquery":
from google.cloud import bigquery

client = bigquery.Client.from_service_account_json(
self.db_creds["json_key_path"]
)
for statement in ddl.split(";"):
client.query(statement)
return True
elif self.db_type == "sqlserver":
import pyodbc

print(ddl)

connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};DATABASE={self.db_creds['database']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;"
conn = pyodbc.connect(connection_string)
cur = conn.cursor()
for statement in ddl.split(";"):
cur.execute(statement)
conn.commit()
conn.close()
return True
else:
raise ValueError(f"Unsupported DB type: {self.db_type}")
try:
if self.db_type == "postgres" or self.db_type == "redshift":
import psycopg2

conn = psycopg2.connect(**self.db_creds)
cur = conn.cursor()
cur.execute(ddl)
conn.commit()
conn.close()
return True
elif self.db_type == "mysql":
import mysql.connector

conn = mysql.connector.connect(**self.db_creds)
cur = conn.cursor()
for statement in ddl.split(";"):
cur.execute(statement)
conn.commit()
conn.close()
return True
elif self.db_type == "databricks":
from databricks import sql

con = sql.connect(**self.db_creds)
con.execute(ddl)
conn.commit()
conn.close()
return True
elif self.db_type == "snowflake":
import snowflake.connector

conn = snowflake.connector.connect(
user=self.db_creds["user"],
password=self.db_creds["password"],
account=self.db_creds["account"],
)
conn.cursor().execute(
f"USE WAREHOUSE {self.db_creds['warehouse']}"
) # set the warehouse
cur = conn.cursor()
for statement in ddl.split(";"):
cur.execute(statement)
conn.commit()
conn.close()
return True
elif self.db_type == "bigquery":
from google.cloud import bigquery

client = bigquery.Client.from_service_account_json(
self.db_creds["json_key_path"]
)
for statement in ddl.split(";"):
client.query(statement)
return True
elif self.db_type == "sqlserver":
import pyodbc

connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};DATABASE={self.db_creds['database']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;"
conn = pyodbc.connect(connection_string)
cur = conn.cursor()
for statement in ddl.split(";"):
cur.execute(statement)
conn.commit()
conn.close()
return True
else:
raise ValueError(f"Unsupported DB type: {self.db_type}")
except Exception as e:
print(f"Error: {e}")
return False

0 comments on commit 14e45e5

Please sign in to comment.