From c83eda9551ebb14f0fb4d2eb2c5276f00bddc1f7 Mon Sep 17 00:00:00 2001
From: Maxime Beauchemin <maximebeauchemin@gmail.com>
Date: Tue, 1 Apr 2025 17:13:09 -0700
Subject: [PATCH] feat: add latest partition support for BigQuery (#30760)

---
 superset/db_engine_specs/base.py              |   1 +
 superset/db_engine_specs/bigquery.py          | 154 ++++++++++++------
 .../db_engine_specs/bigquery_tests.py         | 133 +++++----------
 3 files changed, 147 insertions(+), 141 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 79e0eb3bfd..eaecb74020 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1630,6 +1630,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         :return: SqlAlchemy query with additional where clause referencing the latest
         partition
         """
+        # TODO: Fix circular import caused by importing Database, TableColumn
         return None
 
     @classmethod
diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py
index c6c57b1624..cf5cfaad51 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -27,15 +27,15 @@ from typing import Any, TYPE_CHECKING, TypedDict
 import pandas as pd
 from apispec import APISpec
 from apispec.ext.marshmallow import MarshmallowPlugin
-from deprecation import deprecated
 from flask_babel import gettext as __
 from marshmallow import fields, Schema
 from marshmallow.exceptions import ValidationError
-from sqlalchemy import column, types
+from sqlalchemy import column, func, types
 from sqlalchemy.engine.base import Engine
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import URL
-from sqlalchemy.sql import sqltypes
+from sqlalchemy.sql import column as sql_column, select, sqltypes
+from sqlalchemy.sql.expression import table as sql_table
 
 from superset.constants import TimeGrain
 from superset.databases.schemas import encrypted_field_properties, EncryptedString
@@ -50,6 +50,11 @@ from superset.superset_typing import ResultSetColumnType
 from superset.utils import core as utils, json
 from superset.utils.hashing import md5_sha_from_str
 
+if TYPE_CHECKING:
+    from sqlalchemy.sql.expression import Select
+
+logger = logging.getLogger(__name__)
+
 try:
     import google.auth
     from google.cloud import bigquery
@@ -289,42 +294,80 @@ class BigQueryEngineSpec(BaseEngineSpec):  # pylint: disable=too-many-public-met
         return "_" + md5_sha_from_str(label)
 
     @classmethod
-    @deprecated(deprecated_in="3.0")
-    def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]]:
-        """
-        Normalizes indexes for more consistency across db engines
+    def where_latest_partition(
+        cls,
+        database: Database,
+        table: Table,
+        query: Select,
+        columns: list[ResultSetColumnType] | None = None,
+    ) -> Select | None:
+        if partition_column := cls.get_time_partition_column(database, table):
+            max_partition_id = cls.get_max_partition_id(database, table)
+            query = query.where(
+                column(partition_column) == func.PARSE_DATE("%Y%m%d", max_partition_id)
+            )
 
-        :param indexes: Raw indexes as returned by SQLAlchemy
-        :return: cleaner, more aligned index definition
-        """
-        normalized_idxs = []
-        # Fixing a bug/behavior observed in pybigquery==0.4.15 where
-        # the index's `column_names` == [None]
-        # Here we're returning only non-None indexes
-        for ix in indexes:
-            column_names = ix.get("column_names") or []
-            ix["column_names"] = [col for col in column_names if col is not None]
-            if ix["column_names"]:
-                normalized_idxs.append(ix)
-        return normalized_idxs
+        return query
 
     @classmethod
-    def get_indexes(
+    def get_max_partition_id(
         cls,
         database: Database,
-        inspector: Inspector,
         table: Table,
-    ) -> list[dict[str, Any]]:
-        """
-        Get the indexes associated with the specified schema/table.
+    ) -> Select | None:
+        # Compose schema from catalog and schema
+        schema_parts = []
+        if table.catalog:
+            schema_parts.append(table.catalog)
+        if table.schema:
+            schema_parts.append(table.schema)
+        schema_parts.append("INFORMATION_SCHEMA")
+        schema = ".".join(schema_parts)
+        # Define a virtual table reference to INFORMATION_SCHEMA.PARTITIONS
+        partitions_table = sql_table(
+            "PARTITIONS",
+            sql_column("partition_id"),
+            sql_column("table_name"),
+            schema=schema,
+        )
 
-        :param database: The database to inspect
-        :param inspector: The SQLAlchemy inspector
-        :param table: The table instance to inspect
-        :returns: The indexes
-        """
+        # Build the query
+        query = select(
+            func.max(partitions_table.c.partition_id).label("max_partition_id")
+        ).where(partitions_table.c.table_name == table.table)
 
-        return cls.normalize_indexes(inspector.get_indexes(table.table, table.schema))
+        # Compile to BigQuery SQL
+        compiled_query = query.compile(
+            dialect=database.get_dialect(),
+            compile_kwargs={"literal_binds": True},
+        )
+
+        # Run the query and handle result
+        with database.get_raw_connection(
+            catalog=table.catalog,
+            schema=table.schema,
+        ) as conn:
+            cursor = conn.cursor()
+            cursor.execute(str(compiled_query))
+            if row := cursor.fetchone():
+                return row[0]
+        return None
+
+    @classmethod
+    def get_time_partition_column(
+        cls,
+        database: Database,
+        table: Table,
+    ) -> str | None:
+        with cls.get_engine(
+            database, catalog=table.catalog, schema=table.schema
+        ) as engine:
+            client = cls._get_client(engine, database)
+            bq_table = client.get_table(f"{table.schema}.{table.table}")
+
+            if bq_table.time_partitioning:
+                return bq_table.time_partitioning.field
+        return None
 
     @classmethod
     def get_extra_table_metadata(
@@ -332,23 +375,38 @@ class BigQueryEngineSpec(BaseEngineSpec):  # pylint: disable=too-many-public-met
         database: Database,
         table: Table,
     ) -> dict[str, Any]:
-        indexes = database.get_indexes(table)
-        if not indexes:
-            return {}
-        partitions_columns = [
-            index.get("column_names", [])
-            for index in indexes
-            if index.get("name") == "partition"
-        ]
-        cluster_columns = [
-            index.get("column_names", [])
-            for index in indexes
-            if index.get("name") == "clustering"
-        ]
-        return {
-            "partitions": {"cols": partitions_columns},
-            "clustering": {"cols": cluster_columns},
-        }
+        payload = {}
+        partition_column = cls.get_time_partition_column(database, table)
+        with cls.get_engine(
+            database, catalog=table.catalog, schema=table.schema
+        ) as engine:
+            if partition_column:
+                max_partition_id = cls.get_max_partition_id(database, table)
+                sql = cls.select_star(
+                    database,
+                    table,
+                    engine,
+                    indent=False,
+                    show_cols=False,
+                    latest_partition=True,
+                )
+                payload.update(
+                    {
+                        "partitions": {
+                            "cols": [partition_column],
+                            "latest": {partition_column: max_partition_id},
+                            "partitionQuery": sql,
+                        },
+                        "indexes": [
+                            {
+                                "name": "partitioned",
+                                "cols": [partition_column],
+                                "type": "partitioned",
+                            }
+                        ],
+                    }
+                )
+        return payload
 
     @classmethod
     def epoch_to_dttm(cls) -> str:
diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py
index 45edb5a0d7..636fc3523a 100644
--- a/tests/integration_tests/db_engine_specs/bigquery_tests.py
+++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import unittest.mock as mock
+from contextlib import contextmanager
 
 import pytest
 from pandas import DataFrame
@@ -32,6 +33,15 @@ from tests.integration_tests.fixtures.birth_names_dashboard import (
 )
 
 
+@contextmanager
+def mock_engine_with_credentials(*args, **kwargs):
+    engine_mock = mock.Mock()
+    engine_mock.dialect.credentials_info = {
+        "key": "value"
+    }  # Add the credentials_info attribute
+    yield engine_mock
+
+
 class TestBigQueryDbEngineSpec(TestDbEngineSpec):
     def test_bigquery_sqla_column_label(self):
         """
@@ -111,108 +121,45 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
             result = BigQueryEngineSpec.fetch_data(None, 0)
         assert result == [1, 2]
 
