diff --git a/eralchemy/models.py b/eralchemy/models.py index 08f413c667915ef8039dc6cad5a32afd3296b9c1..87e2a74bf7e97b77442fd0dcc2ed6a73e674afd4 100644 --- a/eralchemy/models.py +++ b/eralchemy/models.py @@ -39,7 +39,7 @@ class Drawable(ABC): def sanitize_mermaid(text: str, *, is_er: bool = False): - RE = re.compile("[^0-9a-zA-Z._-]+") + RE = re.compile("[^0-9a-zA-Z_-]+") """Mermaid does not allow special characters in column names""" if not text: return text @@ -161,17 +161,17 @@ class Relation(Drawable): return Relation(**match.groupdict()) def __init__( - self, - right_table, - left_table, - right_cardinality=None, - left_cardinality=None, - right_column=None, - left_column=None, + self, + right_table, + left_table, + right_cardinality=None, + left_cardinality=None, + right_column=None, + left_column=None, ): if ( - right_cardinality not in self.cardinalities.keys() - or left_cardinality not in self.cardinalities.keys() + right_cardinality not in self.cardinalities.keys() + or left_cardinality not in self.cardinalities.keys() ): raise ValueError(f"Cardinality should be in {self.cardinalities.keys()}") self.right_table = right_table @@ -195,11 +195,11 @@ class Relation(Drawable): normalized = ( Relation.cardinalities_mermaid.get(k, k) for k in ( - sanitize_mermaid(self.left_table), - self.left_cardinality, - self.right_cardinality, - sanitize_mermaid(self.right_table), - ) + sanitize_mermaid(self.left_table), + self.left_cardinality, + self.right_cardinality, + sanitize_mermaid(self.right_table), + ) ) return ' {} "{}" -- "{}" {}'.format(*normalized) @@ -269,12 +269,12 @@ class Table(Drawable): def to_mermaid(self) -> str: columns = [c.to_mermaid() for c in self.columns] name = sanitize_mermaid(self.name) - return f" class {name}{{\n " + "\n ".join(columns) + "\n }" + return f' class {name}["{self.name}"]{{\n ' + "\n ".join(columns) + "\n }" def to_mermaid_er(self) -> str: columns = [c.to_mermaid_er() for c in self.columns] name = sanitize_mermaid(self.name, is_er=True) - return f"{name} {{\n" + "\n ".join(columns) + "\n}" + return f'{name}["{self.name}"] {{\n' + "\n ".join(columns) + "\n}" @property def columns_sorted(self):