From c83eda9551ebb14f0fb4d2eb2c5276f00bddc1f7 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin <maximebeauchemin@gmail.com> Date: Tue, 1 Apr 2025 17:13:09 -0700 Subject: [PATCH] feat: add latest partition support for BigQuery (#30760) --- superset/db_engine_specs/base.py | 1 + superset/db_engine_specs/bigquery.py | 154 ++++++++++++------ .../db_engine_specs/bigquery_tests.py | 133 +++++---------- 3 files changed, 147 insertions(+), 141 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 79e0eb3bfd..eaecb74020 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1630,6 +1630,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :return: SqlAlchemy query with additional where clause referencing the latest partition """ + # TODO: Fix circular import caused by importing Database, TableColumn return None @classmethod diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index c6c57b1624..cf5cfaad51 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -27,15 +27,15 @@ from typing import Any, TYPE_CHECKING, TypedDict import pandas as pd from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin -from deprecation import deprecated from flask_babel import gettext as __ from marshmallow import fields, Schema from marshmallow.exceptions import ValidationError -from sqlalchemy import column, types +from sqlalchemy import column, func, types from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL -from sqlalchemy.sql import sqltypes +from sqlalchemy.sql import column as sql_column, select, sqltypes +from sqlalchemy.sql.expression import table as sql_table from superset.constants import TimeGrain from superset.databases.schemas import encrypted_field_properties, EncryptedString @@ -50,6 +50,11 @@ from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils, json from superset.utils.hashing import md5_sha_from_str +if TYPE_CHECKING: + from sqlalchemy.sql.expression import Select + +logger = logging.getLogger(__name__) + try: import google.auth from google.cloud import bigquery @@ -289,42 +294,80 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met return "_" + md5_sha_from_str(label) @classmethod - @deprecated(deprecated_in="3.0") - def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]]: - """ - Normalizes indexes for more consistency across db engines + def where_latest_partition( + cls, + database: Database, + table: Table, + query: Select, + columns: list[ResultSetColumnType] | None = None, + ) -> Select | None: + if partition_column := cls.get_time_partition_column(database, table): + max_partition_id = cls.get_max_partition_id(database, table) + query = query.where( + column(partition_column) == func.PARSE_DATE("%Y%m%d", max_partition_id) + ) - :param indexes: Raw indexes as returned by SQLAlchemy - :return: cleaner, more aligned index definition - """ - normalized_idxs = [] - # Fixing a bug/behavior observed in pybigquery==0.4.15 where - # the index's `column_names` == [None] - # Here we're returning only non-None indexes - for ix in indexes: - column_names = ix.get("column_names") or [] - ix["column_names"] = [col for col in column_names if col is not None] - if ix["column_names"]: - normalized_idxs.append(ix) - return normalized_idxs + return query @classmethod - def get_indexes( + def get_max_partition_id( cls, database: Database, - inspector: Inspector, table: Table, - ) -> list[dict[str, Any]]: - """ - Get the indexes associated with the specified schema/table. + ) -> Select | None: + # Compose schema from catalog and schema + schema_parts = [] + if table.catalog: + schema_parts.append(table.catalog) + if table.schema: + schema_parts.append(table.schema) + schema_parts.append("INFORMATION_SCHEMA") + schema = ".".join(schema_parts) + # Define a virtual table reference to INFORMATION_SCHEMA.PARTITIONS + partitions_table = sql_table( + "PARTITIONS", + sql_column("partition_id"), + sql_column("table_name"), + schema=schema, + ) - :param database: The database to inspect - :param inspector: The SQLAlchemy inspector - :param table: The table instance to inspect - :returns: The indexes - """ + # Build the query + query = select( + func.max(partitions_table.c.partition_id).label("max_partition_id") + ).where(partitions_table.c.table_name == table.table) - return cls.normalize_indexes(inspector.get_indexes(table.table, table.schema)) + # Compile to BigQuery SQL + compiled_query = query.compile( + dialect=database.get_dialect(), + compile_kwargs={"literal_binds": True}, + ) + + # Run the query and handle result + with database.get_raw_connection( + catalog=table.catalog, + schema=table.schema, + ) as conn: + cursor = conn.cursor() + cursor.execute(str(compiled_query)) + if row := cursor.fetchone(): + return row[0] + return None + + @classmethod + def get_time_partition_column( + cls, + database: Database, + table: Table, + ) -> str | None: + with cls.get_engine( + database, catalog=table.catalog, schema=table.schema + ) as engine: + client = cls._get_client(engine, database) + bq_table = client.get_table(f"{table.schema}.{table.table}") + + if bq_table.time_partitioning: + return bq_table.time_partitioning.field + return None @classmethod def get_extra_table_metadata( @@ -332,23 +375,38 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met database: Database, table: Table, ) -> dict[str, Any]: - indexes = database.get_indexes(table) - if not indexes: - return {} - partitions_columns = [ - index.get("column_names", []) - for index in indexes - if index.get("name") == "partition" - ] - cluster_columns = [ - index.get("column_names", []) - for index in indexes - if index.get("name") == "clustering" - ] - return { - "partitions": {"cols": partitions_columns}, - "clustering": {"cols": cluster_columns}, - } + payload = {} + partition_column = cls.get_time_partition_column(database, table) + with cls.get_engine( + database, catalog=table.catalog, schema=table.schema + ) as engine: + if partition_column: + max_partition_id = cls.get_max_partition_id(database, table) + sql = cls.select_star( + database, + table, + engine, + indent=False, + show_cols=False, + latest_partition=True, + ) + payload.update( + { + "partitions": { + "cols": [partition_column], + "latest": {partition_column: max_partition_id}, + "partitionQuery": sql, + }, + "indexes": [ + { + "name": "partitioned", + "cols": [partition_column], + "type": "partitioned", + } + ], + } + ) + return payload @classmethod def epoch_to_dttm(cls) -> str: diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index 45edb5a0d7..636fc3523a 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import unittest.mock as mock +from contextlib import contextmanager import pytest from pandas import DataFrame @@ -32,6 +33,15 @@ from tests.integration_tests.fixtures.birth_names_dashboard import ( ) +@contextmanager +def mock_engine_with_credentials(*args, **kwargs): + engine_mock = mock.Mock() + engine_mock.dialect.credentials_info = { + "key": "value" + } # Add the credentials_info attribute + yield engine_mock + + class TestBigQueryDbEngineSpec(TestDbEngineSpec): def test_bigquery_sqla_column_label(self): """ @@ -111,108 +121,45 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec): result = BigQueryEngineSpec.fetch_data(None, 0) assert result == [1, 2] - def test_get_extra_table_metadata(self): + @mock.patch.object( + BigQueryEngineSpec, "get_engine", side_effect=mock_engine_with_credentials + ) + @mock.patch.object(BigQueryEngineSpec, "get_time_partition_column") + @mock.patch.object(BigQueryEngineSpec, "get_max_partition_id") + @mock.patch.object(BigQueryEngineSpec, "quote_table", return_value="`table_name`") + def test_get_extra_table_metadata( + self, + mock_quote_table, + mock_get_max_partition_id, + mock_get_time_partition_column, + mock_get_engine, + ): """ DB Eng Specs (bigquery): Test extra table metadata """ database = mock.Mock() + sql = "SELECT * FROM `table_name`" + database.compile_sqla_query.return_value = sql + tbl = Table("some_table", "some_schema") + # Test no indexes - database.get_indexes = mock.MagicMock(return_value=None) - result = BigQueryEngineSpec.get_extra_table_metadata( - database, - Table("some_table", "some_schema"), - ) + mock_get_time_partition_column.return_value = None + mock_get_max_partition_id.return_value = None + result = BigQueryEngineSpec.get_extra_table_metadata(database, tbl) assert result == {} - index_metadata = [ - { - "name": "clustering", - "column_names": ["c_col1", "c_col2", "c_col3"], - }, - { - "name": "partition", - "column_names": ["p_col1", "p_col2", "p_col3"], + mock_get_time_partition_column.return_value = "ds" + mock_get_max_partition_id.return_value = "19690101" + result = BigQueryEngineSpec.get_extra_table_metadata(database, tbl) + print(result) + assert result == { + "indexes": [{"cols": ["ds"], "name": "partitioned", "type": "partitioned"}], + "partitions": { + "cols": ["ds"], + "latest": {"ds": "19690101"}, + "partitionQuery": sql, }, - ] - expected_result = { - "partitions": {"cols": [["p_col1", "p_col2", "p_col3"]]}, - "clustering": {"cols": [["c_col1", "c_col2", "c_col3"]]}, } - database.get_indexes = mock.MagicMock(return_value=index_metadata) - result = BigQueryEngineSpec.get_extra_table_metadata( - database, - Table("some_table", "some_schema"), - ) - assert result == expected_result - - def test_get_indexes(self): - database = mock.Mock() - inspector = mock.Mock() - schema = "foo" - table_name = "bar" - - inspector.get_indexes = mock.Mock( - return_value=[ - { - "name": "partition", - "column_names": [None], - "unique": False, - } - ] - ) - - assert ( - BigQueryEngineSpec.get_indexes( - database, - inspector, - Table(table_name, schema), - ) - == [] - ) - - inspector.get_indexes = mock.Mock( - return_value=[ - { - "name": "partition", - "column_names": ["dttm"], - "unique": False, - } - ] - ) - - assert BigQueryEngineSpec.get_indexes( - database, - inspector, - Table(table_name, schema), - ) == [ - { - "name": "partition", - "column_names": ["dttm"], - "unique": False, - } - ] - - inspector.get_indexes = mock.Mock( - return_value=[ - { - "name": "partition", - "column_names": ["dttm", None], - "unique": False, - } - ] - ) - - assert BigQueryEngineSpec.get_indexes( - database, - inspector, - Table(table_name, schema), - ) == [ - { - "name": "partition", - "column_names": ["dttm"], - "unique": False, - } - ] @mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine") @mock.patch("superset.db_engine_specs.bigquery.pandas_gbq") -- GitLab