Skip to content
Snippets Groups Projects
Unverified Commit 74d7bee9 authored by Florian Maurer's avatar Florian Maurer Committed by GitHub
Browse files

Draw column relations (#72)

parents b9b62f37 8fa193e1
No related branches found
No related tags found
No related merge requests found
......@@ -332,10 +332,10 @@ def filter_resources(
_relationships = [
r
for r in _relationships
if not exclude_tables_re.fullmatch(r.right_col)
and not exclude_tables_re.fullmatch(r.left_col)
and include_tables_re.fullmatch(r.right_col)
and include_tables_re.fullmatch(r.left_col)
if not exclude_tables_re.fullmatch(r.right_table)
and not exclude_tables_re.fullmatch(r.left_table)
and include_tables_re.fullmatch(r.right_table)
and include_tables_re.fullmatch(r.left_table)
]
def check_column(name):
......
......@@ -113,10 +113,11 @@ class Column(Drawable):
def to_dot(self) -> str:
base = ROW_TAGS.format(
' ALIGN="LEFT"',
' ALIGN="LEFT" {port}',
"{key_opening}{col_name}{key_closing} {type}{null}",
)
return base.format(
port=f'PORT="{self.name}"' if self.name else "",
key_opening="<u>" if self.is_key else "",
key_closing="</u>" if self.is_key else "",
col_name=FONT_TAGS.format(self.name),
......@@ -129,7 +130,19 @@ class Relation(Drawable):
"""Represents a Relation in the intermediaty syntax."""
RE = re.compile(
r"(?P<left_name>\S+(\s*\S+)?)\s+(?P<left_cardinality>[*?+1])--(?P<right_cardinality>[*?+1])\s*(?P<right_name>\S+(\s*\S+)?)",
r"""
(?P<left_table>[^\s]+?)
(?:\.\"(?P<left_column>.+)\")?
\s*
(?P<left_cardinality>[*?+1])
--
(?P<right_cardinality>[*?+1])
\s*
(?P<right_table>[^\s]+?)
(?:\.\"(?P<right_column>.+)\")?
\s*$
""",
re.VERBOSE,
)
cardinalities = {"*": "0..N", "?": "{0,1}", "+": "1..N", "1": "1", "": None}
cardinalities_mermaid = {
......@@ -145,41 +158,47 @@ class Relation(Drawable):
@staticmethod
def make_from_match(match: re.Match) -> Relation:
return Relation(
right_col=match.group("right_name"),
left_col=match.group("left_name"),
right_cardinality=match.group("right_cardinality"),
left_cardinality=match.group("left_cardinality"),
)
return Relation(**match.groupdict())
def __init__(
self,
right_col,
left_col,
right_table,
left_table,
right_cardinality=None,
left_cardinality=None,
right_column=None,
left_column=None,
):
if (
right_cardinality not in self.cardinalities.keys()
or left_cardinality not in self.cardinalities.keys()
):
raise ValueError(f"Cardinality should be in {self.cardinalities.keys()}")
self.right_col = right_col
self.left_col = left_col
self.right_table = right_table
self.right_column = right_column or ""
self.left_table = left_table
self.left_column = left_column or ""
self.right_cardinality = right_cardinality
self.left_cardinality = left_cardinality
def to_markdown(self) -> str:
return f"{self.left_col} {self.left_cardinality}--{self.right_cardinality} {self.right_col}"
return "{}{} {}--{} {}{}".format(
self.left_table,
"" if not self.left_column else f'."{self.left_column}"',
self.left_cardinality,
self.right_cardinality,
self.right_table,
"" if not self.right_column else f'."{self.right_column}"',
)
def to_mermaid(self) -> str:
normalized = (
Relation.cardinalities_mermaid.get(k, k)
for k in (
sanitize_mermaid(self.left_col),
sanitize_mermaid(self.left_table),
self.left_cardinality,
self.right_cardinality,
sanitize_mermaid(self.right_col),
sanitize_mermaid(self.right_table),
)
)
return '{} "{}" -- "{}" {}'.format(*normalized)
......@@ -187,15 +206,13 @@ class Relation(Drawable):
def to_mermaid_er(self) -> str:
left = Relation.cardinalities_crowfoot.get(
self.left_cardinality,
self.left_cardinality,
)
right = Relation.cardinalities_crowfoot.get(
self.right_cardinality,
self.right_cardinality,
)
left_col = sanitize_mermaid(self.left_col, is_er=True)
right_col = sanitize_mermaid(self.right_col, is_er=True)
left_col = sanitize_mermaid(self.left_table, is_er=True)
right_col = sanitize_mermaid(self.right_table, is_er=True)
return f"{left_col} {left}--{right} {right_col} : has"
def graphviz_cardinalities(self, card) -> str:
......@@ -211,10 +228,10 @@ class Relation(Drawable):
cards.append("tail" + self.graphviz_cardinalities(self.left_cardinality))
if self.right_cardinality != "":
cards.append("head" + self.graphviz_cardinalities(self.right_cardinality))
return '"{}" -- "{}" [{}];'.format(
self.left_col,
self.right_col,
",".join(cards),
left_col = f':"{self.left_column}"' if self.left_column else ""
right_col = f':"{self.right_column}"' if self.right_column else ""
return (
f'"{self.left_table}"{left_col} -- "{self.right_table}"{right_col} [{",".join(cards)}];'
)
def __eq__(self, other: object) -> bool:
......@@ -223,8 +240,10 @@ class Relation(Drawable):
if not isinstance(other, Relation):
return False
other_inversed = Relation(
right_col=other.left_col,
left_col=other.right_col,
right_table=other.left_table,
right_column=other.left_column,
left_table=other.right_table,
left_column=other.right_column,
right_cardinality=other.left_cardinality,
left_cardinality=other.right_cardinality,
)
......
......@@ -124,8 +124,8 @@ def update_models(
if isinstance(new_obj, Relation):
tables_names = [t.name for t in tables]
_check_colname_in_lst(new_obj.right_col, tables_names)
_check_colname_in_lst(new_obj.left_col, tables_names)
_check_colname_in_lst(new_obj.right_table, tables_names)
_check_colname_in_lst(new_obj.left_table, tables_names)
return current_table, tables, relations + [new_obj]
if isinstance(new_obj, Column):
......
......@@ -47,8 +47,10 @@ def relation_to_intermediary(fk: sa.ForeignKey) -> Relation:
# if this is the case, we are not optional and must be unique
right_cardinality = "1" if check_all_compound_same_parent(fk) else "*"
return Relation(
right_col=format_name(fk.parent.table.fullname),
left_col=format_name(fk.column.table.fullname),
right_table=format_name(fk.parent.table.fullname),
right_column=format_name(fk.parent.name),
left_table=format_name(fk.column.table.fullname),
left_column=format_name(fk.column.name),
right_cardinality=right_cardinality,
left_cardinality="?" if fk.parent.nullable else "1",
)
......
......@@ -68,10 +68,12 @@ child_parent_id = ERColumn(
)
relation = Relation(
right_col="parent",
left_col="child",
right_cardinality="?",
left_table="child",
left_column="parent_id",
right_table="parent",
right_column="id",
left_cardinality="*",
right_cardinality="?",
)
exclude_id = ERColumn(name="id", type="INTEGER", is_key=True)
......@@ -82,8 +84,10 @@ exclude_parent_id = ERColumn(
)
exclude_relation = Relation(
right_col="parent",
left_col="exclude",
right_table="parent",
right_column="id",
left_table="exclude",
left_column="parent_id",
right_cardinality="?",
left_cardinality="*",
)
......@@ -117,8 +121,8 @@ markdown = """
[exclude]
*id {label:"INTEGER"}
parent_id {label:"INTEGER"}
parent ?--* child
parent ?--* exclude
child."parent_id" *--? parent."id"
exclude."parent_id" *--? parent."id"
"""
......
......@@ -5,7 +5,6 @@ from multiprocessing import Process
import pytest
from pygraphviz import AGraph
from eralchemy.cst import DOT_GRAPH_BEGINNING
from eralchemy.main import _intermediary_to_dot
from tests.common import (
child,
......@@ -17,17 +16,13 @@ from tests.common import (
relation,
)
GRAPH_LAYOUT = DOT_GRAPH_BEGINNING + "%s }"
column_re = re.compile(
'\\<TR\\>\\<TD\\ ALIGN\\=\\"LEFT\\"\\>(.*)\\<\\/TD\\>\\<\\/TR\\>',
)
column_re = re.compile(r"\<TR\>\<TD\ ALIGN\=\"LEFT\"\ PORT\=\".+\">(.*)\<\/TD\>\<\/TR\>")
header_re = re.compile(
'\\<TR\\>\\<TD\\>\\<B\\>\\<FONT\\ POINT\\-SIZE\\=\\"16\\"\\>(.*)'
"\\<\\/FONT\\>\\<\\/B\\>\\<\\/TD\\>\\<\\/TR\\>",
r"\<TR\>\<TD\>\<B\>\<FONT\ POINT\-SIZE\=\"16\"\>(.*)" r"\<\/FONT\>\<\/B\>\<\/TD\>\<\/TR\>"
)
column_inside = re.compile(
"(?P<key_opening>.*)\\<FONT\\>(?P<name>.*)\\<\\/FONT\\>"
"(?P<key_closing>.*)\\ <FONT\\>\\ \\[(?P<type>.*)\\]\\<\\/FONT\\>",
r"(?P<key_opening>.*)\<FONT\>(?P<name>.*)\<\/FONT\>"
r"(?P<key_closing>.*)\<FONT\>\ \[(?P<type>.*)\]\<\/FONT\>"
)
......@@ -92,14 +87,16 @@ def test_column_is_dot_format():
def test_relation():
relation_re = re.compile(
'\\"(?P<l_name>.+)\\"\\ \\-\\-\\ \\"(?P<r_name>.+)\\"\\ '
"\\[taillabel\\=\\<\\<FONT\\>(?P<l_card>.+)\\<\\/FONT\\>\\>"
"\\,headlabel\\=\\<\\<FONT\\>(?P<r_card>.+)\\<\\/FONT\\>\\>\\]\\;",
r"\"(?P<l_table>.+)\":\"(?P<l_column>.+)\"\ \-\-\ \"(?P<r_table>.+)\":\"(?P<r_column>.+)\"\ "
r"\[taillabel\=\<\<FONT\>(?P<l_card>.+)\<\/FONT\>\>"
r"\,headlabel\=\<\<FONT\>(?P<r_card>.+)\<\/FONT\>\>\]\;"
)
dot = relation.to_dot()
r = relation_re.match(dot)
assert r.group("l_name") == "child"
assert r.group("r_name") == "parent"
assert r.group("l_table") == "child"
assert r.group("l_column") == "parent_id"
assert r.group("r_table") == "parent"
assert r.group("r_column") == "id"
assert r.group("l_card") == "0..N"
assert r.group("r_card") == "{0,1}"
......
......@@ -38,7 +38,10 @@ def test_column_to_er():
def test_relation():
assert relation.to_markdown() in ["parent ?--* child", "child *--? parent"]
assert relation.to_markdown() in [
'parent."id" *--? child."parent_id"',
'child."parent_id" *--? parent."id"',
]
def assert_table_well_rendered_to_er(table):
......
......@@ -73,9 +73,11 @@ def test_parse_line():
assert isinstance(rv, Column)
for s in relations_lst:
rv = parse_line(s)
assert rv.right_col == s[16:].strip()
assert rv.left_col == s[:12].strip()
rv = parse_line(s) # type: Relation
assert rv.right_table == s[16:].strip()
assert rv.right_column == ""
assert rv.left_table == s[:12].strip()
assert rv.left_column == ""
assert rv.right_cardinality == s[15]
assert rv.left_cardinality == s[12]
assert isinstance(rv, Relation)
......
......@@ -121,13 +121,13 @@ def test_table_names_in_relationships(pg_db_uri):
table_names = [t.name for t in tables]
# Assert column names are table names
assert all(r.right_col in table_names for r in relationships)
assert all(r.left_col in table_names for r in relationships)
assert all(r.right_table in table_names for r in relationships)
assert all(r.left_table in table_names for r in relationships)
# Assert column names match table names
for r in relationships:
r_name = table_names[table_names.index(r.right_col)]
l_name = table_names[table_names.index(r.left_col)]
r_name = table_names[table_names.index(r.right_table)]
l_name = table_names[table_names.index(r.left_table)]
# Table name in relationship should *NOT* have a schema
assert r_name.find(".") == -1
......@@ -142,8 +142,8 @@ def test_table_names_in_relationships_with_schema(pg_db_uri):
table_names = [t.name for t in tables]
# Assert column names match table names, including schema
assert all(r.right_col in table_names for r in relationships)
assert all(r.left_col in table_names for r in relationships)
assert all(r.right_table in table_names for r in relationships)
assert all(r.left_table in table_names for r in relationships)
# Assert column names match table names, including schema
for r in relationships:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment