From 3bc8daf3db2c2817b9150d906b139ba11e4c4b91 Mon Sep 17 00:00:00 2001
From: Florian Maurer <f.maurer@outlook.de>
Date: Sat, 13 Jul 2024 21:40:03 +0200
Subject: [PATCH] allow people to have either pygraphviz or graphviz installed

* Remove graphviz from the required dependencies, as it is not needed when using the mermaid/markdown export
* add both graphviz and pygraphviz in CI for mypy
* add documentation of graphviz flavors
---
 README.md                          | 15 +++++++++
 eralchemy/main.py                  | 51 ++++++++++++++++++++++--------
 pyproject.toml                     |  9 ++++--
 tests/test_intermediary_to_dot.py  |  3 +-
 tests/test_sqla_to_intermediary.py |  2 ++
 5 files changed, 62 insertions(+), 18 deletions(-)

diff --git a/README.md b/README.md
index 2dde9ec..c16b3da 100644
--- a/README.md
+++ b/README.md
@@ -18,6 +18,21 @@ To install eralchemy, just do:
 
     $ pip install eralchemy
 
+### Graph library flavors
+
+To create Pictures and PDFs, eralchemy relies on either graphviz or pygraphviz.
+
+You can use either
+
+    $ pip install eralchemy[graphviz]
+
+or
+
+    $ pip install eralchemy[pygraphviz]
+
+to retrieve the correct dependencies.
+The `graphviz` library is the default if both are installed.
+
 `eralchemy` requires [GraphViz](http://www.graphviz.org/download) to generate the graphs and Python. Both are available for Windows, Mac and Linux.
 
 For Debian based systems, run:
diff --git a/eralchemy/main.py b/eralchemy/main.py
index 30da6f1..d962617 100644
--- a/eralchemy/main.py
+++ b/eralchemy/main.py
@@ -1,12 +1,11 @@
 import argparse
 import base64
 import copy
+import logging
 import re
 import sys
 from importlib.metadata import PackageNotFoundError, version
 
-#from pygraphviz.agraph import AGraph
-from graphviz import Source
 from sqlalchemy.engine.url import make_url
 from sqlalchemy.exc import ArgumentError
 
@@ -23,6 +22,22 @@ from .sqla import (
     metadata_to_intermediary,
 )
 
+USE_PYGRAPHVIZ = True
+GRAPHVIZ_AVAILABLE = True
+try:
+    from pygraphviz.agraph import AGraph
+
+    logging.debug("using pygraphviz")
+except ImportError:
+    USE_PYGRAPHVIZ = False
+    try:
+        from graphviz import Source
+
+        logging.debug("using graphviz")
+    except ImportError:
+        logging.error("either pygraphviz or graphviz should be installed")
+        GRAPHVIZ_AVAILABLE = False
+
 try:
     __version__ = version(__package__)
 except PackageNotFoundError:
@@ -139,13 +154,18 @@ def intermediary_to_dot(tables, relationships, output, title=""):
 
 def intermediary_to_schema(tables, relationships, output, title=""):
     """Transforms and save the intermediary representation to the file chosen."""
+    if not GRAPHVIZ_AVAILABLE:
+        raise Exception("neither graphviz or pygraphviz are available. Install either library!")
     dot_file = _intermediary_to_dot(tables, relationships, title)
-    #graph = AGraph()
-    #graph = graph.from_string(dot_file)
-    extension = output.split('.')[-1]
-    #graph.draw(path=output, prog='dot', format=extension)
-    #Source.from_file(filename, engine='dot', format=extension)
-    return Source(dot_file, engine='dot', format=extension)
+    extension = output.split(".")[-1]
+    if USE_PYGRAPHVIZ:
+        graph = AGraph()
+        graph = graph.from_string(dot_file)
+        graph.draw(path=output, prog="dot", format=extension)
+    else:
+        graph = Source(dot_file, engine="dot", format=extension)
+        graph.render(outfile=output, cleanup=True)
+    return graph
 
 
 def _intermediary_to_markdown(tables, relationships):
@@ -366,13 +386,16 @@ def render_er(
     """
     try:
         tables, relationships = all_to_intermediary(input, schema=schema)
-        if include is not None:
-            tables, relationships = filter_includes(tables, relationships, include)
-        if exclude is not None:
-            tables, relationships = filter_excludes(tables, relationships, exclude)
+        tables, relationships = filter_resources(
+            tables,
+            relationships,
+            include_tables=include_tables,
+            include_columns=include_columns,
+            exclude_tables=exclude_tables,
+            exclude_columns=exclude_columns,
+        )
         intermediary_to_output = get_output_mode(output, mode)
-        out = intermediary_to_output(tables, relationships, output)
-        return out
+        return intermediary_to_output(tables, relationships, output)
     except ImportError as e:
         module_name = e.message.split()[-1]
         print(f'Please install {module_name} using "pip install {module_name}".')
diff --git a/pyproject.toml b/pyproject.toml
index 0cd4c0a..f3a843b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,8 +26,7 @@ classifiers=[
 ]
 requires-python = ">=3.8"
 dependencies = [
-  "sqlalchemy >= 1.4",
-  "pygraphviz >= 1.9",
+  "sqlalchemy >= 1.4"
 ]
 
 [project.urls]
@@ -35,10 +34,14 @@ homepage = "https://github.com/eralchemy/eralchemy"
 repository = "https://github.com/eralchemy/eralchemy"
 
 [project.optional-dependencies]
+graphviz = ["graphviz >= 0.20.3"]
+pygraphviz = ["pygraphviz >= 1.9"]
 ci = [
   "flask-sqlalchemy >= 2.5.1",
   "psycopg2 >= 2.9.3",
   "pytest >= 7.4.3",
+  "pygraphviz >= 1.9",
+  "graphviz >= 0.20.3",
 ]
 dev = [
   "nox",
@@ -103,7 +106,7 @@ show_error_codes = true
 pretty = true
 
 [[tool.mypy.overrides]]
-module = ["pygraphviz.*", "sqlalchemy.*"]
+module = ["graphviz.*", "pygraphviz.*", "sqlalchemy.*"]
 ignore_missing_imports = true
 
 
diff --git a/tests/test_intermediary_to_dot.py b/tests/test_intermediary_to_dot.py
index e75a969..ea4261a 100644
--- a/tests/test_intermediary_to_dot.py
+++ b/tests/test_intermediary_to_dot.py
@@ -30,7 +30,8 @@ column_inside = re.compile(
     "(?P<key_closing>.*)\\ <FONT\\>\\ \\[(?P<type>.*)\\]\\<\\/FONT\\>",
 )
 
-#This test needs fixing with move to graphviz
+
+# This test needs fixing with move to graphviz
 def assert_is_dot_format(dot):
     """Checks that the dot is usable by graphviz."""
 
diff --git a/tests/test_sqla_to_intermediary.py b/tests/test_sqla_to_intermediary.py
index 8d52309..8566e83 100644
--- a/tests/test_sqla_to_intermediary.py
+++ b/tests/test_sqla_to_intermediary.py
@@ -115,6 +115,7 @@ def test_flask_sqlalchemy():
     check_intermediary_representation_simple_all_table(tables, relationships)
 
 
+@pytest.mark.external_db
 def test_table_names_in_relationships(pg_db_uri):
     tables, relationships = database_to_intermediary(pg_db_uri)
     table_names = [t.name for t in tables]
@@ -133,6 +134,7 @@ def test_table_names_in_relationships(pg_db_uri):
         assert l_name.find(".") == -1
 
 
+@pytest.mark.external_db
 def test_table_names_in_relationships_with_schema(pg_db_uri):
     schema_name = "test"
     matcher = re.compile(rf"{schema_name}\.[\S+]", re.I)
-- 
GitLab