diff --git a/eralchemy/main.py b/eralchemy/main.py index d962617a79a1a29ef65fd1cfe195e365df5da4c3..af71246f9dd48d4ed8149cb901084f4707d907da 100644 --- a/eralchemy/main.py +++ b/eralchemy/main.py @@ -332,10 +332,10 @@ def filter_resources( _relationships = [ r for r in _relationships - if not exclude_tables_re.fullmatch(r.right_col) - and not exclude_tables_re.fullmatch(r.left_col) - and include_tables_re.fullmatch(r.right_col) - and include_tables_re.fullmatch(r.left_col) + if not exclude_tables_re.fullmatch(r.right_table) + and not exclude_tables_re.fullmatch(r.left_table) + and include_tables_re.fullmatch(r.right_table) + and include_tables_re.fullmatch(r.left_table) ] def check_column(name): diff --git a/eralchemy/models.py b/eralchemy/models.py index d6682638b4a3cee67d20eadaa53aacabd08be94d..b54b36e9c64950d92454cb57bdf38d797cf9f2f3 100644 --- a/eralchemy/models.py +++ b/eralchemy/models.py @@ -113,10 +113,11 @@ class Column(Drawable): def to_dot(self) -> str: base = ROW_TAGS.format( - ' ALIGN="LEFT"', + ' ALIGN="LEFT" {port}', "{key_opening}{col_name}{key_closing} {type}{null}", ) return base.format( + port=f'PORT="{self.name}"' if self.name else "", key_opening="<u>" if self.is_key else "", key_closing="</u>" if self.is_key else "", col_name=FONT_TAGS.format(self.name), @@ -129,7 +130,19 @@ class Relation(Drawable): """Represents a Relation in the intermediaty syntax.""" RE = re.compile( - r"(?P<left_name>\S+(\s*\S+)?)\s+(?P<left_cardinality>[*?+1])--(?P<right_cardinality>[*?+1])\s*(?P<right_name>\S+(\s*\S+)?)", + r""" + (?P<left_table>[^\s]+?) + (?:\.\"(?P<left_column>.+)\")? + \s* + (?P<left_cardinality>[*?+1]) + -- + (?P<right_cardinality>[*?+1]) + \s* + (?P<right_table>[^\s]+?) + (?:\.\"(?P<right_column>.+)\")? + \s*$ + """, + re.VERBOSE, ) cardinalities = {"*": "0..N", "?": "{0,1}", "+": "1..N", "1": "1", "": None} cardinalities_mermaid = { @@ -145,41 +158,47 @@ class Relation(Drawable): @staticmethod def make_from_match(match: re.Match) -> Relation: - return Relation( - right_col=match.group("right_name"), - left_col=match.group("left_name"), - right_cardinality=match.group("right_cardinality"), - left_cardinality=match.group("left_cardinality"), - ) + return Relation(**match.groupdict()) def __init__( self, - right_col, - left_col, + right_table, + left_table, right_cardinality=None, left_cardinality=None, + right_column=None, + left_column=None, ): if ( right_cardinality not in self.cardinalities.keys() or left_cardinality not in self.cardinalities.keys() ): raise ValueError(f"Cardinality should be in {self.cardinalities.keys()}") - self.right_col = right_col - self.left_col = left_col + self.right_table = right_table + self.right_column = right_column or "" + self.left_table = left_table + self.left_column = left_column or "" self.right_cardinality = right_cardinality self.left_cardinality = left_cardinality def to_markdown(self) -> str: - return f"{self.left_col} {self.left_cardinality}--{self.right_cardinality} {self.right_col}" + return "{}{} {}--{} {}{}".format( + self.left_table, + "" if not self.left_column else f'."{self.left_column}"', + self.left_cardinality, + self.right_cardinality, + self.right_table, + "" if not self.right_column else f'."{self.right_column}"', + ) def to_mermaid(self) -> str: normalized = ( Relation.cardinalities_mermaid.get(k, k) for k in ( - sanitize_mermaid(self.left_col), + sanitize_mermaid(self.left_table), self.left_cardinality, self.right_cardinality, - sanitize_mermaid(self.right_col), + sanitize_mermaid(self.right_table), ) ) return '{} "{}" -- "{}" {}'.format(*normalized) @@ -187,15 +206,13 @@ class Relation(Drawable): def to_mermaid_er(self) -> str: left = Relation.cardinalities_crowfoot.get( self.left_cardinality, - self.left_cardinality, ) right = Relation.cardinalities_crowfoot.get( self.right_cardinality, - self.right_cardinality, ) - left_col = sanitize_mermaid(self.left_col, is_er=True) - right_col = sanitize_mermaid(self.right_col, is_er=True) + left_col = sanitize_mermaid(self.left_table, is_er=True) + right_col = sanitize_mermaid(self.right_table, is_er=True) return f"{left_col} {left}--{right} {right_col} : has" def graphviz_cardinalities(self, card) -> str: @@ -211,10 +228,10 @@ class Relation(Drawable): cards.append("tail" + self.graphviz_cardinalities(self.left_cardinality)) if self.right_cardinality != "": cards.append("head" + self.graphviz_cardinalities(self.right_cardinality)) - return '"{}" -- "{}" [{}];'.format( - self.left_col, - self.right_col, - ",".join(cards), + left_col = f':"{self.left_column}"' if self.left_column else "" + right_col = f':"{self.right_column}"' if self.right_column else "" + return ( + f'"{self.left_table}"{left_col} -- "{self.right_table}"{right_col} [{",".join(cards)}];' ) def __eq__(self, other: object) -> bool: @@ -223,8 +240,10 @@ class Relation(Drawable): if not isinstance(other, Relation): return False other_inversed = Relation( - right_col=other.left_col, - left_col=other.right_col, + right_table=other.left_table, + right_column=other.left_column, + left_table=other.right_table, + left_column=other.right_column, right_cardinality=other.left_cardinality, left_cardinality=other.right_cardinality, ) diff --git a/eralchemy/parser.py b/eralchemy/parser.py index 5e67667f6ce60acfbfd347ee8aa82e261ca65dcc..43831a8469be6fca664d302877eb522cfe065a23 100644 --- a/eralchemy/parser.py +++ b/eralchemy/parser.py @@ -124,8 +124,8 @@ def update_models( if isinstance(new_obj, Relation): tables_names = [t.name for t in tables] - _check_colname_in_lst(new_obj.right_col, tables_names) - _check_colname_in_lst(new_obj.left_col, tables_names) + _check_colname_in_lst(new_obj.right_table, tables_names) + _check_colname_in_lst(new_obj.left_table, tables_names) return current_table, tables, relations + [new_obj] if isinstance(new_obj, Column): diff --git a/eralchemy/sqla.py b/eralchemy/sqla.py index 6f5209fa2017a723bfa00e39fda6312e0f4e7328..afda6d338b714a2f8e6da7293adef6b873b21b27 100644 --- a/eralchemy/sqla.py +++ b/eralchemy/sqla.py @@ -47,8 +47,10 @@ def relation_to_intermediary(fk: sa.ForeignKey) -> Relation: # if this is the case, we are not optional and must be unique right_cardinality = "1" if check_all_compound_same_parent(fk) else "*" return Relation( - right_col=format_name(fk.parent.table.fullname), - left_col=format_name(fk.column.table.fullname), + right_table=format_name(fk.parent.table.fullname), + right_column=format_name(fk.parent.name), + left_table=format_name(fk.column.table.fullname), + left_column=format_name(fk.column.name), right_cardinality=right_cardinality, left_cardinality="?" if fk.parent.nullable else "1", ) diff --git a/tests/common.py b/tests/common.py index 40b9dcbd0bdbba8b16314040f3b4c76aea1bced2..be3cebd1c37d030093f39af6c1d7d37fb955ddc8 100644 --- a/tests/common.py +++ b/tests/common.py @@ -68,10 +68,12 @@ child_parent_id = ERColumn( ) relation = Relation( - right_col="parent", - left_col="child", - right_cardinality="?", + left_table="child", + left_column="parent_id", + right_table="parent", + right_column="id", left_cardinality="*", + right_cardinality="?", ) exclude_id = ERColumn(name="id", type="INTEGER", is_key=True) @@ -82,8 +84,10 @@ exclude_parent_id = ERColumn( ) exclude_relation = Relation( - right_col="parent", - left_col="exclude", + right_table="parent", + right_column="id", + left_table="exclude", + left_column="parent_id", right_cardinality="?", left_cardinality="*", ) @@ -117,8 +121,8 @@ markdown = """ [exclude] *id {label:"INTEGER"} parent_id {label:"INTEGER"} - parent ?--* child - parent ?--* exclude + child."parent_id" *--? parent."id" + exclude."parent_id" *--? parent."id" """ diff --git a/tests/test_intermediary_to_dot.py b/tests/test_intermediary_to_dot.py index ea4261af0704b97f22516fe56b38b05c4bf9f852..7695edff205081c3026da4a889e1605a5d0da043 100644 --- a/tests/test_intermediary_to_dot.py +++ b/tests/test_intermediary_to_dot.py @@ -5,7 +5,6 @@ from multiprocessing import Process import pytest from pygraphviz import AGraph -from eralchemy.cst import DOT_GRAPH_BEGINNING from eralchemy.main import _intermediary_to_dot from tests.common import ( child, @@ -17,17 +16,13 @@ from tests.common import ( relation, ) -GRAPH_LAYOUT = DOT_GRAPH_BEGINNING + "%s }" -column_re = re.compile( - '\\<TR\\>\\<TD\\ ALIGN\\=\\"LEFT\\"\\>(.*)\\<\\/TD\\>\\<\\/TR\\>', -) +column_re = re.compile(r"\<TR\>\<TD\ ALIGN\=\"LEFT\"\ PORT\=\".+\">(.*)\<\/TD\>\<\/TR\>") header_re = re.compile( - '\\<TR\\>\\<TD\\>\\<B\\>\\<FONT\\ POINT\\-SIZE\\=\\"16\\"\\>(.*)' - "\\<\\/FONT\\>\\<\\/B\\>\\<\\/TD\\>\\<\\/TR\\>", + r"\<TR\>\<TD\>\<B\>\<FONT\ POINT\-SIZE\=\"16\"\>(.*)" r"\<\/FONT\>\<\/B\>\<\/TD\>\<\/TR\>" ) column_inside = re.compile( - "(?P<key_opening>.*)\\<FONT\\>(?P<name>.*)\\<\\/FONT\\>" - "(?P<key_closing>.*)\\ <FONT\\>\\ \\[(?P<type>.*)\\]\\<\\/FONT\\>", + r"(?P<key_opening>.*)\<FONT\>(?P<name>.*)\<\/FONT\>" + r"(?P<key_closing>.*)\<FONT\>\ \[(?P<type>.*)\]\<\/FONT\>" ) @@ -79,7 +74,7 @@ def assert_column_well_rendered_to_dot(col): col_parsed = column_inside.match(col_no_table[0]) assert col_parsed.group("key_opening") == ("<u>" if col.is_key else "") assert col_parsed.group("name") == col.name - assert col_parsed.group("key_closing") == ("</u>" if col.is_key else "") + assert col_parsed.group("key_closing") == ("</u> " if col.is_key else " ") assert col_parsed.group("type") == col.type @@ -92,14 +87,16 @@ def test_column_is_dot_format(): def test_relation(): relation_re = re.compile( - '\\"(?P<l_name>.+)\\"\\ \\-\\-\\ \\"(?P<r_name>.+)\\"\\ ' - "\\[taillabel\\=\\<\\<FONT\\>(?P<l_card>.+)\\<\\/FONT\\>\\>" - "\\,headlabel\\=\\<\\<FONT\\>(?P<r_card>.+)\\<\\/FONT\\>\\>\\]\\;", + r"\"(?P<l_table>.+)\":\"(?P<l_column>.+)\"\ \-\-\ \"(?P<r_table>.+)\":\"(?P<r_column>.+)\"\ " + r"\[taillabel\=\<\<FONT\>(?P<l_card>.+)\<\/FONT\>\>" + r"\,headlabel\=\<\<FONT\>(?P<r_card>.+)\<\/FONT\>\>\]\;" ) dot = relation.to_dot() r = relation_re.match(dot) - assert r.group("l_name") == "child" - assert r.group("r_name") == "parent" + assert r.group("l_table") == "child" + assert r.group("l_column") == "parent_id" + assert r.group("r_table") == "parent" + assert r.group("r_column") == "id" assert r.group("l_card") == "0..N" assert r.group("r_card") == "{0,1}" diff --git a/tests/test_intermediary_to_er.py b/tests/test_intermediary_to_er.py index 6d574c5ddbb48890e329d72c5ace77e6d176a4b9..193f28fe9886d36963f195ed5c40ba49a07a80d1 100644 --- a/tests/test_intermediary_to_er.py +++ b/tests/test_intermediary_to_er.py @@ -38,7 +38,10 @@ def test_column_to_er(): def test_relation(): - assert relation.to_markdown() in ["parent ?--* child", "child *--? parent"] + assert relation.to_markdown() in [ + 'parent."id" *--? child."parent_id"', + 'child."parent_id" *--? parent."id"', + ] def assert_table_well_rendered_to_er(table): diff --git a/tests/test_parser.py b/tests/test_parser.py index e4ac6e3e749b0326fb2d0f3be901d373284fa6a2..efede7967c0bcdb31b9aa6940d9da201c9af77a5 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -73,9 +73,11 @@ def test_parse_line(): assert isinstance(rv, Column) for s in relations_lst: - rv = parse_line(s) - assert rv.right_col == s[16:].strip() - assert rv.left_col == s[:12].strip() + rv = parse_line(s) # type: Relation + assert rv.right_table == s[16:].strip() + assert rv.right_column == "" + assert rv.left_table == s[:12].strip() + assert rv.left_column == "" assert rv.right_cardinality == s[15] assert rv.left_cardinality == s[12] assert isinstance(rv, Relation) diff --git a/tests/test_sqla_to_intermediary.py b/tests/test_sqla_to_intermediary.py index 8566e838e887247119fd852cad188bcfde9d28f7..3d9c2950a07938789185c2961fa7eb7a969c11ac 100644 --- a/tests/test_sqla_to_intermediary.py +++ b/tests/test_sqla_to_intermediary.py @@ -121,13 +121,13 @@ def test_table_names_in_relationships(pg_db_uri): table_names = [t.name for t in tables] # Assert column names are table names - assert all(r.right_col in table_names for r in relationships) - assert all(r.left_col in table_names for r in relationships) + assert all(r.right_table in table_names for r in relationships) + assert all(r.left_table in table_names for r in relationships) # Assert column names match table names for r in relationships: - r_name = table_names[table_names.index(r.right_col)] - l_name = table_names[table_names.index(r.left_col)] + r_name = table_names[table_names.index(r.right_table)] + l_name = table_names[table_names.index(r.left_table)] # Table name in relationship should *NOT* have a schema assert r_name.find(".") == -1 @@ -142,8 +142,8 @@ def test_table_names_in_relationships_with_schema(pg_db_uri): table_names = [t.name for t in tables] # Assert column names match table names, including schema - assert all(r.right_col in table_names for r in relationships) - assert all(r.left_col in table_names for r in relationships) + assert all(r.right_table in table_names for r in relationships) + assert all(r.left_table in table_names for r in relationships) # Assert column names match table names, including schema for r in relationships: