diff --git a/README.md b/README.md index efbf921..9b8f7bc 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,9 @@ For a BigQuery installation, use For a Databricks installation, use `pip install --upgrade 'defog[databricks]'` +For a SQLServer installation, use +`pip install --upgrade 'defog[sqlserver]'` + ## API Key You can get your API key by going to [https://defog.ai/signup](https://defog.ai/signup) and creating an account. If you fail to verify your email, you can email us at support(at)defog.ai diff --git a/defog/__init__.py b/defog/__init__.py index 92da8ad..8f166ca 100644 --- a/defog/__init__.py +++ b/defog/__init__.py @@ -16,6 +16,7 @@ "bigquery", "snowflake", "databricks", + "sqlserver", ] @@ -163,7 +164,16 @@ def check_db_creds(db_type: str, db_creds: dict): raise KeyError("db_creds must contain a 'access_token' key.") if "http_path" not in db_creds: raise KeyError("db_creds must contain a 'http_path' key.") - elif db_type == "mongo" or db_type == "sqlserver": + elif db_type == "sqlserver": + if "server" not in db_creds: + raise KeyError("db_creds must contain a 'server' key.") + if "database" not in db_creds: + raise KeyError("db_creds must contain a 'database' key.") + if "user" not in db_creds: + raise KeyError("db_creds must contain a 'user' key.") + if "password" not in db_creds: + raise KeyError("db_creds must contain a 'password' key.") + elif db_type == "mongo": if "connection_string" not in db_creds: raise KeyError("db_creds must contain a 'connection_string' key.") elif db_type == "bigquery": diff --git a/defog/cli.py b/defog/cli.py index d3520a8..c2d2fcc 100644 --- a/defog/cli.py +++ b/defog/cli.py @@ -194,6 +194,17 @@ def init(): db_creds = { "json_key_path": json_key_path, } + elif db_type == "sqlserver": + server = prompt("Please enter your database server host\n").strip() + database = prompt("Please enter your database name\n").strip() + user = prompt("Please enter your database user\n").strip() + password = pwinput.pwinput("Please enter your database password\n") + db_creds = { + "server": server, + "database": database, + "user": user, + "password": password, + } # write to filepath and print confirmation with open(filepath, "w") as f: diff --git a/defog/generate_schema.py b/defog/generate_schema.py index fcee97c..c7f9b5c 100644 --- a/defog/generate_schema.py +++ b/defog/generate_schema.py @@ -61,12 +61,13 @@ def generate_postgres_schema( rows = cur.fetchall() rows = [row for row in rows] rows = [{"column_name": i[0], "data_type": i[1]} for i in rows] - if scan: - rows = identify_categorical_columns(cur, table_name, rows) - if schema == "public": - table_columns[table_name] = rows - else: - table_columns[schema + table_name] = rows + if len(rows) > 0: + if scan: + rows = identify_categorical_columns(cur, table_name, rows) + if schema == "public": + table_columns[table_name] = rows + else: + table_columns[schema + table_name] = rows conn.close() print( @@ -152,11 +153,12 @@ def generate_redshift_schema( rows = cur.fetchall() rows = [row for row in rows] rows = [{"column_name": i[0], "data_type": i[1]} for i in rows] - if scan: - cur.execute(f"SET search_path TO {schema}") - rows = identify_categorical_columns(cur, table_name, rows) - cur.close() - schemas[table_name] = rows + if len(rows) > 0: + if scan: + cur.execute(f"SET search_path TO {schema}") + rows = identify_categorical_columns(cur, table_name, rows) + cur.close() + schemas[table_name] = rows if upload: print( @@ -231,7 +233,8 @@ def generate_mysql_schema( rows = [{"column_name": i[0], "data_type": i[1]} for i in rows] if scan: rows = identify_categorical_columns(cur, table_name, rows) - schemas[table_name] = rows + if len(rows) > 0: + schemas[table_name] = rows conn.close() @@ -303,7 +306,8 @@ def generate_databricks_schema( ] if scan: rows = identify_categorical_columns(cur, table_name, rows) - schemas[table_name] = rows + if len(rows) > 0: + schemas[table_name] = rows conn.close() @@ -391,7 +395,8 @@ def generate_snowflake_schema( if scan: rows = identify_categorical_columns(cur, table_name, rows) cur.close() - schemas[table_name] = rows + if len(rows) > 0: + schemas[table_name] = rows conn.close() @@ -449,7 +454,8 @@ def generate_bigquery_schema( table = client.get_table(table_name) rows = table.schema rows = [{"column_name": i.name, "data_type": i.field_type} for i in rows] - schemas[table_name] = rows + if len(rows) > 0: + schemas[table_name] = rows client.close() @@ -485,6 +491,83 @@ def generate_bigquery_schema( return schemas +def generate_sqlserver_schema( + self, + tables: list, + upload: bool = True, + return_format: str = "csv", + return_tables_only: bool = False, +) -> str: + try: + import pyodbc + except: + raise Exception("pyodbc not installed.") + + connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};DATABASE={self.db_creds['database']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" + conn = pyodbc.connect(connection_string) + cur = conn.cursor() + schemas = {} + schema = self.db_creds.get("schema", "dbo") + + if len(tables) == 0: + # get all tables + cur.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = %s;", + (schema,), + ) + if schema == "dbo": + tables += [row[0] for row in cur.fetchall()] + else: + tables += [schema + "." + row[0] for row in cur.fetchall()] + + if return_tables_only: + return tables + + print("Getting schema for each table in your database...") + # get the schema for each table + for table_name in tables: + cur.execute( + f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" + ) + rows = cur.fetchall() + rows = [row for row in rows] + rows = [{"column_name": i[0], "data_type": i[1]} for i in rows] + if len(rows) > 0: + schemas[table_name] = rows + + conn.close() + if upload: + print( + "Sending the schema to Defog servers and generating column descriptions. This might take up to 2 minutes..." + ) + r = requests.post( + f"{self.base_url}/get_schema_csv", + json={ + "api_key": self.api_key, + "schemas": schemas, + "foreign_keys": [], + "indexes": [], + }, + ) + resp = r.json() + if "csv" in resp: + csv = resp["csv"] + if return_format == "csv": + pd.read_csv(StringIO(csv)).to_csv("defog_metadata.csv", index=False) + return "defog_metadata.csv" + else: + return csv + else: + print(f"We got an error!") + if "message" in resp: + print(f"Error message: {resp['message']}") + print( + f"Please feel free to open a github issue at https://github.com/defog-ai/defog-python if this a generic library issue, or email support@defog.ai." + ) + else: + return schemas + + def generate_db_schema( self, tables: list, @@ -541,6 +624,13 @@ def generate_db_schema( upload=upload, return_tables_only=return_tables_only, ) + elif self.db_type == "sqlserver": + return self.generate_sqlserver_schema( + tables, + return_format=return_format, + upload=upload, + return_tables_only=return_tables_only, + ) else: raise ValueError( f"Creation of a DB schema for {self.db_type} is not yet supported via the library. If you are a premium user, please contact us at founder@defog.ai so we can manually add it." diff --git a/defog/query.py b/defog/query.py index b72ce82..722dfad 100644 --- a/defog/query.py +++ b/defog/query.py @@ -113,6 +113,21 @@ def execute_query_once(db_type: str, db_creds, query: str): colnames = [desc[0] for desc in cursor.description] results = cursor.fetchall() return colnames, results + elif db_type == "sqlserver": + try: + import pyodbc + except: + raise Exception("pyodbc not installed.") + + connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={db_creds['server']};DATABASE={db_creds['database']};UID={db_creds['user']};PWD={db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" + conn = pyodbc.connect(connection_string) + cur = conn.cursor() + cur.execute(query) + colnames = [desc[0] for desc in cur.description] + results = cur.fetchall() + cur.close() + conn.close() + return colnames, results else: raise Exception(f"Database type {db_type} not yet supported.") diff --git a/setup.py b/setup.py index c84bfe6..432e4f2 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,7 @@ "bigquery": ["google-cloud-bigquery"], "redshift": ["psycopg2-binary"], "databricks": ["databricks-sql-connector"], + "sqlserver": ["pyodbc"], } @@ -25,7 +26,7 @@ def package_files(directory): name="defog", packages=find_packages(), package_data={"defog": ["gcp/*", "aws/*"] + next_static_files}, - version="0.64.2", + version="0.65.0", description="Defog is a Python library that helps you generate data queries from natural language questions.", author="Full Stack Data Pte. Ltd.", license="MIT", diff --git a/tests/test_defog.py b/tests/test_defog.py index 2e4dcf9..0ae6ee0 100644 --- a/tests/test_defog.py +++ b/tests/test_defog.py @@ -199,7 +199,12 @@ def test_check_db_creds_mongo(self): Defog.check_db_creds("mongo", {"account": "some_account"}) def test_check_db_creds_sqlserver(self): - db_creds = {"connection_string": "some_connection_string"} + db_creds = { + "server": "some_server", + "database": "some_database", + "user": "some_user", + "password": "some_password", + } Defog.check_db_creds("sqlserver", db_creds) Defog.check_db_creds("sqlserver", {}) with self.assertRaises(KeyError):