Skip to content
Snippets Groups Projects
Unverified Commit 37f6e5b1 authored by Rambaud Pierrick's avatar Rambaud Pierrick Committed by GitHub
Browse files

Merge branch 'main' into single-test

parents bdf47e75 25f1b532
Branches
No related tags found
No related merge requests found
......@@ -20,6 +20,21 @@ To install eralchemy, just do:
$ pip install eralchemy
### Graph library flavors
To create Pictures and PDFs, eralchemy relies on either graphviz or pygraphviz.
You can use either
$ pip install eralchemy[graphviz]
or
$ pip install eralchemy[pygraphviz]
to retrieve the correct dependencies.
The `graphviz` library is the default if both are installed.
`eralchemy` requires [GraphViz](http://www.graphviz.org/download) to generate the graphs and Python. Both are available for Windows, Mac and Linux.
For Debian based systems, run:
......
import argparse
import base64
import copy
import logging
import re
import sys
from importlib.metadata import PackageNotFoundError, version
from pygraphviz.agraph import AGraph
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import ArgumentError
......@@ -22,17 +22,33 @@ from .sqla import (
metadata_to_intermediary,
)
USE_PYGRAPHVIZ = True
GRAPHVIZ_AVAILABLE = True
try:
from pygraphviz.agraph import AGraph
logging.debug("using pygraphviz")
except ImportError:
USE_PYGRAPHVIZ = False
try:
from graphviz import Source
logging.debug("using graphviz")
except ImportError:
logging.error("either pygraphviz or graphviz should be installed")
GRAPHVIZ_AVAILABLE = False
try:
__version__ = version(__package__)
except PackageNotFoundError:
__version__ = "na"
def cli() -> None:
def cli(args=None) -> None:
"""Entry point for the application script."""
parser = get_argparser()
args = parser.parse_args()
args = parser.parse_args(args)
check_args(args)
if args.v:
print(f"eralchemy version {__version__}.")
......@@ -138,11 +154,18 @@ def intermediary_to_dot(tables, relationships, output, title=""):
def intermediary_to_schema(tables, relationships, output, title=""):
"""Transforms and save the intermediary representation to the file chosen."""
if not GRAPHVIZ_AVAILABLE:
raise Exception("neither graphviz or pygraphviz are available. Install either library!")
dot_file = _intermediary_to_dot(tables, relationships, title)
extension = output.split(".")[-1]
if USE_PYGRAPHVIZ:
graph = AGraph()
graph = graph.from_string(dot_file)
extension = output.split(".")[-1]
graph.draw(path=output, prog="dot", format=extension)
else:
graph = Source(dot_file, engine="dot", format=extension)
graph.render(outfile=output, cleanup=True)
return graph
def _intermediary_to_markdown(tables, relationships):
......@@ -309,10 +332,10 @@ def filter_resources(
_relationships = [
r
for r in _relationships
if not exclude_tables_re.fullmatch(r.right_col)
and not exclude_tables_re.fullmatch(r.left_col)
and include_tables_re.fullmatch(r.right_col)
and include_tables_re.fullmatch(r.left_col)
if not exclude_tables_re.fullmatch(r.right_table)
and not exclude_tables_re.fullmatch(r.left_table)
and include_tables_re.fullmatch(r.right_table)
and include_tables_re.fullmatch(r.left_table)
]
def check_column(name):
......@@ -372,7 +395,7 @@ def render_er(
exclude_columns=exclude_columns,
)
intermediary_to_output = get_output_mode(output, mode)
intermediary_to_output(tables, relationships, output, title)
return intermediary_to_output(tables, relationships, output, title)
except ImportError as e:
module_name = e.message.split()[-1]
print(f'Please install {module_name} using "pip install {module_name}".')
......@@ -383,4 +406,5 @@ def render_er(
if __name__ == "__main__":
# cli("-i example/forum.er -o test.dot".split(" "))
cli()
......@@ -113,10 +113,11 @@ class Column(Drawable):
def to_dot(self) -> str:
base = ROW_TAGS.format(
' ALIGN="LEFT"',
' ALIGN="LEFT" {port}',
"{key_opening}{col_name}{key_closing} {type}{null}",
)
return base.format(
port=f'PORT="{self.name}"' if self.name else "",
key_opening="<u>" if self.is_key else "",
key_closing="</u>" if self.is_key else "",
col_name=FONT_TAGS.format(self.name),
......@@ -129,7 +130,19 @@ class Relation(Drawable):
"""Represents a Relation in the intermediaty syntax."""
RE = re.compile(
r"(?P<left_name>\S+(\s*\S+)?)\s+(?P<left_cardinality>[*?+1])--(?P<right_cardinality>[*?+1])\s*(?P<right_name>\S+(\s*\S+)?)",
r"""
(?P<left_table>[^\s]+?)
(?:\.\"(?P<left_column>.+)\")?
\s*
(?P<left_cardinality>[*?+1])
--
(?P<right_cardinality>[*?+1])
\s*
(?P<right_table>[^\s]+?)
(?:\.\"(?P<right_column>.+)\")?
\s*$
""",
re.VERBOSE,
)
cardinalities = {"*": "0..N", "?": "{0,1}", "+": "1..N", "1": "1", "": None}
cardinalities_mermaid = {
......@@ -145,41 +158,47 @@ class Relation(Drawable):
@staticmethod
def make_from_match(match: re.Match) -> Relation:
return Relation(
right_col=match.group("right_name"),
left_col=match.group("left_name"),
right_cardinality=match.group("right_cardinality"),
left_cardinality=match.group("left_cardinality"),
)
return Relation(**match.groupdict())
def __init__(
self,
right_col,
left_col,
right_table,
left_table,
right_cardinality=None,
left_cardinality=None,
right_column=None,
left_column=None,
):
if (
right_cardinality not in self.cardinalities.keys()
or left_cardinality not in self.cardinalities.keys()
):
raise ValueError(f"Cardinality should be in {self.cardinalities.keys()}")
self.right_col = right_col
self.left_col = left_col
self.right_table = right_table
self.right_column = right_column or ""
self.left_table = left_table
self.left_column = left_column or ""
self.right_cardinality = right_cardinality
self.left_cardinality = left_cardinality
def to_markdown(self) -> str:
return f"{self.left_col} {self.left_cardinality}--{self.right_cardinality} {self.right_col}"
return "{}{} {}--{} {}{}".format(
self.left_table,
"" if not self.left_column else f'."{self.left_column}"',
self.left_cardinality,
self.right_cardinality,
self.right_table,
"" if not self.right_column else f'."{self.right_column}"',
)
def to_mermaid(self) -> str:
normalized = (
Relation.cardinalities_mermaid.get(k, k)
for k in (
sanitize_mermaid(self.left_col),
sanitize_mermaid(self.left_table),
self.left_cardinality,
self.right_cardinality,
sanitize_mermaid(self.right_col),
sanitize_mermaid(self.right_table),
)
)
return '{} "{}" -- "{}" {}'.format(*normalized)
......@@ -187,15 +206,13 @@ class Relation(Drawable):
def to_mermaid_er(self) -> str:
left = Relation.cardinalities_crowfoot.get(
self.left_cardinality,
self.left_cardinality,
)
right = Relation.cardinalities_crowfoot.get(
self.right_cardinality,
self.right_cardinality,
)
left_col = sanitize_mermaid(self.left_col, is_er=True)
right_col = sanitize_mermaid(self.right_col, is_er=True)
left_col = sanitize_mermaid(self.left_table, is_er=True)
right_col = sanitize_mermaid(self.right_table, is_er=True)
return f"{left_col} {left}--{right} {right_col} : has"
def graphviz_cardinalities(self, card) -> str:
......@@ -211,10 +228,10 @@ class Relation(Drawable):
cards.append("tail" + self.graphviz_cardinalities(self.left_cardinality))
if self.right_cardinality != "":
cards.append("head" + self.graphviz_cardinalities(self.right_cardinality))
return '"{}" -- "{}" [{}];'.format(
self.left_col,
self.right_col,
",".join(cards),
left_col = f':"{self.left_column}"' if self.left_column else ""
right_col = f':"{self.right_column}"' if self.right_column else ""
return (
f'"{self.left_table}"{left_col} -- "{self.right_table}"{right_col} [{",".join(cards)}];'
)
def __eq__(self, other: object) -> bool:
......@@ -223,8 +240,10 @@ class Relation(Drawable):
if not isinstance(other, Relation):
return False
other_inversed = Relation(
right_col=other.left_col,
left_col=other.right_col,
right_table=other.left_table,
right_column=other.left_column,
left_table=other.right_table,
left_column=other.right_column,
right_cardinality=other.left_cardinality,
left_cardinality=other.right_cardinality,
)
......
......@@ -124,8 +124,8 @@ def update_models(
if isinstance(new_obj, Relation):
tables_names = [t.name for t in tables]
_check_colname_in_lst(new_obj.right_col, tables_names)
_check_colname_in_lst(new_obj.left_col, tables_names)
_check_colname_in_lst(new_obj.right_table, tables_names)
_check_colname_in_lst(new_obj.left_table, tables_names)
return current_table, tables, relations + [new_obj]
if isinstance(new_obj, Column):
......
......@@ -47,8 +47,10 @@ def relation_to_intermediary(fk: sa.ForeignKey) -> Relation:
# if this is the case, we are not optional and must be unique
right_cardinality = "1" if check_all_compound_same_parent(fk) else "*"
return Relation(
right_col=format_name(fk.parent.table.fullname),
left_col=format_name(fk.column.table.fullname),
right_table=format_name(fk.parent.table.fullname),
right_column=format_name(fk.parent.name),
left_table=format_name(fk.column.table.fullname),
left_column=format_name(fk.column.name),
right_cardinality=right_cardinality,
left_cardinality="?" if fk.parent.nullable else "1",
)
......
......@@ -26,8 +26,7 @@ classifiers=[
]
requires-python = ">=3.8"
dependencies = [
"sqlalchemy >= 1.4",
"pygraphviz >= 1.9",
"sqlalchemy >= 1.4"
]
[project.urls]
......@@ -40,6 +39,14 @@ test = [
"psycopg2 >= 2.9.3",
"pytest >= 7.4.3",
"pytest-cov",
graphviz = ["graphviz >= 0.20.3"]
pygraphviz = ["pygraphviz >= 1.9"]
ci = [
"flask-sqlalchemy >= 2.5.1",
"psycopg2 >= 2.9.3",
"pytest >= 7.4.3",
"pygraphviz >= 1.9",
"graphviz >= 0.20.3",
]
dev = [
"nox",
......@@ -75,7 +82,7 @@ show_error_codes = true
pretty = true
[[tool.mypy.overrides]]
module = ["pygraphviz.*", "sqlalchemy.*"]
module = ["graphviz.*", "pygraphviz.*", "sqlalchemy.*"]
ignore_missing_imports = true
......
......@@ -68,10 +68,12 @@ child_parent_id = ERColumn(
)
relation = Relation(
right_col="parent",
left_col="child",
right_cardinality="?",
left_table="child",
left_column="parent_id",
right_table="parent",
right_column="id",
left_cardinality="*",
right_cardinality="?",
)
exclude_id = ERColumn(name="id", type="INTEGER", is_key=True)
......@@ -82,8 +84,10 @@ exclude_parent_id = ERColumn(
)
exclude_relation = Relation(
right_col="parent",
left_col="exclude",
right_table="parent",
right_column="id",
left_table="exclude",
left_column="parent_id",
right_cardinality="?",
left_cardinality="*",
)
......@@ -117,8 +121,8 @@ markdown = """
[exclude]
*id {label:"INTEGER"}
parent_id {label:"INTEGER"}
parent ?--* child
parent ?--* exclude
child."parent_id" *--? parent."id"
exclude."parent_id" *--? parent."id"
"""
......
......@@ -5,7 +5,6 @@ from multiprocessing import Process
import pytest
from pygraphviz import AGraph
from eralchemy.cst import DOT_GRAPH_BEGINNING
from eralchemy.main import _intermediary_to_dot
from tests.common import (
child,
......@@ -17,20 +16,17 @@ from tests.common import (
relation,
)
GRAPH_LAYOUT = DOT_GRAPH_BEGINNING + "%s }"
column_re = re.compile(
'\\<TR\\>\\<TD\\ ALIGN\\=\\"LEFT\\"\\>(.*)\\<\\/TD\\>\\<\\/TR\\>',
)
column_re = re.compile(r"\<TR\>\<TD\ ALIGN\=\"LEFT\"\ PORT\=\".+\">(.*)\<\/TD\>\<\/TR\>")
header_re = re.compile(
'\\<TR\\>\\<TD\\>\\<B\\>\\<FONT\\ POINT\\-SIZE\\=\\"16\\"\\>(.*)'
"\\<\\/FONT\\>\\<\\/B\\>\\<\\/TD\\>\\<\\/TR\\>",
r"\<TR\>\<TD\>\<B\>\<FONT\ POINT\-SIZE\=\"16\"\>(.*)" r"\<\/FONT\>\<\/B\>\<\/TD\>\<\/TR\>"
)
column_inside = re.compile(
"(?P<key_opening>.*)\\<FONT\\>(?P<name>.*)\\<\\/FONT\\>"
"(?P<key_closing>.*)\\ <FONT\\>\\ \\[(?P<type>.*)\\]\\<\\/FONT\\>",
r"(?P<key_opening>.*)\<FONT\>(?P<name>.*)\<\/FONT\>"
r"(?P<key_closing>.*)\<FONT\>\ \[(?P<type>.*)\]\<\/FONT\>"
)
# This test needs fixing with move to graphviz
def assert_is_dot_format(dot):
"""Checks that the dot is usable by graphviz."""
......@@ -91,14 +87,16 @@ def test_column_is_dot_format():
def test_relation():
relation_re = re.compile(
'\\"(?P<l_name>.+)\\"\\ \\-\\-\\ \\"(?P<r_name>.+)\\"\\ '
"\\[taillabel\\=\\<\\<FONT\\>(?P<l_card>.+)\\<\\/FONT\\>\\>"
"\\,headlabel\\=\\<\\<FONT\\>(?P<r_card>.+)\\<\\/FONT\\>\\>\\]\\;",
r"\"(?P<l_table>.+)\":\"(?P<l_column>.+)\"\ \-\-\ \"(?P<r_table>.+)\":\"(?P<r_column>.+)\"\ "
r"\[taillabel\=\<\<FONT\>(?P<l_card>.+)\<\/FONT\>\>"
r"\,headlabel\=\<\<FONT\>(?P<r_card>.+)\<\/FONT\>\>\]\;"
)
dot = relation.to_dot()
r = relation_re.match(dot)
assert r.group("l_name") == "child"
assert r.group("r_name") == "parent"
assert r.group("l_table") == "child"
assert r.group("l_column") == "parent_id"
assert r.group("r_table") == "parent"
assert r.group("r_column") == "id"
assert r.group("l_card") == "0..N"
assert r.group("r_card") == "{0,1}"
......
......@@ -38,7 +38,10 @@ def test_column_to_er():
def test_relation():
assert relation.to_markdown() in ["parent ?--* child", "child *--? parent"]
assert relation.to_markdown() in [
'parent."id" *--? child."parent_id"',
'child."parent_id" *--? parent."id"',
]
def assert_table_well_rendered_to_er(table):
......
......@@ -73,9 +73,11 @@ def test_parse_line():
assert isinstance(rv, Column)
for s in relations_lst:
rv = parse_line(s)
assert rv.right_col == s[16:].strip()
assert rv.left_col == s[:12].strip()
rv = parse_line(s) # type: Relation
assert rv.right_table == s[16:].strip()
assert rv.right_column == ""
assert rv.left_table == s[:12].strip()
assert rv.left_column == ""
assert rv.right_cardinality == s[15]
assert rv.left_cardinality == s[12]
assert isinstance(rv, Relation)
......
......@@ -115,24 +115,26 @@ def test_flask_sqlalchemy():
check_intermediary_representation_simple_all_table(tables, relationships)
@pytest.mark.external_db
def test_table_names_in_relationships(pg_db_uri):
tables, relationships = database_to_intermediary(pg_db_uri)
table_names = [t.name for t in tables]
# Assert column names are table names
assert all(r.right_col in table_names for r in relationships)
assert all(r.left_col in table_names for r in relationships)
assert all(r.right_table in table_names for r in relationships)
assert all(r.left_table in table_names for r in relationships)
# Assert column names match table names
for r in relationships:
r_name = table_names[table_names.index(r.right_col)]
l_name = table_names[table_names.index(r.left_col)]
r_name = table_names[table_names.index(r.right_table)]
l_name = table_names[table_names.index(r.left_table)]
# Table name in relationship should *NOT* have a schema
assert r_name.find(".") == -1
assert l_name.find(".") == -1
@pytest.mark.external_db
def test_table_names_in_relationships_with_schema(pg_db_uri):
schema_name = "test"
matcher = re.compile(rf"{schema_name}\.[\S+]", re.I)
......@@ -140,8 +142,8 @@ def test_table_names_in_relationships_with_schema(pg_db_uri):
table_names = [t.name for t in tables]
# Assert column names match table names, including schema
assert all(r.right_col in table_names for r in relationships)
assert all(r.left_col in table_names for r in relationships)
assert all(r.right_table in table_names for r in relationships)
assert all(r.left_table in table_names for r in relationships)
# Assert column names match table names, including schema
for r in relationships:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment