Select Git revision
MainWindow.xaml.cs
-
Timon Römer authoredTimon Römer authored
main.py 17.34 KiB
from my_flask_app import app
from .models.models import CustomTable, CustomColumn, Theme
from flask_sqlalchemy import SQLAlchemy
from flask import jsonify, redirect, render_template, request, url_for, session
from sqlalchemy import Inspector, MetaData, create_engine, text, inspect
import pydot, base64, os
from sqlalchemy.orm import scoped_session, sessionmaker
# Set up database (call db.engine)
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
# app.config['SQLALCHEMY_DATABASE_URI'] = "postgresql://postgres:password@localhost:5432/test"
db = SQLAlchemy()
# db.init_app(app)
app.secret_key = 'my_secret_key' # Needed for session management
dropped_items = []
@app.route('/', methods=['POST', 'GET'])
def index():
if request.method == 'POST':
# Handle the form submission
database_uri = request.form['database_uri']
session['db_uri'] = database_uri
engine = create_engine(database_uri)
session_factory = sessionmaker(bind=engine)
db.session = scoped_session(session_factory)
insp = inspect(engine)
metadata_obj = MetaData()
# Determine the type and naem of database
database_type = engine.dialect.name # postgresql, sqlite
database = database_name_from_uri(engine, database_uri)
# Initialize variables
tables_selected = []
schemas = getSchema(insp)
themes = getThemes()
schema_Selected = request.form.get('schema', None)
show_all = request.form.get('show_all') == 'True'
tables1 = importMetadata(engine, schema_Selected, tables_selected, show_all)
# graph_DOT1 = createGraph(tables1, themes["Blue Navy"], True, True, True)
# image1 = generate_erd(graph_DOT1)
if dropped_items==[]:
image2 = ""
# else:
# tables2 = importMetadata(engine, None, dropped_items, False)
# graph_DOT2 = createGraph(tables2, themes["Blue Navy"], True, True, True)
# image2 = generate_erd(graph_DOT2)
# return render_template('app.html', database=database, schemas=schemas, show_all=show_all, schema_Selected=schema_Selected, tables=tables1, image1=image1, image2=image2, dropped_items=dropped_items)
# print(insp.get_foreign_keys('machine_sensor'))
# print(insp.get_unique_constraints('segmentation_data', schema='segmentation'))
return render_template('app.html', database=database, schemas=schemas, show_all=show_all, schema_Selected=schema_Selected, tables=tables1, dropped_items=dropped_items)
else:
# Display the form
return render_template('app.html')
@app.route('/handle-drop', methods=['POST'])
def handle_drop():
data = request.json
item_name = data.get('item')
action = data.get('action')
if action == 'added':
dropped_items.append(item_name)
elif action == 'removed' and item_name in dropped_items:
dropped_items.remove(item_name)
# Regenerate ERD based on the updated dropped_items
database = db.engine.url.database
themes = getThemes()
tables2 = importMetadata(database, None, dropped_items, False)
graph_DOT2 = createGraph(tables2, themes["Blue Navy"], True, True, True)
image2 = generate_erd(graph_DOT2)
return jsonify(image2=image2)
@app.route('/get-table-data', methods=['POST'])
def get_table_data():
data = request.json
table_name = data.get('table_name')
print(table_name)
# Query your database to get the data for the table_name
content = query_database_for_table_content(table_name)
# Convert content to HTML table format
html_table = generate_html_table(content)
return jsonify({'html_table': html_table})
def database_name_from_uri(engine, database_uri: str):
if engine.dialect.name == 'postgresql':
return engine.url.database
elif engine.dialect.name == 'sqlite':
return os.path.splitext(os.path.basename(database_uri.split('///')[-1]))[0]
else:
return 'Unknown'
def generate_html_table(content):
if not content:
return "No data found."
# Generate column headers
columns = content[0]
table_html = "<table class='uk-table uk-table-small uk-table-hover uk-table-divider'><thead><tr>"
for col in columns:
table_html += f"<th>{col}</th>"
table_html += "</tr></thead><tbody>"
# Generate table rows
for i in range(1, len(content)):
table_html += "<tr>"
for item in content[i]:
table_html += f"<td>{item}</td>"
table_html += "</tr>"
table_html += "</tbody></table>"
return table_html
def query_database_for_table_content(table_name, number=20):
# Initialize content list
content_list = []
# Get the schema of the table
schema = getTableSchema(table_name)
# Query the database to get the content of the table
sql_content = text(f"""SELECT * FROM {schema}.{table_name} LIMIT {number};""")
result = db.session.execute(sql_content, {'table_name': table_name, 'number': number}).fetchall()
if not result:
return []
# Get the column names
sql_columns = text("""
SELECT column_name
FROM information_schema.columns
WHERE table_name = :table_name;
""")
column_names = db.session.execute(sql_columns, {'table_name': table_name}).fetchall()
# Prepare column names
columns = [column_name[0] for column_name in column_names]
content_list.append(columns)
# Append rows to content list
for row in result:
content_list.append(list(row))
return content_list
# Only postgresql needs this function (database_type = 'postgresql')
def getTableSchema(table_name):
sql= text(f"""
SELECT table_schema
FROM information_schema.tables
WHERE table_name = :table_name;
""")
schema = db.session.execute(sql, {'table_name': table_name}).fetchone()[0]
return schema
def getSchema(insp):
# sql = text("""SELECT schema_name FROM information_schema.schemata;""")
# result = db.session.execute(sql)
# schemas = [row[0] for row in result]
schemas = insp.get_schema_names()
return schemas
def importMetadata(engine, schema=None, tables_selected=None, show_all=False):
tables = {}
if engine == None:
return tables
# Convert tables_selected to a list to ensure compatibility with SQL IN operation.
tables_selected_list = list(tables_selected) if tables_selected else None
# Fetch initial tables based on schema and table_names.
tables = fetch_initial_tables(engine, schema, tables_selected_list)
# If show_all is True, expand the list to include related tables.
if show_all:
expand_to_include_related_tables(engine, tables)
# Fetch columns for each table.
fetch_columns_for_tables(engine, tables)
# Fetch constraints (PK, FK, Unique) for each table.
fetch_constraints_for_tables(engine, tables)
return tables
def fetch_initial_tables(engine, schema=None, tables_selected_list=None):
if isinstance(engine, str):
engine = create_engine(engine)
tables = {}
insp = inspect(engine)
database_type = engine.dialect.name
# Get all table names in the database (or specific schema for PostgreSQL)
all_tables = []
if schema!=None and database_type == 'postgresql':
all_tables = insp.get_table_names(schema=schema)
elif schema==None and database_type == 'postgresql':
for schema_of_schemas in insp.get_schema_names():
for table_name in insp.get_table_names(schema=schema_of_schemas):
all_tables.append(table_name)
else: # For SQLite
all_tables = insp.get_table_names()
# Filter tables if a specific list is provided
if tables_selected_list:
table_names = [table for table in all_tables if table in tables_selected_list]
else:
table_names = all_tables
for table_name in table_names:
# For PostgreSQL, use the provided schema, otherwise use the default schema
table_schema = getTableSchema(table_name) if database_type == 'postgresql' else insp.default_schema_name
table = CustomTable(table_name, table_schema)
tables[table_name] = table
table.label = f"n{len(tables)}"
return tables
def expand_to_include_related_tables(engine, tables):
if isinstance(engine, str):
engine = create_engine(engine)
# Create an inspector object
insp = inspect(engine)
# This dictionary will temporarily store related tables to fetch later.
related_tables_to_fetch = {}
# Iterate over initially fetched tables to find foreign key relationships.
for tableName, table in tables.items():
# Fetch foreign key relationships for the current table using the inspector.
fks = insp.get_foreign_keys(tableName, schema=table.schema)
for fk in fks:
referenced_table_name = fk['referred_table']
referenced_schema = fk['referred_schema']
if referenced_table_name not in tables and referenced_table_name not in related_tables_to_fetch:
related_tables_to_fetch[referenced_table_name] = referenced_schema
# Fetch and add related tables.
for tableName, tableSchema in related_tables_to_fetch.items():
# Create a CustomTable object for each related table.
table = CustomTable(tableName, tableSchema)
tables[tableName] = table
return tables
def fetch_columns_for_tables(engine, tables):
if isinstance(engine, str):
engine = create_engine(engine)
insp = inspect(engine)
for tableName, table in tables.items():
# Use the inspector to get column information for each table
columns = insp.get_columns(tableName, schema=table.schema)
for col in columns:
name = col['name']
datatype = col['type']
nullable = col['nullable']
default = col['default']
# Create a CustomColumn object with the retrieved information
column = CustomColumn(table, name, '')
column.setDataType({
"type": str(datatype),
"nullable": nullable,
"default": default
})
# Append the column to the table's columns list
table.columns.append(column)
return tables
def fetch_constraints_for_tables(engine, tables):
if isinstance(engine, str):
engine = create_engine(engine)
insp = inspect(engine)
# Fetching Unique Constraints
for tableName, table in tables.items():
unique_constraints = insp.get_unique_constraints(tableName, schema=table.schema)
for uc in unique_constraints:
for column_name in uc['column_names']:
column = table.getColumn(column_name)
if column:
column.isunique = True
if uc['name'] not in table.uniques:
table.uniques[uc['name']] = []
table.uniques[uc['name']].append(column)
# Primary Keys
for tableName, table in tables.items():
pk_constraint = insp.get_pk_constraint(tableName, schema=table.schema)
for column_name in pk_constraint['constrained_columns']:
column = table.getColumn(column_name)
if column:
column.ispk = True
column.pkconstraint = pk_constraint['name']
# Foreign Keys
for tableName, table in tables.items():
fks = insp.get_foreign_keys(tableName, schema=table.schema)
for fk in fks:
fk_columns = fk['constrained_columns']
referred_table = fk['referred_table']
referred_columns = fk['referred_columns']
for fk_column, ref_column in zip(fk_columns, referred_columns):
column = table.getColumn(fk_column)
if column:
column.fkof = f"{referred_table}.{ref_column}"
if fk['name'] not in table.fks:
table.fks[fk['name']] = []
table.fks[fk['name']].append(column)
return tables
# def fetch_constraints_for_tables(engine, tables):
# # Fetching Unique Constraints
# for tableName, table in tables.items():
# sql_unique = text("""
# SELECT kcu.column_name, tc.constraint_name
# FROM information_schema.table_constraints AS tc
# JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name
# WHERE tc.table_name = :table_name AND tc.constraint_type = 'UNIQUE';
# """)
# unique_result = db.session.execute(sql_unique, {'table_name': tableName})
# for col in unique_result:
# name, constraintName = col
# column = table.getColumn(name)
# if column:
# column.isunique = True
# if constraintName not in table.uniques:
# table.uniques[constraintName] = []
# table.uniques[constraintName].append(column)
# # Primary Keys
# for tableName, table in tables.items():
# sql_pk = text("""
# SELECT kcu.column_name, tc.constraint_name, kcu.ordinal_position
# FROM information_schema.table_constraints AS tc
# JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name
# WHERE tc.table_name = :table_name AND tc.constraint_type = 'PRIMARY KEY';
# """)
# pk_result = db.session.execute(sql_pk, {'table_name': tableName})
# for col in pk_result:
# name, constraintName, ordinal_position = col
# column = table.getColumn(name)
# if column:
# column.ispk = True
# column.pkconstraint = constraintName
# # Assuming you want to order PKs, though not directly used in provided class
# # Fetching Foreign Keys for each table
# for tableName, table in tables.items():
# sql_fk = text("""
# SELECT
# tc.constraint_name,
# tc.table_name AS fk_table_name,
# kcu.column_name AS fk_column_name,
# ccu.table_name AS pk_table_name,
# ccu.column_name AS pk_column_name,
# ccu.table_schema AS pk_table_schema
# FROM
# information_schema.table_constraints AS tc
# JOIN information_schema.key_column_usage AS kcu
# ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema AND tc.table_name = kcu.table_name
# JOIN information_schema.constraint_column_usage AS ccu
# ON ccu.constraint_name = tc.constraint_name
# WHERE
# tc.constraint_type = 'FOREIGN KEY'
# AND tc.table_name = :table_name
# """)
# fk_result = db.session.execute(sql_fk, {'table_name': tableName})
# for row in fk_result:
# constraintName, fkTableName, fkColumnName, pkTableName, pkColumnName, pkTableSchema = row
# # Ensure the foreign key table is the current table being processed
# if fkTableName != tableName:
# continue
# fkTable = tables.get(fkTableName)
# pkTable = tables.get(pkTableName)
# if fkTable and pkTable:
# fkColumn = fkTable.getColumn(fkColumnName)
# pkColumn = pkTable.getColumn(pkColumnName)
# if fkColumn and pkColumn:
# # Here, instead of assigning pkColumn directly, store relevant info
# fkColumn.fkof = pkColumn # Adjust based on your application's needs
# if constraintName not in fkTable.fks:
# fkTable.fks[constraintName] = []
# fkTable.fks[constraintName].append(fkColumn)
# return tables
def createGraph(tables, theme, showColumns, showTypes, useUpperCase):
s = ('digraph {\n'
+ ' graph [ rankdir="LR" bgcolor="#ffffff" ]\n'
+ f' node [ style="filled" shape="{theme.shape}" gradientangle="180" ]\n'
+ ' edge [ arrowhead="none" arrowtail="none" dir="both" ]\n\n')
for name in tables:
s += tables[name].getDotShape(theme, showColumns, showTypes, useUpperCase)
s += "\n"
for name in tables:
s += tables[name].getDotLinks(theme)
s += "}\n"
return s
def generate_erd(graph_DOT):
graph_module = pydot.graph_from_dot_data(graph_DOT)
graph = graph_module[0]
png_image_data = graph.create_png()
encoded_image = base64.b64encode(png_image_data).decode('utf-8')
return encoded_image
def getThemes():
return {
"Common Gray": Theme("#6c6c6c", "#e0e0e0", "#f5f5f5",
"#e0e0e0", "#000000", "#000000", "rounded", "Mrecord", "#696969", "1"),
"Blue Navy": Theme("#1a5282", "#1a5282", "#ffffff",
"#1a5282", "#000000", "#ffffff", "rounded", "Mrecord", "#0078d7", "2"),
#"Gradient Green": Theme("#716f64", "#008080:#ffffff", "#008080:#ffffff",
# "transparent", "#000000", "#000000", "rounded", "Mrecord", "#696969", "1"),
#"Blue Sky": Theme("#716f64", "#d3dcef:#ffffff", "#d3dcef:#ffffff",
# "transparent", "#000000", "#000000", "rounded", "Mrecord", "#696969", "1"),
"Common Gray Box": Theme("#6c6c6c", "#e0e0e0", "#f5f5f5",
"#e0e0e0", "#000000", "#000000", "rounded", "record", "#696969", "1")
}
if __name__ == "__main__":
app.run(debug=True)