Skip to content

Commit

Permalink
Re-enabled SQLServer (#52)
Browse files Browse the repository at this point in the history
* reenable sqlserver support from defog-python

* updated setup and README

* feat: Add SQL Server support to defog-python

This commit enables SQL Server support in defog-python by adding the necessary code changes to handle SQL Server database credentials and connections. Now users can specify a SQL Server host, database name, user, and password when initializing defog-python. This allows defog-python to connect to SQL Server databases and perform data extraction and manipulation tasks.

The changes include:
- Adding validation for SQL Server database credentials in the `Defog` class.
- Modifying the `init` function in the CLI module to prompt users for SQL Server database information.
- Updating the `generate_sqlserver_schema` function to use the provided SQL Server credentials for establishing a connection.
- Modifying the `execute_query_once` function to use the SQL Server credentials for connecting to the database.

These changes were made in response to user requests and will enhance the functionality of defog-python by expanding its database compatibility.

* defence against tables with no rows (likely because of misspellings)

* better cli for sqlserver

* added TrustServerCertificate to pyodbc connection string for sqlserver

* updated tests
  • Loading branch information
rishsriv authored Jun 25, 2024
1 parent 563aedc commit 5d40e9c
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 18 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ For a BigQuery installation, use
For a Databricks installation, use
`pip install --upgrade 'defog[databricks]'`

For a SQLServer installation, use
`pip install --upgrade 'defog[sqlserver]'`

## API Key
You can get your API key by going to [https://defog.ai/signup](https://defog.ai/signup) and creating an account. If you fail to verify your email, you can email us at support(at)defog.ai

Expand Down
12 changes: 11 additions & 1 deletion defog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"bigquery",
"snowflake",
"databricks",
"sqlserver",
]


Expand Down Expand Up @@ -163,7 +164,16 @@ def check_db_creds(db_type: str, db_creds: dict):
raise KeyError("db_creds must contain a 'access_token' key.")
if "http_path" not in db_creds:
raise KeyError("db_creds must contain a 'http_path' key.")
elif db_type == "mongo" or db_type == "sqlserver":
elif db_type == "sqlserver":
if "server" not in db_creds:
raise KeyError("db_creds must contain a 'server' key.")
if "database" not in db_creds:
raise KeyError("db_creds must contain a 'database' key.")
if "user" not in db_creds:
raise KeyError("db_creds must contain a 'user' key.")
if "password" not in db_creds:
raise KeyError("db_creds must contain a 'password' key.")
elif db_type == "mongo":
if "connection_string" not in db_creds:
raise KeyError("db_creds must contain a 'connection_string' key.")
elif db_type == "bigquery":
Expand Down
11 changes: 11 additions & 0 deletions defog/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,17 @@ def init():
db_creds = {
"json_key_path": json_key_path,
}
elif db_type == "sqlserver":
server = prompt("Please enter your database server host\n").strip()
database = prompt("Please enter your database name\n").strip()
user = prompt("Please enter your database user\n").strip()
password = pwinput.pwinput("Please enter your database password\n")
db_creds = {
"server": server,
"database": database,
"user": user,
"password": password,
}

# write to filepath and print confirmation
with open(filepath, "w") as f:
Expand Down
120 changes: 105 additions & 15 deletions defog/generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ def generate_postgres_schema(
rows = cur.fetchall()
rows = [row for row in rows]
rows = [{"column_name": i[0], "data_type": i[1]} for i in rows]
if scan:
rows = identify_categorical_columns(cur, table_name, rows)
if schema == "public":
table_columns[table_name] = rows
else:
table_columns[schema + table_name] = rows
if len(rows) > 0:
if scan:
rows = identify_categorical_columns(cur, table_name, rows)
if schema == "public":
table_columns[table_name] = rows
else:
table_columns[schema + table_name] = rows
conn.close()

print(
Expand Down Expand Up @@ -152,11 +153,12 @@ def generate_redshift_schema(
rows = cur.fetchall()
rows = [row for row in rows]
rows = [{"column_name": i[0], "data_type": i[1]} for i in rows]
if scan:
cur.execute(f"SET search_path TO {schema}")
rows = identify_categorical_columns(cur, table_name, rows)
cur.close()
schemas[table_name] = rows
if len(rows) > 0:
if scan:
cur.execute(f"SET search_path TO {schema}")
rows = identify_categorical_columns(cur, table_name, rows)
cur.close()
schemas[table_name] = rows

if upload:
print(
Expand Down Expand Up @@ -231,7 +233,8 @@ def generate_mysql_schema(
rows = [{"column_name": i[0], "data_type": i[1]} for i in rows]
if scan:
rows = identify_categorical_columns(cur, table_name, rows)
schemas[table_name] = rows
if len(rows) > 0:
schemas[table_name] = rows

conn.close()

Expand Down Expand Up @@ -303,7 +306,8 @@ def generate_databricks_schema(
]
if scan:
rows = identify_categorical_columns(cur, table_name, rows)
schemas[table_name] = rows
if len(rows) > 0:
schemas[table_name] = rows

conn.close()

Expand Down Expand Up @@ -391,7 +395,8 @@ def generate_snowflake_schema(
if scan:
rows = identify_categorical_columns(cur, table_name, rows)
cur.close()
schemas[table_name] = rows
if len(rows) > 0:
schemas[table_name] = rows

conn.close()

Expand Down Expand Up @@ -449,7 +454,8 @@ def generate_bigquery_schema(
table = client.get_table(table_name)
rows = table.schema
rows = [{"column_name": i.name, "data_type": i.field_type} for i in rows]
schemas[table_name] = rows
if len(rows) > 0:
schemas[table_name] = rows

client.close()

Expand Down Expand Up @@ -485,6 +491,83 @@ def generate_bigquery_schema(
return schemas


def generate_sqlserver_schema(
self,
tables: list,
upload: bool = True,
return_format: str = "csv",
return_tables_only: bool = False,
) -> str:
try:
import pyodbc
except:
raise Exception("pyodbc not installed.")

connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};DATABASE={self.db_creds['database']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;"
conn = pyodbc.connect(connection_string)
cur = conn.cursor()
schemas = {}
schema = self.db_creds.get("schema", "dbo")

if len(tables) == 0:
# get all tables
cur.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = %s;",
(schema,),
)
if schema == "dbo":
tables += [row[0] for row in cur.fetchall()]
else:
tables += [schema + "." + row[0] for row in cur.fetchall()]

if return_tables_only:
return tables

print("Getting schema for each table in your database...")
# get the schema for each table
for table_name in tables:
cur.execute(
f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';"
)
rows = cur.fetchall()
rows = [row for row in rows]
rows = [{"column_name": i[0], "data_type": i[1]} for i in rows]
if len(rows) > 0:
schemas[table_name] = rows

conn.close()
if upload:
print(
"Sending the schema to Defog servers and generating column descriptions. This might take up to 2 minutes..."
)
r = requests.post(
f"{self.base_url}/get_schema_csv",
json={
"api_key": self.api_key,
"schemas": schemas,
"foreign_keys": [],
"indexes": [],
},
)
resp = r.json()
if "csv" in resp:
csv = resp["csv"]
if return_format == "csv":
pd.read_csv(StringIO(csv)).to_csv("defog_metadata.csv", index=False)
return "defog_metadata.csv"
else:
return csv
else:
print(f"We got an error!")
if "message" in resp:
print(f"Error message: {resp['message']}")
print(
f"Please feel free to open a github issue at https://github.com/defog-ai/defog-python if this a generic library issue, or email [email protected]."
)
else:
return schemas


def generate_db_schema(
self,
tables: list,
Expand Down Expand Up @@ -541,6 +624,13 @@ def generate_db_schema(
upload=upload,
return_tables_only=return_tables_only,
)
elif self.db_type == "sqlserver":
return self.generate_sqlserver_schema(
tables,
return_format=return_format,
upload=upload,
return_tables_only=return_tables_only,
)
else:
raise ValueError(
f"Creation of a DB schema for {self.db_type} is not yet supported via the library. If you are a premium user, please contact us at [email protected] so we can manually add it."
Expand Down
15 changes: 15 additions & 0 deletions defog/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ def execute_query_once(db_type: str, db_creds, query: str):
colnames = [desc[0] for desc in cursor.description]
results = cursor.fetchall()
return colnames, results
elif db_type == "sqlserver":
try:
import pyodbc
except:
raise Exception("pyodbc not installed.")

connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={db_creds['server']};DATABASE={db_creds['database']};UID={db_creds['user']};PWD={db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;"
conn = pyodbc.connect(connection_string)
cur = conn.cursor()
cur.execute(query)
colnames = [desc[0] for desc in cur.description]
results = cur.fetchall()
cur.close()
conn.close()
return colnames, results
else:
raise Exception(f"Database type {db_type} not yet supported.")

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"bigquery": ["google-cloud-bigquery"],
"redshift": ["psycopg2-binary"],
"databricks": ["databricks-sql-connector"],
"sqlserver": ["pyodbc"],
}


Expand All @@ -25,7 +26,7 @@ def package_files(directory):
name="defog",
packages=find_packages(),
package_data={"defog": ["gcp/*", "aws/*"] + next_static_files},
version="0.64.2",
version="0.65.0",
description="Defog is a Python library that helps you generate data queries from natural language questions.",
author="Full Stack Data Pte. Ltd.",
license="MIT",
Expand Down
7 changes: 6 additions & 1 deletion tests/test_defog.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,12 @@ def test_check_db_creds_mongo(self):
Defog.check_db_creds("mongo", {"account": "some_account"})

def test_check_db_creds_sqlserver(self):
db_creds = {"connection_string": "some_connection_string"}
db_creds = {
"server": "some_server",
"database": "some_database",
"user": "some_user",
"password": "some_password",
}
Defog.check_db_creds("sqlserver", db_creds)
Defog.check_db_creds("sqlserver", {})
with self.assertRaises(KeyError):
Expand Down

0 comments on commit 5d40e9c

Please sign in to comment.