From 8fa193e1fcd6c46528391d846b282c20ec3d46c2 Mon Sep 17 00:00:00 2001 From: Florian Maurer <f.maurer@outlook.de> Date: Tue, 17 Sep 2024 01:37:02 +0200 Subject: [PATCH] relations between tables Draw relations between specific columns rather than just between the tables (because when there are multiple relations between two tables, that ends up looking very messy). --- eralchemy/main.py | 8 ++-- eralchemy/models.py | 69 +++++++++++++++++++----------- eralchemy/parser.py | 4 +- eralchemy/sqla.py | 6 ++- tests/common.py | 18 +++++--- tests/test_intermediary_to_dot.py | 27 ++++++------ tests/test_intermediary_to_er.py | 5 ++- tests/test_parser.py | 8 ++-- tests/test_sqla_to_intermediary.py | 14 +++--- 9 files changed, 94 insertions(+), 65 deletions(-) diff --git a/eralchemy/main.py b/eralchemy/main.py index b6b874a..ce8bc55 100644 --- a/eralchemy/main.py +++ b/eralchemy/main.py @@ -309,10 +309,10 @@ def filter_resources( _relationships = [ r for r in _relationships - if not exclude_tables_re.fullmatch(r.right_col) - and not exclude_tables_re.fullmatch(r.left_col) - and include_tables_re.fullmatch(r.right_col) - and include_tables_re.fullmatch(r.left_col) + if not exclude_tables_re.fullmatch(r.right_table) + and not exclude_tables_re.fullmatch(r.left_table) + and include_tables_re.fullmatch(r.right_table) + and include_tables_re.fullmatch(r.left_table) ] def check_column(name): diff --git a/eralchemy/models.py b/eralchemy/models.py index d668263..b54b36e 100644 --- a/eralchemy/models.py +++ b/eralchemy/models.py @@ -113,10 +113,11 @@ class Column(Drawable): def to_dot(self) -> str: base = ROW_TAGS.format( - ' ALIGN="LEFT"', + ' ALIGN="LEFT" {port}', "{key_opening}{col_name}{key_closing} {type}{null}", ) return base.format( + port=f'PORT="{self.name}"' if self.name else "", key_opening="<u>" if self.is_key else "", key_closing="</u>" if self.is_key else "", col_name=FONT_TAGS.format(self.name), @@ -129,7 +130,19 @@ class Relation(Drawable): """Represents a Relation in the intermediaty syntax.""" RE = re.compile( - r"(?P<left_name>\S+(\s*\S+)?)\s+(?P<left_cardinality>[*?+1])--(?P<right_cardinality>[*?+1])\s*(?P<right_name>\S+(\s*\S+)?)", + r""" + (?P<left_table>[^\s]+?) + (?:\.\"(?P<left_column>.+)\")? + \s* + (?P<left_cardinality>[*?+1]) + -- + (?P<right_cardinality>[*?+1]) + \s* + (?P<right_table>[^\s]+?) + (?:\.\"(?P<right_column>.+)\")? + \s*$ + """, + re.VERBOSE, ) cardinalities = {"*": "0..N", "?": "{0,1}", "+": "1..N", "1": "1", "": None} cardinalities_mermaid = { @@ -145,41 +158,47 @@ class Relation(Drawable): @staticmethod def make_from_match(match: re.Match) -> Relation: - return Relation( - right_col=match.group("right_name"), - left_col=match.group("left_name"), - right_cardinality=match.group("right_cardinality"), - left_cardinality=match.group("left_cardinality"), - ) + return Relation(**match.groupdict()) def __init__( self, - right_col, - left_col, + right_table, + left_table, right_cardinality=None, left_cardinality=None, + right_column=None, + left_column=None, ): if ( right_cardinality not in self.cardinalities.keys() or left_cardinality not in self.cardinalities.keys() ): raise ValueError(f"Cardinality should be in {self.cardinalities.keys()}") - self.right_col = right_col - self.left_col = left_col + self.right_table = right_table + self.right_column = right_column or "" + self.left_table = left_table + self.left_column = left_column or "" self.right_cardinality = right_cardinality self.left_cardinality = left_cardinality def to_markdown(self) -> str: - return f"{self.left_col} {self.left_cardinality}--{self.right_cardinality} {self.right_col}" + return "{}{} {}--{} {}{}".format( + self.left_table, + "" if not self.left_column else f'."{self.left_column}"', + self.left_cardinality, + self.right_cardinality, + self.right_table, + "" if not self.right_column else f'."{self.right_column}"', + ) def to_mermaid(self) -> str: normalized = ( Relation.cardinalities_mermaid.get(k, k) for k in ( - sanitize_mermaid(self.left_col), + sanitize_mermaid(self.left_table), self.left_cardinality, self.right_cardinality, - sanitize_mermaid(self.right_col), + sanitize_mermaid(self.right_table), ) ) return '{} "{}" -- "{}" {}'.format(*normalized) @@ -187,15 +206,13 @@ class Relation(Drawable): def to_mermaid_er(self) -> str: left = Relation.cardinalities_crowfoot.get( self.left_cardinality, - self.left_cardinality, ) right = Relation.cardinalities_crowfoot.get( self.right_cardinality, - self.right_cardinality, ) - left_col = sanitize_mermaid(self.left_col, is_er=True) - right_col = sanitize_mermaid(self.right_col, is_er=True) + left_col = sanitize_mermaid(self.left_table, is_er=True) + right_col = sanitize_mermaid(self.right_table, is_er=True) return f"{left_col} {left}--{right} {right_col} : has" def graphviz_cardinalities(self, card) -> str: @@ -211,10 +228,10 @@ class Relation(Drawable): cards.append("tail" + self.graphviz_cardinalities(self.left_cardinality)) if self.right_cardinality != "": cards.append("head" + self.graphviz_cardinalities(self.right_cardinality)) - return '"{}" -- "{}" [{}];'.format( - self.left_col, - self.right_col, - ",".join(cards), + left_col = f':"{self.left_column}"' if self.left_column else "" + right_col = f':"{self.right_column}"' if self.right_column else "" + return ( + f'"{self.left_table}"{left_col} -- "{self.right_table}"{right_col} [{",".join(cards)}];' ) def __eq__(self, other: object) -> bool: @@ -223,8 +240,10 @@ class Relation(Drawable): if not isinstance(other, Relation): return False other_inversed = Relation( - right_col=other.left_col, - left_col=other.right_col, + right_table=other.left_table, + right_column=other.left_column, + left_table=other.right_table, + left_column=other.right_column, right_cardinality=other.left_cardinality, left_cardinality=other.right_cardinality, ) diff --git a/eralchemy/parser.py b/eralchemy/parser.py index 5e67667..43831a8 100644 --- a/eralchemy/parser.py +++ b/eralchemy/parser.py @@ -124,8 +124,8 @@ def update_models( if isinstance(new_obj, Relation): tables_names = [t.name for t in tables] - _check_colname_in_lst(new_obj.right_col, tables_names) - _check_colname_in_lst(new_obj.left_col, tables_names) + _check_colname_in_lst(new_obj.right_table, tables_names) + _check_colname_in_lst(new_obj.left_table, tables_names) return current_table, tables, relations + [new_obj] if isinstance(new_obj, Column): diff --git a/eralchemy/sqla.py b/eralchemy/sqla.py index 6f5209f..afda6d3 100644 --- a/eralchemy/sqla.py +++ b/eralchemy/sqla.py @@ -47,8 +47,10 @@ def relation_to_intermediary(fk: sa.ForeignKey) -> Relation: # if this is the case, we are not optional and must be unique right_cardinality = "1" if check_all_compound_same_parent(fk) else "*" return Relation( - right_col=format_name(fk.parent.table.fullname), - left_col=format_name(fk.column.table.fullname), + right_table=format_name(fk.parent.table.fullname), + right_column=format_name(fk.parent.name), + left_table=format_name(fk.column.table.fullname), + left_column=format_name(fk.column.name), right_cardinality=right_cardinality, left_cardinality="?" if fk.parent.nullable else "1", ) diff --git a/tests/common.py b/tests/common.py index 40b9dcb..be3cebd 100644 --- a/tests/common.py +++ b/tests/common.py @@ -68,10 +68,12 @@ child_parent_id = ERColumn( ) relation = Relation( - right_col="parent", - left_col="child", - right_cardinality="?", + left_table="child", + left_column="parent_id", + right_table="parent", + right_column="id", left_cardinality="*", + right_cardinality="?", ) exclude_id = ERColumn(name="id", type="INTEGER", is_key=True) @@ -82,8 +84,10 @@ exclude_parent_id = ERColumn( ) exclude_relation = Relation( - right_col="parent", - left_col="exclude", + right_table="parent", + right_column="id", + left_table="exclude", + left_column="parent_id", right_cardinality="?", left_cardinality="*", ) @@ -117,8 +121,8 @@ markdown = """ [exclude] *id {label:"INTEGER"} parent_id {label:"INTEGER"} - parent ?--* child - parent ?--* exclude + child."parent_id" *--? parent."id" + exclude."parent_id" *--? parent."id" """ diff --git a/tests/test_intermediary_to_dot.py b/tests/test_intermediary_to_dot.py index 19e411a..1eed503 100644 --- a/tests/test_intermediary_to_dot.py +++ b/tests/test_intermediary_to_dot.py @@ -5,7 +5,6 @@ from multiprocessing import Process import pytest from pygraphviz import AGraph -from eralchemy.cst import DOT_GRAPH_BEGINNING from eralchemy.main import _intermediary_to_dot from tests.common import ( child, @@ -17,17 +16,13 @@ from tests.common import ( relation, ) -GRAPH_LAYOUT = DOT_GRAPH_BEGINNING + "%s }" -column_re = re.compile( - '\\<TR\\>\\<TD\\ ALIGN\\=\\"LEFT\\"\\>(.*)\\<\\/TD\\>\\<\\/TR\\>', -) +column_re = re.compile(r"\<TR\>\<TD\ ALIGN\=\"LEFT\"\ PORT\=\".+\">(.*)\<\/TD\>\<\/TR\>") header_re = re.compile( - '\\<TR\\>\\<TD\\>\\<B\\>\\<FONT\\ POINT\\-SIZE\\=\\"16\\"\\>(.*)' - "\\<\\/FONT\\>\\<\\/B\\>\\<\\/TD\\>\\<\\/TR\\>", + r"\<TR\>\<TD\>\<B\>\<FONT\ POINT\-SIZE\=\"16\"\>(.*)" r"\<\/FONT\>\<\/B\>\<\/TD\>\<\/TR\>" ) column_inside = re.compile( - "(?P<key_opening>.*)\\<FONT\\>(?P<name>.*)\\<\\/FONT\\>" - "(?P<key_closing>.*)\\ <FONT\\>\\ \\[(?P<type>.*)\\]\\<\\/FONT\\>", + r"(?P<key_opening>.*)\<FONT\>(?P<name>.*)\<\/FONT\>" + r"(?P<key_closing>.*)\<FONT\>\ \[(?P<type>.*)\]\<\/FONT\>" ) @@ -78,7 +73,7 @@ def assert_column_well_rendered_to_dot(col): col_parsed = column_inside.match(col_no_table[0]) assert col_parsed.group("key_opening") == ("<u>" if col.is_key else "") assert col_parsed.group("name") == col.name - assert col_parsed.group("key_closing") == ("</u>" if col.is_key else "") + assert col_parsed.group("key_closing") == ("</u> " if col.is_key else " ") assert col_parsed.group("type") == col.type @@ -91,14 +86,16 @@ def test_column_is_dot_format(): def test_relation(): relation_re = re.compile( - '\\"(?P<l_name>.+)\\"\\ \\-\\-\\ \\"(?P<r_name>.+)\\"\\ ' - "\\[taillabel\\=\\<\\<FONT\\>(?P<l_card>.+)\\<\\/FONT\\>\\>" - "\\,headlabel\\=\\<\\<FONT\\>(?P<r_card>.+)\\<\\/FONT\\>\\>\\]\\;", + r"\"(?P<l_table>.+)\":\"(?P<l_column>.+)\"\ \-\-\ \"(?P<r_table>.+)\":\"(?P<r_column>.+)\"\ " + r"\[taillabel\=\<\<FONT\>(?P<l_card>.+)\<\/FONT\>\>" + r"\,headlabel\=\<\<FONT\>(?P<r_card>.+)\<\/FONT\>\>\]\;" ) dot = relation.to_dot() r = relation_re.match(dot) - assert r.group("l_name") == "child" - assert r.group("r_name") == "parent" + assert r.group("l_table") == "child" + assert r.group("l_column") == "parent_id" + assert r.group("r_table") == "parent" + assert r.group("r_column") == "id" assert r.group("l_card") == "0..N" assert r.group("r_card") == "{0,1}" diff --git a/tests/test_intermediary_to_er.py b/tests/test_intermediary_to_er.py index 6d574c5..193f28f 100644 --- a/tests/test_intermediary_to_er.py +++ b/tests/test_intermediary_to_er.py @@ -38,7 +38,10 @@ def test_column_to_er(): def test_relation(): - assert relation.to_markdown() in ["parent ?--* child", "child *--? parent"] + assert relation.to_markdown() in [ + 'parent."id" *--? child."parent_id"', + 'child."parent_id" *--? parent."id"', + ] def assert_table_well_rendered_to_er(table): diff --git a/tests/test_parser.py b/tests/test_parser.py index e4ac6e3..efede79 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -73,9 +73,11 @@ def test_parse_line(): assert isinstance(rv, Column) for s in relations_lst: - rv = parse_line(s) - assert rv.right_col == s[16:].strip() - assert rv.left_col == s[:12].strip() + rv = parse_line(s) # type: Relation + assert rv.right_table == s[16:].strip() + assert rv.right_column == "" + assert rv.left_table == s[:12].strip() + assert rv.left_column == "" assert rv.right_cardinality == s[15] assert rv.left_cardinality == s[12] assert isinstance(rv, Relation) diff --git a/tests/test_sqla_to_intermediary.py b/tests/test_sqla_to_intermediary.py index 8d52309..3d9c295 100644 --- a/tests/test_sqla_to_intermediary.py +++ b/tests/test_sqla_to_intermediary.py @@ -115,24 +115,26 @@ def test_flask_sqlalchemy(): check_intermediary_representation_simple_all_table(tables, relationships) +@pytest.mark.external_db def test_table_names_in_relationships(pg_db_uri): tables, relationships = database_to_intermediary(pg_db_uri) table_names = [t.name for t in tables] # Assert column names are table names - assert all(r.right_col in table_names for r in relationships) - assert all(r.left_col in table_names for r in relationships) + assert all(r.right_table in table_names for r in relationships) + assert all(r.left_table in table_names for r in relationships) # Assert column names match table names for r in relationships: - r_name = table_names[table_names.index(r.right_col)] - l_name = table_names[table_names.index(r.left_col)] + r_name = table_names[table_names.index(r.right_table)] + l_name = table_names[table_names.index(r.left_table)] # Table name in relationship should *NOT* have a schema assert r_name.find(".") == -1 assert l_name.find(".") == -1 +@pytest.mark.external_db def test_table_names_in_relationships_with_schema(pg_db_uri): schema_name = "test" matcher = re.compile(rf"{schema_name}\.[\S+]", re.I) @@ -140,8 +142,8 @@ def test_table_names_in_relationships_with_schema(pg_db_uri): table_names = [t.name for t in tables] # Assert column names match table names, including schema - assert all(r.right_col in table_names for r in relationships) - assert all(r.left_col in table_names for r in relationships) + assert all(r.right_table in table_names for r in relationships) + assert all(r.left_table in table_names for r in relationships) # Assert column names match table names, including schema for r in relationships: -- GitLab