Skip to content

Commit

Permalink
Merge pull request #110 from lanl/pulido/issue108
Browse files Browse the repository at this point in the history
Updated file_writer.py with Dot class
  • Loading branch information
jpulidojr authored Aug 29, 2024
2 parents ab6f394 + 4e779a8 commit f201d21
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 0 deletions.
103 changes: 103 additions & 0 deletions dsi/plugins/file_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from hashlib import sha1
import json, csv
from math import isnan
import sqlite3
import subprocess

from dsi.plugins.metadata import StructuredMetadata

Expand All @@ -26,6 +28,107 @@ def __init__(self, filenames, **kwargs):
sha = sha1(open(filename, 'rb').read())
self.file_info[abspath(filename)] = sha.hexdigest()

class ER_Diagram(FileWriter):

def __init__(self, filenames, **kwargs):
super().__init__(filenames, **kwargs)

def export_erd(self, dbname, fname):
"""
Function that outputs a dot file for the given database.
`dbname`: database to create an ER diagram for
`fname`: name (including path) of the png file that contains the generated ER diagram
`return`: none
"""
db = sqlite3.connect(dbname)

# if fname[-4:] == ".dot":
# fname = fname[:-4]

file_type = ".png"
if fname[-4:] == ".png" or fname[-4:] == ".pdf" or fname[-4:] == ".jpg":
file_type = fname[-4:]
fname = fname[:-4]
elif fname[-5:] == ".jpeg":
file_type = fname[-5:]
fname = fname[:-5]

# if fname[-4:] == ".dot":
# fname = fname[:-4]
dot_file = open(fname + ".dot", "w")

numColsERD = 1

dot_file.write("digraph sqliteschema { ")
dot_file.write("node [shape=plaintext]; ")
dot_file.write("rankdir=LR ")
dot_file.write("splines=true ")
dot_file.write("overlap=false ")

list_db_tbls = "SELECT tbl_name, NULL AS label, NULL AS color, NULL AS clusterid FROM sqlite_master WHERE type='table'"
try:
tbl_list_stmt = db.execute(list_db_tbls)
except sqlite3.Error as er:
dot_file.write(er.sqlite_errorname)
dot_file.write("Can't prepare table list statement")
db.close()
dot_file.close()

for row in tbl_list_stmt:
tbl_name = row[0]

tbl_info_sql = f"PRAGMA table_info({tbl_name})"
try:
tbl_info_stmt = db.execute(tbl_info_sql)
except sqlite3.Error as er:
dot_file.write(er.sqlite_errorname)
dot_file.write(f"Can't prepare table info statement on table {tbl_name}")
db.close()
dot_file.close()

dot_file.write(f"{tbl_name} [label=<<TABLE CELLSPACING=\"0\"><TR><TD COLSPAN=\"{numColsERD}\"><B>{tbl_name}</B></TD></TR>")

curr_row = 0
inner_brace = 0
for info_row in tbl_info_stmt:
if curr_row % numColsERD == 0:
inner_brace = 1
dot_file.write("<TR>")

dot_file.write(f"<TD PORT=\"{info_row[1]}\">{info_row[1]}</TD>")
curr_row += 1
if curr_row % numColsERD == 0:
inner_brace = 0
dot_file.write("</TR>")

if inner_brace:
dot_file.write("</TR>")
dot_file.write("</TABLE>>]; ")

tbl_list_stmt = db.execute(list_db_tbls)
for row in tbl_list_stmt:
tbl_name = row[0]

fkey_info_sql = f"PRAGMA foreign_key_list({tbl_name})"
try:
fkey_info_stmt = db.execute(fkey_info_sql)
except sqlite3.Error as er:
dot_file.write(er.sqlite_errorname)
dot_file.write(f"Can't prepare foreign key statement on table {tbl_name}")
db.close()
dot_file.close()

for fkey_row in fkey_info_stmt:
dot_file.write(f"{tbl_name}:{fkey_row[3]} -> {fkey_row[2]}:{fkey_row[4]}; ")

dot_file.write("}")
db.close()
dot_file.close()

subprocess.run(["dot", "-T", file_type[1:], "-o", fname + file_type, fname + ".dot"])

class Csv(FileWriter):
"""
Expand Down
23 changes: 23 additions & 0 deletions dsi/tests/test_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from dsi.plugins import file_writer as fw
import cv2
import sqlite3
import numpy as np

def test_export_db_erd():

connection = sqlite3.connect("test.db")
cursor = connection.cursor()
cursor.execute("CREATE TABLE IF NOT EXISTS example (id INTEGER, name TEXT, age INTEGER)")
cursor.execute("INSERT INTO example VALUES (1, 'alice', 20)")
cursor.execute("INSERT INTO example VALUES (2, 'bob', 30)")
cursor.execute("INSERT INTO example VALUES (3, 'eve', 40)")
connection.commit()
connection.close()

erd = fw.ER_Diagram("test.db")
erd.export_erd("test.db", "test1")

er_image = cv2.imread("test1.png")

assert er_image is not None #check if image generated at all
assert np.mean(er_image) != 255 #check if image is all white pixels (empty diagram)

0 comments on commit f201d21

Please sign in to comment.