Skip to content
Snippets Groups Projects
Unverified Commit 8fa193e1 authored by Florian Maurer's avatar Florian Maurer
Browse files

relations between tables

Draw relations between specific columns rather than just between the tables (because when there are multiple
relations between two tables, that ends up looking very messy).
parent ec80be3b
Branches
Tags
No related merge requests found
...@@ -309,10 +309,10 @@ def filter_resources( ...@@ -309,10 +309,10 @@ def filter_resources(
_relationships = [ _relationships = [
r r
for r in _relationships for r in _relationships
if not exclude_tables_re.fullmatch(r.right_col) if not exclude_tables_re.fullmatch(r.right_table)
and not exclude_tables_re.fullmatch(r.left_col) and not exclude_tables_re.fullmatch(r.left_table)
and include_tables_re.fullmatch(r.right_col) and include_tables_re.fullmatch(r.right_table)
and include_tables_re.fullmatch(r.left_col) and include_tables_re.fullmatch(r.left_table)
] ]
def check_column(name): def check_column(name):
......
...@@ -113,10 +113,11 @@ class Column(Drawable): ...@@ -113,10 +113,11 @@ class Column(Drawable):
def to_dot(self) -> str: def to_dot(self) -> str:
base = ROW_TAGS.format( base = ROW_TAGS.format(
' ALIGN="LEFT"', ' ALIGN="LEFT" {port}',
"{key_opening}{col_name}{key_closing} {type}{null}", "{key_opening}{col_name}{key_closing} {type}{null}",
) )
return base.format( return base.format(
port=f'PORT="{self.name}"' if self.name else "",
key_opening="<u>" if self.is_key else "", key_opening="<u>" if self.is_key else "",
key_closing="</u>" if self.is_key else "", key_closing="</u>" if self.is_key else "",
col_name=FONT_TAGS.format(self.name), col_name=FONT_TAGS.format(self.name),
...@@ -129,7 +130,19 @@ class Relation(Drawable): ...@@ -129,7 +130,19 @@ class Relation(Drawable):
"""Represents a Relation in the intermediaty syntax.""" """Represents a Relation in the intermediaty syntax."""
RE = re.compile( RE = re.compile(
r"(?P<left_name>\S+(\s*\S+)?)\s+(?P<left_cardinality>[*?+1])--(?P<right_cardinality>[*?+1])\s*(?P<right_name>\S+(\s*\S+)?)", r"""
(?P<left_table>[^\s]+?)
(?:\.\"(?P<left_column>.+)\")?
\s*
(?P<left_cardinality>[*?+1])
--
(?P<right_cardinality>[*?+1])
\s*
(?P<right_table>[^\s]+?)
(?:\.\"(?P<right_column>.+)\")?
\s*$
""",
re.VERBOSE,
) )
cardinalities = {"*": "0..N", "?": "{0,1}", "+": "1..N", "1": "1", "": None} cardinalities = {"*": "0..N", "?": "{0,1}", "+": "1..N", "1": "1", "": None}
cardinalities_mermaid = { cardinalities_mermaid = {
...@@ -145,41 +158,47 @@ class Relation(Drawable): ...@@ -145,41 +158,47 @@ class Relation(Drawable):
@staticmethod @staticmethod
def make_from_match(match: re.Match) -> Relation: def make_from_match(match: re.Match) -> Relation:
return Relation( return Relation(**match.groupdict())
right_col=match.group("right_name"),
left_col=match.group("left_name"),
right_cardinality=match.group("right_cardinality"),
left_cardinality=match.group("left_cardinality"),
)
def __init__( def __init__(
self, self,
right_col, right_table,
left_col, left_table,
right_cardinality=None, right_cardinality=None,
left_cardinality=None, left_cardinality=None,
right_column=None,
left_column=None,
): ):
if ( if (
right_cardinality not in self.cardinalities.keys() right_cardinality not in self.cardinalities.keys()
or left_cardinality not in self.cardinalities.keys() or left_cardinality not in self.cardinalities.keys()
): ):
raise ValueError(f"Cardinality should be in {self.cardinalities.keys()}") raise ValueError(f"Cardinality should be in {self.cardinalities.keys()}")
self.right_col = right_col self.right_table = right_table
self.left_col = left_col self.right_column = right_column or ""
self.left_table = left_table
self.left_column = left_column or ""
self.right_cardinality = right_cardinality self.right_cardinality = right_cardinality
self.left_cardinality = left_cardinality self.left_cardinality = left_cardinality
def to_markdown(self) -> str: def to_markdown(self) -> str:
return f"{self.left_col} {self.left_cardinality}--{self.right_cardinality} {self.right_col}" return "{}{} {}--{} {}{}".format(
self.left_table,
"" if not self.left_column else f'."{self.left_column}"',
self.left_cardinality,
self.right_cardinality,
self.right_table,
"" if not self.right_column else f'."{self.right_column}"',
)
def to_mermaid(self) -> str: def to_mermaid(self) -> str:
normalized = ( normalized = (
Relation.cardinalities_mermaid.get(k, k) Relation.cardinalities_mermaid.get(k, k)
for k in ( for k in (
sanitize_mermaid(self.left_col), sanitize_mermaid(self.left_table),
self.left_cardinality, self.left_cardinality,
self.right_cardinality, self.right_cardinality,
sanitize_mermaid(self.right_col), sanitize_mermaid(self.right_table),
) )
) )
return '{} "{}" -- "{}" {}'.format(*normalized) return '{} "{}" -- "{}" {}'.format(*normalized)
...@@ -187,15 +206,13 @@ class Relation(Drawable): ...@@ -187,15 +206,13 @@ class Relation(Drawable):
def to_mermaid_er(self) -> str: def to_mermaid_er(self) -> str:
left = Relation.cardinalities_crowfoot.get( left = Relation.cardinalities_crowfoot.get(
self.left_cardinality, self.left_cardinality,
self.left_cardinality,
) )
right = Relation.cardinalities_crowfoot.get( right = Relation.cardinalities_crowfoot.get(
self.right_cardinality, self.right_cardinality,
self.right_cardinality,
) )
left_col = sanitize_mermaid(self.left_col, is_er=True) left_col = sanitize_mermaid(self.left_table, is_er=True)
right_col = sanitize_mermaid(self.right_col, is_er=True) right_col = sanitize_mermaid(self.right_table, is_er=True)
return f"{left_col} {left}--{right} {right_col} : has" return f"{left_col} {left}--{right} {right_col} : has"
def graphviz_cardinalities(self, card) -> str: def graphviz_cardinalities(self, card) -> str:
...@@ -211,10 +228,10 @@ class Relation(Drawable): ...@@ -211,10 +228,10 @@ class Relation(Drawable):
cards.append("tail" + self.graphviz_cardinalities(self.left_cardinality)) cards.append("tail" + self.graphviz_cardinalities(self.left_cardinality))
if self.right_cardinality != "": if self.right_cardinality != "":
cards.append("head" + self.graphviz_cardinalities(self.right_cardinality)) cards.append("head" + self.graphviz_cardinalities(self.right_cardinality))
return '"{}" -- "{}" [{}];'.format( left_col = f':"{self.left_column}"' if self.left_column else ""
self.left_col, right_col = f':"{self.right_column}"' if self.right_column else ""
self.right_col, return (
",".join(cards), f'"{self.left_table}"{left_col} -- "{self.right_table}"{right_col} [{",".join(cards)}];'
) )
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
...@@ -223,8 +240,10 @@ class Relation(Drawable): ...@@ -223,8 +240,10 @@ class Relation(Drawable):
if not isinstance(other, Relation): if not isinstance(other, Relation):
return False return False
other_inversed = Relation( other_inversed = Relation(
right_col=other.left_col, right_table=other.left_table,
left_col=other.right_col, right_column=other.left_column,
left_table=other.right_table,
left_column=other.right_column,
right_cardinality=other.left_cardinality, right_cardinality=other.left_cardinality,
left_cardinality=other.right_cardinality, left_cardinality=other.right_cardinality,
) )
......
...@@ -124,8 +124,8 @@ def update_models( ...@@ -124,8 +124,8 @@ def update_models(
if isinstance(new_obj, Relation): if isinstance(new_obj, Relation):
tables_names = [t.name for t in tables] tables_names = [t.name for t in tables]
_check_colname_in_lst(new_obj.right_col, tables_names) _check_colname_in_lst(new_obj.right_table, tables_names)
_check_colname_in_lst(new_obj.left_col, tables_names) _check_colname_in_lst(new_obj.left_table, tables_names)
return current_table, tables, relations + [new_obj] return current_table, tables, relations + [new_obj]
if isinstance(new_obj, Column): if isinstance(new_obj, Column):
......
...@@ -47,8 +47,10 @@ def relation_to_intermediary(fk: sa.ForeignKey) -> Relation: ...@@ -47,8 +47,10 @@ def relation_to_intermediary(fk: sa.ForeignKey) -> Relation:
# if this is the case, we are not optional and must be unique # if this is the case, we are not optional and must be unique
right_cardinality = "1" if check_all_compound_same_parent(fk) else "*" right_cardinality = "1" if check_all_compound_same_parent(fk) else "*"
return Relation( return Relation(
right_col=format_name(fk.parent.table.fullname), right_table=format_name(fk.parent.table.fullname),
left_col=format_name(fk.column.table.fullname), right_column=format_name(fk.parent.name),
left_table=format_name(fk.column.table.fullname),
left_column=format_name(fk.column.name),
right_cardinality=right_cardinality, right_cardinality=right_cardinality,
left_cardinality="?" if fk.parent.nullable else "1", left_cardinality="?" if fk.parent.nullable else "1",
) )
......
...@@ -68,10 +68,12 @@ child_parent_id = ERColumn( ...@@ -68,10 +68,12 @@ child_parent_id = ERColumn(
) )
relation = Relation( relation = Relation(
right_col="parent", left_table="child",
left_col="child", left_column="parent_id",
right_cardinality="?", right_table="parent",
right_column="id",
left_cardinality="*", left_cardinality="*",
right_cardinality="?",
) )
exclude_id = ERColumn(name="id", type="INTEGER", is_key=True) exclude_id = ERColumn(name="id", type="INTEGER", is_key=True)
...@@ -82,8 +84,10 @@ exclude_parent_id = ERColumn( ...@@ -82,8 +84,10 @@ exclude_parent_id = ERColumn(
) )
exclude_relation = Relation( exclude_relation = Relation(
right_col="parent", right_table="parent",
left_col="exclude", right_column="id",
left_table="exclude",
left_column="parent_id",
right_cardinality="?", right_cardinality="?",
left_cardinality="*", left_cardinality="*",
) )
...@@ -117,8 +121,8 @@ markdown = """ ...@@ -117,8 +121,8 @@ markdown = """
[exclude] [exclude]
*id {label:"INTEGER"} *id {label:"INTEGER"}
parent_id {label:"INTEGER"} parent_id {label:"INTEGER"}
parent ?--* child child."parent_id" *--? parent."id"
parent ?--* exclude exclude."parent_id" *--? parent."id"
""" """
......
...@@ -5,7 +5,6 @@ from multiprocessing import Process ...@@ -5,7 +5,6 @@ from multiprocessing import Process
import pytest import pytest
from pygraphviz import AGraph from pygraphviz import AGraph
from eralchemy.cst import DOT_GRAPH_BEGINNING
from eralchemy.main import _intermediary_to_dot from eralchemy.main import _intermediary_to_dot
from tests.common import ( from tests.common import (
child, child,
...@@ -17,17 +16,13 @@ from tests.common import ( ...@@ -17,17 +16,13 @@ from tests.common import (
relation, relation,
) )
GRAPH_LAYOUT = DOT_GRAPH_BEGINNING + "%s }" column_re = re.compile(r"\<TR\>\<TD\ ALIGN\=\"LEFT\"\ PORT\=\".+\">(.*)\<\/TD\>\<\/TR\>")
column_re = re.compile(
'\\<TR\\>\\<TD\\ ALIGN\\=\\"LEFT\\"\\>(.*)\\<\\/TD\\>\\<\\/TR\\>',
)
header_re = re.compile( header_re = re.compile(
'\\<TR\\>\\<TD\\>\\<B\\>\\<FONT\\ POINT\\-SIZE\\=\\"16\\"\\>(.*)' r"\<TR\>\<TD\>\<B\>\<FONT\ POINT\-SIZE\=\"16\"\>(.*)" r"\<\/FONT\>\<\/B\>\<\/TD\>\<\/TR\>"
"\\<\\/FONT\\>\\<\\/B\\>\\<\\/TD\\>\\<\\/TR\\>",
) )
column_inside = re.compile( column_inside = re.compile(
"(?P<key_opening>.*)\\<FONT\\>(?P<name>.*)\\<\\/FONT\\>" r"(?P<key_opening>.*)\<FONT\>(?P<name>.*)\<\/FONT\>"
"(?P<key_closing>.*)\\ <FONT\\>\\ \\[(?P<type>.*)\\]\\<\\/FONT\\>", r"(?P<key_closing>.*)\<FONT\>\ \[(?P<type>.*)\]\<\/FONT\>"
) )
...@@ -91,14 +86,16 @@ def test_column_is_dot_format(): ...@@ -91,14 +86,16 @@ def test_column_is_dot_format():
def test_relation(): def test_relation():
relation_re = re.compile( relation_re = re.compile(
'\\"(?P<l_name>.+)\\"\\ \\-\\-\\ \\"(?P<r_name>.+)\\"\\ ' r"\"(?P<l_table>.+)\":\"(?P<l_column>.+)\"\ \-\-\ \"(?P<r_table>.+)\":\"(?P<r_column>.+)\"\ "
"\\[taillabel\\=\\<\\<FONT\\>(?P<l_card>.+)\\<\\/FONT\\>\\>" r"\[taillabel\=\<\<FONT\>(?P<l_card>.+)\<\/FONT\>\>"
"\\,headlabel\\=\\<\\<FONT\\>(?P<r_card>.+)\\<\\/FONT\\>\\>\\]\\;", r"\,headlabel\=\<\<FONT\>(?P<r_card>.+)\<\/FONT\>\>\]\;"
) )
dot = relation.to_dot() dot = relation.to_dot()
r = relation_re.match(dot) r = relation_re.match(dot)
assert r.group("l_name") == "child" assert r.group("l_table") == "child"
assert r.group("r_name") == "parent" assert r.group("l_column") == "parent_id"
assert r.group("r_table") == "parent"
assert r.group("r_column") == "id"
assert r.group("l_card") == "0..N" assert r.group("l_card") == "0..N"
assert r.group("r_card") == "{0,1}" assert r.group("r_card") == "{0,1}"
......
...@@ -38,7 +38,10 @@ def test_column_to_er(): ...@@ -38,7 +38,10 @@ def test_column_to_er():
def test_relation(): def test_relation():
assert relation.to_markdown() in ["parent ?--* child", "child *--? parent"] assert relation.to_markdown() in [
'parent."id" *--? child."parent_id"',
'child."parent_id" *--? parent."id"',
]
def assert_table_well_rendered_to_er(table): def assert_table_well_rendered_to_er(table):
......
...@@ -73,9 +73,11 @@ def test_parse_line(): ...@@ -73,9 +73,11 @@ def test_parse_line():
assert isinstance(rv, Column) assert isinstance(rv, Column)
for s in relations_lst: for s in relations_lst:
rv = parse_line(s) rv = parse_line(s) # type: Relation
assert rv.right_col == s[16:].strip() assert rv.right_table == s[16:].strip()
assert rv.left_col == s[:12].strip() assert rv.right_column == ""
assert rv.left_table == s[:12].strip()
assert rv.left_column == ""
assert rv.right_cardinality == s[15] assert rv.right_cardinality == s[15]
assert rv.left_cardinality == s[12] assert rv.left_cardinality == s[12]
assert isinstance(rv, Relation) assert isinstance(rv, Relation)
......
...@@ -115,24 +115,26 @@ def test_flask_sqlalchemy(): ...@@ -115,24 +115,26 @@ def test_flask_sqlalchemy():
check_intermediary_representation_simple_all_table(tables, relationships) check_intermediary_representation_simple_all_table(tables, relationships)
@pytest.mark.external_db
def test_table_names_in_relationships(pg_db_uri): def test_table_names_in_relationships(pg_db_uri):
tables, relationships = database_to_intermediary(pg_db_uri) tables, relationships = database_to_intermediary(pg_db_uri)
table_names = [t.name for t in tables] table_names = [t.name for t in tables]
# Assert column names are table names # Assert column names are table names
assert all(r.right_col in table_names for r in relationships) assert all(r.right_table in table_names for r in relationships)
assert all(r.left_col in table_names for r in relationships) assert all(r.left_table in table_names for r in relationships)
# Assert column names match table names # Assert column names match table names
for r in relationships: for r in relationships:
r_name = table_names[table_names.index(r.right_col)] r_name = table_names[table_names.index(r.right_table)]
l_name = table_names[table_names.index(r.left_col)] l_name = table_names[table_names.index(r.left_table)]
# Table name in relationship should *NOT* have a schema # Table name in relationship should *NOT* have a schema
assert r_name.find(".") == -1 assert r_name.find(".") == -1
assert l_name.find(".") == -1 assert l_name.find(".") == -1
@pytest.mark.external_db
def test_table_names_in_relationships_with_schema(pg_db_uri): def test_table_names_in_relationships_with_schema(pg_db_uri):
schema_name = "test" schema_name = "test"
matcher = re.compile(rf"{schema_name}\.[\S+]", re.I) matcher = re.compile(rf"{schema_name}\.[\S+]", re.I)
...@@ -140,8 +142,8 @@ def test_table_names_in_relationships_with_schema(pg_db_uri): ...@@ -140,8 +142,8 @@ def test_table_names_in_relationships_with_schema(pg_db_uri):
table_names = [t.name for t in tables] table_names = [t.name for t in tables]
# Assert column names match table names, including schema # Assert column names match table names, including schema
assert all(r.right_col in table_names for r in relationships) assert all(r.right_table in table_names for r in relationships)
assert all(r.left_col in table_names for r in relationships) assert all(r.left_table in table_names for r in relationships)
# Assert column names match table names, including schema # Assert column names match table names, including schema
for r in relationships: for r in relationships:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment