From 17a2122b9d8c995086e89369f4aa93ffed9fcb35 Mon Sep 17 00:00:00 2001 From: Leah Tacke genannt Unterberg <leah.tgu@pads.rwth-aachen.de> Date: Thu, 8 May 2025 17:17:28 +0200 Subject: [PATCH] maybe fixed some isinstance check --- mitm_tooling/utilities/sql_utils.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/mitm_tooling/utilities/sql_utils.py b/mitm_tooling/utilities/sql_utils.py index 1767958..8c7f0ed 100644 --- a/mitm_tooling/utilities/sql_utils.py +++ b/mitm_tooling/utilities/sql_utils.py @@ -5,13 +5,12 @@ from contextlib import contextmanager import sqlalchemy import sqlalchemy as sa from pydantic import AnyUrl -from sqlalchemy import Engine +from sqlalchemy import Engine, Connection from typing import Type, Generator +from sqlalchemy.orm.session import Session -from sqlalchemy.orm import Session - -EngineOrConnection = sa.Engine | sa.Connection -AnyDBBind = EngineOrConnection | sqlalchemy.orm.Session +EngineOrConnection = Engine | Connection +AnyDBBind = EngineOrConnection | Session def qualify(*, table: str, schema: str | None = None, column: str | None = None): res = table @@ -45,23 +44,25 @@ def dialect_cls_from_url(url: AnyUrl) -> Type[sa.engine.Dialect]: @contextmanager -def use_nested_conn(bind: AnyDBBind) -> Generator[sa.Connection, None, None]: - if isinstance(bind, sa.Engine): +def use_nested_conn(bind: AnyDBBind) -> Generator[Connection, None, None]: + if isinstance(bind, Engine): yield bind.connect() - elif isinstance(bind, sa.Connection): + elif isinstance(bind, Connection): with bind.begin_nested(): yield bind elif isinstance(bind, Session): with bind.begin_nested(): yield bind.connection() + else: + raise TypeError(f"Expected Engine, Connection or Session, got {type(bind)}") @contextmanager -def use_db_bind(bind: AnyDBBind) -> Generator[sa.Connection, None, None]: +def use_db_bind(bind: AnyDBBind) -> Generator[Connection, None, None]: if isinstance(bind, Session): yield bind.connection() - if isinstance(bind, sa.Connection): + elif isinstance(bind, Connection): yield bind - elif isinstance(bind, sa.Engine): + elif isinstance(bind, Engine): yield bind.connect() else: raise TypeError(f"Expected Engine, Connection or Session, got {type(bind)}") \ No newline at end of file -- GitLab