Skip to content
Snippets Groups Projects
Commit 17a2122b authored by Leah Tacke genannt Unterberg's avatar Leah Tacke genannt Unterberg
Browse files

maybe fixed some isinstance check

parent a25ec003
No related branches found
No related tags found
No related merge requests found
...@@ -5,13 +5,12 @@ from contextlib import contextmanager ...@@ -5,13 +5,12 @@ from contextlib import contextmanager
import sqlalchemy import sqlalchemy
import sqlalchemy as sa import sqlalchemy as sa
from pydantic import AnyUrl from pydantic import AnyUrl
from sqlalchemy import Engine from sqlalchemy import Engine, Connection
from typing import Type, Generator from typing import Type, Generator
from sqlalchemy.orm.session import Session
from sqlalchemy.orm import Session EngineOrConnection = Engine | Connection
AnyDBBind = EngineOrConnection | Session
EngineOrConnection = sa.Engine | sa.Connection
AnyDBBind = EngineOrConnection | sqlalchemy.orm.Session
def qualify(*, table: str, schema: str | None = None, column: str | None = None): def qualify(*, table: str, schema: str | None = None, column: str | None = None):
res = table res = table
...@@ -45,23 +44,25 @@ def dialect_cls_from_url(url: AnyUrl) -> Type[sa.engine.Dialect]: ...@@ -45,23 +44,25 @@ def dialect_cls_from_url(url: AnyUrl) -> Type[sa.engine.Dialect]:
@contextmanager @contextmanager
def use_nested_conn(bind: AnyDBBind) -> Generator[sa.Connection, None, None]: def use_nested_conn(bind: AnyDBBind) -> Generator[Connection, None, None]:
if isinstance(bind, sa.Engine): if isinstance(bind, Engine):
yield bind.connect() yield bind.connect()
elif isinstance(bind, sa.Connection): elif isinstance(bind, Connection):
with bind.begin_nested(): with bind.begin_nested():
yield bind yield bind
elif isinstance(bind, Session): elif isinstance(bind, Session):
with bind.begin_nested(): with bind.begin_nested():
yield bind.connection() yield bind.connection()
else:
raise TypeError(f"Expected Engine, Connection or Session, got {type(bind)}")
@contextmanager @contextmanager
def use_db_bind(bind: AnyDBBind) -> Generator[sa.Connection, None, None]: def use_db_bind(bind: AnyDBBind) -> Generator[Connection, None, None]:
if isinstance(bind, Session): if isinstance(bind, Session):
yield bind.connection() yield bind.connection()
if isinstance(bind, sa.Connection): elif isinstance(bind, Connection):
yield bind yield bind
elif isinstance(bind, sa.Engine): elif isinstance(bind, Engine):
yield bind.connect() yield bind.connect()
else: else:
raise TypeError(f"Expected Engine, Connection or Session, got {type(bind)}") raise TypeError(f"Expected Engine, Connection or Session, got {type(bind)}")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment