diff --git a/mitm_tooling/utilities/sql_utils.py b/mitm_tooling/utilities/sql_utils.py index 17679580f23ec99a2d59a871b57af537298e88dd..8c7f0ed51fad17b5b0ac0c1f9d01e8340bb35b6e 100644 --- a/mitm_tooling/utilities/sql_utils.py +++ b/mitm_tooling/utilities/sql_utils.py @@ -5,13 +5,12 @@ from contextlib import contextmanager import sqlalchemy import sqlalchemy as sa from pydantic import AnyUrl -from sqlalchemy import Engine +from sqlalchemy import Engine, Connection from typing import Type, Generator +from sqlalchemy.orm.session import Session -from sqlalchemy.orm import Session - -EngineOrConnection = sa.Engine | sa.Connection -AnyDBBind = EngineOrConnection | sqlalchemy.orm.Session +EngineOrConnection = Engine | Connection +AnyDBBind = EngineOrConnection | Session def qualify(*, table: str, schema: str | None = None, column: str | None = None): res = table @@ -45,23 +44,25 @@ def dialect_cls_from_url(url: AnyUrl) -> Type[sa.engine.Dialect]: @contextmanager -def use_nested_conn(bind: AnyDBBind) -> Generator[sa.Connection, None, None]: - if isinstance(bind, sa.Engine): +def use_nested_conn(bind: AnyDBBind) -> Generator[Connection, None, None]: + if isinstance(bind, Engine): yield bind.connect() - elif isinstance(bind, sa.Connection): + elif isinstance(bind, Connection): with bind.begin_nested(): yield bind elif isinstance(bind, Session): with bind.begin_nested(): yield bind.connection() + else: + raise TypeError(f"Expected Engine, Connection or Session, got {type(bind)}") @contextmanager -def use_db_bind(bind: AnyDBBind) -> Generator[sa.Connection, None, None]: +def use_db_bind(bind: AnyDBBind) -> Generator[Connection, None, None]: if isinstance(bind, Session): yield bind.connection() - if isinstance(bind, sa.Connection): + elif isinstance(bind, Connection): yield bind - elif isinstance(bind, sa.Engine): + elif isinstance(bind, Engine): yield bind.connect() else: raise TypeError(f"Expected Engine, Connection or Session, got {type(bind)}") \ No newline at end of file