-    def test_get_extra_table_metadata(self):
+    @mock.patch.object(
+        BigQueryEngineSpec, "get_engine", side_effect=mock_engine_with_credentials
+    )
+    @mock.patch.object(BigQueryEngineSpec, "get_time_partition_column")
+    @mock.patch.object(BigQueryEngineSpec, "get_max_partition_id")
+    @mock.patch.object(BigQueryEngineSpec, "quote_table", return_value="`table_name`")
+    def test_get_extra_table_metadata(
+        self,
+        mock_quote_table,
+        mock_get_max_partition_id,
+        mock_get_time_partition_column,
+        mock_get_engine,
+    ):
         """
         DB Eng Specs (bigquery): Test extra table metadata
         """
         database = mock.Mock()
+        sql = "SELECT * FROM `table_name`"
+        database.compile_sqla_query.return_value = sql
+        tbl = Table("some_table", "some_schema")
+
         # Test no indexes
-        database.get_indexes = mock.MagicMock(return_value=None)
-        result = BigQueryEngineSpec.get_extra_table_metadata(
-            database,
-            Table("some_table", "some_schema"),
-        )
+        mock_get_time_partition_column.return_value = None
+        mock_get_max_partition_id.return_value = None
+        result = BigQueryEngineSpec.get_extra_table_metadata(database, tbl)
         assert result == {}
 
-        index_metadata = [
-            {
-                "name": "clustering",
-                "column_names": ["c_col1", "c_col2", "c_col3"],
-            },
-            {
-                "name": "partition",
-                "column_names": ["p_col1", "p_col2", "p_col3"],
+        mock_get_time_partition_column.return_value = "ds"
+        mock_get_max_partition_id.return_value = "19690101"
+        result = BigQueryEngineSpec.get_extra_table_metadata(database, tbl)
+        print(result)
+        assert result == {
+            "indexes": [{"cols": ["ds"], "name": "partitioned", "type": "partitioned"}],
+            "partitions": {
+                "cols": ["ds"],
+                "latest": {"ds": "19690101"},
+                "partitionQuery": sql,
             },
-        ]
-        expected_result = {
-            "partitions": {"cols": [["p_col1", "p_col2", "p_col3"]]},
-            "clustering": {"cols": [["c_col1", "c_col2", "c_col3"]]},
         }
-        database.get_indexes = mock.MagicMock(return_value=index_metadata)
-        result = BigQueryEngineSpec.get_extra_table_metadata(
-            database,
-            Table("some_table", "some_schema"),
-        )
-        assert result == expected_result
-
-    def test_get_indexes(self):
-        database = mock.Mock()
-        inspector = mock.Mock()
-        schema = "foo"
-        table_name = "bar"
-
-        inspector.get_indexes = mock.Mock(
-            return_value=[
-                {
-                    "name": "partition",
-                    "column_names": [None],
-                    "unique": False,
-                }
-            ]
-        )
-
-        assert (
-            BigQueryEngineSpec.get_indexes(
-                database,
-                inspector,
-                Table(table_name, schema),
-            )
-            == []
-        )
-
-        inspector.get_indexes = mock.Mock(
-            return_value=[
-                {
-                    "name": "partition",
-                    "column_names": ["dttm"],
-                    "unique": False,
-                }
-            ]
-        )
-
-        assert BigQueryEngineSpec.get_indexes(
-            database,
-            inspector,
-            Table(table_name, schema),
-        ) == [
-            {
-                "name": "partition",
-                "column_names": ["dttm"],
-                "unique": False,
-            }
-        ]
-
-        inspector.get_indexes = mock.Mock(
-            return_value=[
-                {
-                    "name": "partition",
-                    "column_names": ["dttm", None],
-                    "unique": False,
-                }
-            ]
-        )
-
-        assert BigQueryEngineSpec.get_indexes(
-            database,
-            inspector,
-            Table(table_name, schema),
-        ) == [
-            {
-                "name": "partition",
-                "column_names": ["dttm"],
-                "unique": False,
-            }
-        ]
 
     @mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine")
     @mock.patch("superset.db_engine_specs.bigquery.pandas_gbq")
-- 
GitLab