Skip to content
Snippets Groups Projects
Select Git revision
  • c98dcd4a0bab98779128b33e1c9b39d6ee7ef667
  • master default protected
2 results

MainWindow.xaml.cs

Blame
  • main.py 17.34 KiB
    from my_flask_app import app
    from .models.models import CustomTable, CustomColumn, Theme
    from flask_sqlalchemy import SQLAlchemy
    from flask import jsonify, redirect, render_template, request, url_for, session
    from sqlalchemy import Inspector, MetaData, create_engine, text, inspect
    import pydot, base64, os
    from sqlalchemy.orm import scoped_session, sessionmaker
    
    
    # Set up database (call db.engine)
    app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
    # app.config['SQLALCHEMY_DATABASE_URI'] = "postgresql://postgres:password@localhost:5432/test"
     
    db = SQLAlchemy() 
    # db.init_app(app)
    
    app.secret_key = 'my_secret_key'  # Needed for session management
    dropped_items = []
    
    @app.route('/', methods=['POST', 'GET'])
    def index():
        if request.method == 'POST':
            # Handle the form submission
            database_uri = request.form['database_uri']
            session['db_uri'] = database_uri
            engine = create_engine(database_uri)
            session_factory = sessionmaker(bind=engine)
            db.session = scoped_session(session_factory)
            insp = inspect(engine)
            metadata_obj = MetaData()
    
            # Determine the type and naem of database 
            database_type = engine.dialect.name                     # postgresql, sqlite
            database = database_name_from_uri(engine, database_uri)
            # Initialize variables
            tables_selected = []
            schemas = getSchema(insp)
            themes = getThemes()
    
            schema_Selected = request.form.get('schema', None)
            show_all = request.form.get('show_all') == 'True'
    
            tables1 = importMetadata(engine, schema_Selected, tables_selected, show_all)
            # graph_DOT1 = createGraph(tables1, themes["Blue Navy"], True, True, True)
            # image1 = generate_erd(graph_DOT1)
    
            if dropped_items==[]:
                image2 = ""
            # else:
            #     tables2 = importMetadata(engine, None, dropped_items, False)
            #     graph_DOT2 = createGraph(tables2, themes["Blue Navy"], True, True, True)
            #     image2 = generate_erd(graph_DOT2)
    
            # return render_template('app.html', database=database, schemas=schemas, show_all=show_all, schema_Selected=schema_Selected, tables=tables1, image1=image1, image2=image2, dropped_items=dropped_items)
            
    
    
            # print(insp.get_foreign_keys('machine_sensor'))
            # print(insp.get_unique_constraints('segmentation_data', schema='segmentation'))
    
            return render_template('app.html', database=database, schemas=schemas, show_all=show_all, schema_Selected=schema_Selected, tables=tables1, dropped_items=dropped_items)
        else:
            # Display the form
            return render_template('app.html')
        
    
    @app.route('/handle-drop', methods=['POST'])
    def handle_drop():
        data = request.json
        item_name = data.get('item')
        action = data.get('action')
    
        if action == 'added':
            dropped_items.append(item_name)
        elif action == 'removed' and item_name in dropped_items:
            dropped_items.remove(item_name)
    
        # Regenerate ERD based on the updated dropped_items
        database = db.engine.url.database
        themes = getThemes()
        tables2 = importMetadata(database, None, dropped_items, False)
        graph_DOT2 = createGraph(tables2, themes["Blue Navy"], True, True, True)
        image2 = generate_erd(graph_DOT2)
    
        return jsonify(image2=image2)
    
    
    @app.route('/get-table-data', methods=['POST'])
    def get_table_data():
        data = request.json
        table_name = data.get('table_name')
        print(table_name)
    
        # Query your database to get the data for the table_name
        content = query_database_for_table_content(table_name)
    
        # Convert content to HTML table format
        html_table = generate_html_table(content)
    
        return jsonify({'html_table': html_table})
    
    
    def database_name_from_uri(engine, database_uri: str):
        if engine.dialect.name == 'postgresql':
            return engine.url.database
        elif engine.dialect.name == 'sqlite':
            return os.path.splitext(os.path.basename(database_uri.split('///')[-1]))[0]
        else:
            return 'Unknown'
        
    def generate_html_table(content):
        if not content:
            return "No data found."
        # Generate column headers
        columns = content[0]
        table_html = "<table class='uk-table uk-table-small uk-table-hover uk-table-divider'><thead><tr>"
        for col in columns:
            table_html += f"<th>{col}</th>"
    
        table_html += "</tr></thead><tbody>"
    
        # Generate table rows
        for i in range(1, len(content)): 
            table_html += "<tr>"
            for item in content[i]:
                table_html += f"<td>{item}</td>"
            table_html += "</tr>"
        table_html += "</tbody></table>"
        return table_html
    
    
    def query_database_for_table_content(table_name, number=20):
        # Initialize content list
        content_list = []
    
        # Get the schema of the table
        schema = getTableSchema(table_name)
        # Query the database to get the content of the table
        sql_content = text(f"""SELECT * FROM {schema}.{table_name} LIMIT {number};""")
        result = db.session.execute(sql_content, {'table_name': table_name, 'number': number}).fetchall()
        
        if not result:
            return []
        
        # Get the column names
        sql_columns = text("""
                SELECT column_name
                FROM information_schema.columns
                WHERE table_name = :table_name;
            """)
        column_names = db.session.execute(sql_columns, {'table_name': table_name}).fetchall()
    
        # Prepare column names
        columns = [column_name[0] for column_name in column_names]
        content_list.append(columns)
    
        # Append rows to content list
        for row in result:
            content_list.append(list(row))
    
        return content_list
    
    # Only postgresql needs this function (database_type = 'postgresql')
    def getTableSchema(table_name):
        sql= text(f"""
            SELECT table_schema 
            FROM information_schema.tables 
            WHERE table_name = :table_name;
        """)
        schema = db.session.execute(sql, {'table_name': table_name}).fetchone()[0]
        return schema
        
    
    def getSchema(insp):
        # sql = text("""SELECT schema_name FROM information_schema.schemata;""")
        # result = db.session.execute(sql)
        # schemas = [row[0] for row in result]
        schemas = insp.get_schema_names()
        return schemas
    
    
    def importMetadata(engine, schema=None, tables_selected=None, show_all=False):
        tables = {}
        if engine == None:
            return tables
    
        # Convert tables_selected to a list to ensure compatibility with SQL IN operation.
        tables_selected_list = list(tables_selected) if tables_selected else None
    
        # Fetch initial tables based on schema and table_names.
        tables = fetch_initial_tables(engine, schema, tables_selected_list)
    
        # If show_all is True, expand the list to include related tables.
        if show_all:
            expand_to_include_related_tables(engine, tables)
    
        # Fetch columns for each table.
        fetch_columns_for_tables(engine, tables)
    
        # Fetch constraints (PK, FK, Unique) for each table.
        fetch_constraints_for_tables(engine, tables)
    
        return tables
    
    
    def fetch_initial_tables(engine, schema=None, tables_selected_list=None):
        if isinstance(engine, str):
            engine = create_engine(engine)
        tables = {}
        insp = inspect(engine)
        database_type = engine.dialect.name
    
        # Get all table names in the database (or specific schema for PostgreSQL)
        all_tables = []
        if schema!=None and database_type == 'postgresql':
            all_tables = insp.get_table_names(schema=schema)
        elif schema==None and database_type == 'postgresql':
            for schema_of_schemas in insp.get_schema_names():
                for table_name in insp.get_table_names(schema=schema_of_schemas):
                    all_tables.append(table_name)
        else:       # For SQLite
            all_tables = insp.get_table_names() 
    
        # Filter tables if a specific list is provided
        if tables_selected_list:
            table_names = [table for table in all_tables if table in tables_selected_list]
        else:
            table_names = all_tables
    
        for table_name in table_names:
            # For PostgreSQL, use the provided schema, otherwise use the default schema
            table_schema = getTableSchema(table_name) if database_type == 'postgresql' else insp.default_schema_name
            table = CustomTable(table_name, table_schema)
            tables[table_name] = table
            table.label = f"n{len(tables)}"
    
        return tables
    
    
    def expand_to_include_related_tables(engine, tables):
        if isinstance(engine, str):
            engine = create_engine(engine)
        # Create an inspector object
        insp = inspect(engine)
    
        # This dictionary will temporarily store related tables to fetch later.
        related_tables_to_fetch = {}
    
        # Iterate over initially fetched tables to find foreign key relationships.
        for tableName, table in tables.items():
            # Fetch foreign key relationships for the current table using the inspector.
            fks = insp.get_foreign_keys(tableName, schema=table.schema)
            for fk in fks:
                referenced_table_name = fk['referred_table']
                referenced_schema = fk['referred_schema']
    
                if referenced_table_name not in tables and referenced_table_name not in related_tables_to_fetch:
                    related_tables_to_fetch[referenced_table_name] = referenced_schema
    
        # Fetch and add related tables.
        for tableName, tableSchema in related_tables_to_fetch.items():
            # Create a CustomTable object for each related table.
            table = CustomTable(tableName, tableSchema)
            tables[tableName] = table
    
        return tables
    
    
    def fetch_columns_for_tables(engine, tables):
        if isinstance(engine, str):
            engine = create_engine(engine)
        insp = inspect(engine)
    
        for tableName, table in tables.items():
            # Use the inspector to get column information for each table
            columns = insp.get_columns(tableName, schema=table.schema)
            for col in columns:
                name = col['name']
                datatype = col['type']
                nullable = col['nullable']
                default = col['default']
    
                # Create a CustomColumn object with the retrieved information
                column = CustomColumn(table, name, '')
                column.setDataType({
                    "type": str(datatype),
                    "nullable": nullable,
                    "default": default
                })
    
                # Append the column to the table's columns list
                table.columns.append(column)
        return tables
    
    
    def fetch_constraints_for_tables(engine, tables):
        if isinstance(engine, str):
            engine = create_engine(engine)
        insp = inspect(engine)
    
        # Fetching Unique Constraints
        for tableName, table in tables.items():
            unique_constraints = insp.get_unique_constraints(tableName, schema=table.schema)
            for uc in unique_constraints:
                for column_name in uc['column_names']:
                    column = table.getColumn(column_name)
                    if column:
                        column.isunique = True
                        if uc['name'] not in table.uniques:
                            table.uniques[uc['name']] = []
                        table.uniques[uc['name']].append(column)
    
        # Primary Keys
        for tableName, table in tables.items():
            pk_constraint = insp.get_pk_constraint(tableName, schema=table.schema)
            for column_name in pk_constraint['constrained_columns']:
                column = table.getColumn(column_name)
                if column:
                    column.ispk = True
                    column.pkconstraint = pk_constraint['name']
    
        # Foreign Keys
        for tableName, table in tables.items():
            fks = insp.get_foreign_keys(tableName, schema=table.schema)
            for fk in fks:
                fk_columns = fk['constrained_columns']
                referred_table = fk['referred_table']
                referred_columns = fk['referred_columns']
                for fk_column, ref_column in zip(fk_columns, referred_columns):
                    column = table.getColumn(fk_column)
                    if column:
                        column.fkof = f"{referred_table}.{ref_column}"
                        if fk['name'] not in table.fks:
                            table.fks[fk['name']] = []
                        table.fks[fk['name']].append(column)
    
        return tables
    
    
    # def fetch_constraints_for_tables(engine, tables):
    #     # Fetching Unique Constraints
    #     for tableName, table in tables.items():
    #         sql_unique = text("""
    #             SELECT kcu.column_name, tc.constraint_name
    #             FROM information_schema.table_constraints AS tc
    #             JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name
    #             WHERE tc.table_name = :table_name AND tc.constraint_type = 'UNIQUE';
    #         """)
    #         unique_result = db.session.execute(sql_unique, {'table_name': tableName})
    #         for col in unique_result:
    #             name, constraintName = col
    #             column = table.getColumn(name)
    #             if column:
    #                 column.isunique = True
    #                 if constraintName not in table.uniques:
    #                     table.uniques[constraintName] = []
    #                 table.uniques[constraintName].append(column)
    
    
    #     # Primary Keys
    #     for tableName, table in tables.items():
    #         sql_pk = text("""
    #             SELECT kcu.column_name, tc.constraint_name, kcu.ordinal_position
    #             FROM information_schema.table_constraints AS tc
    #             JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name
    #             WHERE tc.table_name = :table_name AND tc.constraint_type = 'PRIMARY KEY';
    #         """)
    #         pk_result = db.session.execute(sql_pk, {'table_name': tableName})
    #         for col in pk_result:
    #             name, constraintName, ordinal_position = col
    #             column = table.getColumn(name)
    #             if column:
    #                 column.ispk = True
    #                 column.pkconstraint = constraintName
    #                 # Assuming you want to order PKs, though not directly used in provided class
    
    
    #     # Fetching Foreign Keys for each table
    #     for tableName, table in tables.items():
    #         sql_fk = text("""
    #             SELECT 
    #                 tc.constraint_name, 
    #                 tc.table_name AS fk_table_name, 
    #                 kcu.column_name AS fk_column_name, 
    #                 ccu.table_name AS pk_table_name, 
    #                 ccu.column_name AS pk_column_name,
    #                 ccu.table_schema AS pk_table_schema
    #             FROM 
    #                 information_schema.table_constraints AS tc 
    #                 JOIN information_schema.key_column_usage AS kcu 
    #                 ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema AND tc.table_name = kcu.table_name
    #                 JOIN information_schema.constraint_column_usage AS ccu 
    #                 ON ccu.constraint_name = tc.constraint_name
    #             WHERE 
    #                 tc.constraint_type = 'FOREIGN KEY' 
    #                 AND tc.table_name = :table_name
    #         """)
    
    #         fk_result = db.session.execute(sql_fk, {'table_name': tableName})
    #         for row in fk_result:
    #             constraintName, fkTableName, fkColumnName, pkTableName, pkColumnName, pkTableSchema = row
    
    #             # Ensure the foreign key table is the current table being processed
    #             if fkTableName != tableName:
    #                 continue
    
    
    #             fkTable = tables.get(fkTableName)
    #             pkTable = tables.get(pkTableName)
    
    #             if fkTable and pkTable:
    #                 fkColumn = fkTable.getColumn(fkColumnName)
    #                 pkColumn = pkTable.getColumn(pkColumnName)
    
    #                 if fkColumn and pkColumn:
    #                     # Here, instead of assigning pkColumn directly, store relevant info
    #                     fkColumn.fkof = pkColumn  # Adjust based on your application's needs
    #                     if constraintName not in fkTable.fks:
    #                         fkTable.fks[constraintName] = []
    #                     fkTable.fks[constraintName].append(fkColumn)
    #     return tables
    
    
    def createGraph(tables, theme, showColumns, showTypes, useUpperCase):
        s = ('digraph {\n'
            + '  graph [ rankdir="LR" bgcolor="#ffffff" ]\n'
            + f'  node [ style="filled" shape="{theme.shape}" gradientangle="180" ]\n'
            + '  edge [ arrowhead="none" arrowtail="none" dir="both" ]\n\n')
    
        for name in tables:
            s += tables[name].getDotShape(theme, showColumns, showTypes, useUpperCase)
        s += "\n"
        for name in tables:
            s += tables[name].getDotLinks(theme)
        s += "}\n"
        return s
    
    
    def generate_erd(graph_DOT):
        graph_module = pydot.graph_from_dot_data(graph_DOT)
        graph = graph_module[0]
        png_image_data = graph.create_png()
        encoded_image = base64.b64encode(png_image_data).decode('utf-8')
        return encoded_image
    
    
    def getThemes():
        return {
            "Common Gray": Theme("#6c6c6c", "#e0e0e0", "#f5f5f5",
                "#e0e0e0", "#000000", "#000000", "rounded", "Mrecord", "#696969", "1"),
            "Blue Navy": Theme("#1a5282", "#1a5282", "#ffffff",
                "#1a5282", "#000000", "#ffffff", "rounded", "Mrecord", "#0078d7", "2"),
            #"Gradient Green": Theme("#716f64", "#008080:#ffffff", "#008080:#ffffff",
            #    "transparent", "#000000", "#000000", "rounded", "Mrecord", "#696969", "1"),
            #"Blue Sky": Theme("#716f64", "#d3dcef:#ffffff", "#d3dcef:#ffffff",
            #    "transparent", "#000000", "#000000", "rounded", "Mrecord", "#696969", "1"),
            "Common Gray Box": Theme("#6c6c6c", "#e0e0e0", "#f5f5f5",
                "#e0e0e0", "#000000", "#000000", "rounded", "record", "#696969", "1")
        }
    
    
    if __name__ == "__main__":
        app.run(debug=True)