Skip to content
Snippets Groups Projects
Unverified Commit c83eda95 authored by Maxime Beauchemin's avatar Maxime Beauchemin Committed by GitHub
Browse files

feat: add latest partition support for BigQuery (#30760)

parent a36e636a
Branches
No related tags found
No related merge requests found
...@@ -1630,6 +1630,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ...@@ -1630,6 +1630,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:return: SqlAlchemy query with additional where clause referencing the latest :return: SqlAlchemy query with additional where clause referencing the latest
partition partition
""" """
# TODO: Fix circular import caused by importing Database, TableColumn
return None return None
@classmethod @classmethod
......
...@@ -27,15 +27,15 @@ from typing import Any, TYPE_CHECKING, TypedDict ...@@ -27,15 +27,15 @@ from typing import Any, TYPE_CHECKING, TypedDict
import pandas as pd import pandas as pd
from apispec import APISpec from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin from apispec.ext.marshmallow import MarshmallowPlugin
from deprecation import deprecated
from flask_babel import gettext as __ from flask_babel import gettext as __
from marshmallow import fields, Schema from marshmallow import fields, Schema
from marshmallow.exceptions import ValidationError from marshmallow.exceptions import ValidationError
from sqlalchemy import column, types from sqlalchemy import column, func, types
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL from sqlalchemy.engine.url import URL
from sqlalchemy.sql import sqltypes from sqlalchemy.sql import column as sql_column, select, sqltypes
from sqlalchemy.sql.expression import table as sql_table
from superset.constants import TimeGrain from superset.constants import TimeGrain
from superset.databases.schemas import encrypted_field_properties, EncryptedString from superset.databases.schemas import encrypted_field_properties, EncryptedString
...@@ -50,6 +50,11 @@ from superset.superset_typing import ResultSetColumnType ...@@ -50,6 +50,11 @@ from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils, json from superset.utils import core as utils, json
from superset.utils.hashing import md5_sha_from_str from superset.utils.hashing import md5_sha_from_str
if TYPE_CHECKING:
from sqlalchemy.sql.expression import Select
logger = logging.getLogger(__name__)
try: try:
import google.auth import google.auth
from google.cloud import bigquery from google.cloud import bigquery
...@@ -289,42 +294,80 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met ...@@ -289,42 +294,80 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
return "_" + md5_sha_from_str(label) return "_" + md5_sha_from_str(label)
@classmethod @classmethod
@deprecated(deprecated_in="3.0") def where_latest_partition(
def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]]: cls,
""" database: Database,
Normalizes indexes for more consistency across db engines table: Table,
query: Select,
columns: list[ResultSetColumnType] | None = None,
) -> Select | None:
if partition_column := cls.get_time_partition_column(database, table):
max_partition_id = cls.get_max_partition_id(database, table)
query = query.where(
column(partition_column) == func.PARSE_DATE("%Y%m%d", max_partition_id)
)
:param indexes: Raw indexes as returned by SQLAlchemy return query
:return: cleaner, more aligned index definition
"""
normalized_idxs = []
# Fixing a bug/behavior observed in pybigquery==0.4.15 where
# the index's `column_names` == [None]
# Here we're returning only non-None indexes
for ix in indexes:
column_names = ix.get("column_names") or []
ix["column_names"] = [col for col in column_names if col is not None]
if ix["column_names"]:
normalized_idxs.append(ix)
return normalized_idxs
@classmethod @classmethod
def get_indexes( def get_max_partition_id(
cls, cls,
database: Database, database: Database,
inspector: Inspector,
table: Table, table: Table,
) -> list[dict[str, Any]]: ) -> Select | None:
""" # Compose schema from catalog and schema
Get the indexes associated with the specified schema/table. schema_parts = []
if table.catalog:
schema_parts.append(table.catalog)
if table.schema:
schema_parts.append(table.schema)
schema_parts.append("INFORMATION_SCHEMA")
schema = ".".join(schema_parts)
# Define a virtual table reference to INFORMATION_SCHEMA.PARTITIONS
partitions_table = sql_table(
"PARTITIONS",
sql_column("partition_id"),
sql_column("table_name"),
schema=schema,
)
:param database: The database to inspect # Build the query
:param inspector: The SQLAlchemy inspector query = select(
:param table: The table instance to inspect func.max(partitions_table.c.partition_id).label("max_partition_id")
:returns: The indexes ).where(partitions_table.c.table_name == table.table)
"""
# Compile to BigQuery SQL
compiled_query = query.compile(
dialect=database.get_dialect(),
compile_kwargs={"literal_binds": True},
)
# Run the query and handle result
with database.get_raw_connection(
catalog=table.catalog,
schema=table.schema,
) as conn:
cursor = conn.cursor()
cursor.execute(str(compiled_query))
if row := cursor.fetchone():
return row[0]
return None
return cls.normalize_indexes(inspector.get_indexes(table.table, table.schema)) @classmethod
def get_time_partition_column(
cls,
database: Database,
table: Table,
) -> str | None:
with cls.get_engine(
database, catalog=table.catalog, schema=table.schema
) as engine:
client = cls._get_client(engine, database)
bq_table = client.get_table(f"{table.schema}.{table.table}")
if bq_table.time_partitioning:
return bq_table.time_partitioning.field
return None
@classmethod @classmethod
def get_extra_table_metadata( def get_extra_table_metadata(
...@@ -332,23 +375,38 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met ...@@ -332,23 +375,38 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
database: Database, database: Database,
table: Table, table: Table,
) -> dict[str, Any]: ) -> dict[str, Any]:
indexes = database.get_indexes(table) payload = {}
if not indexes: partition_column = cls.get_time_partition_column(database, table)
return {} with cls.get_engine(
partitions_columns = [ database, catalog=table.catalog, schema=table.schema
index.get("column_names", []) ) as engine:
for index in indexes if partition_column:
if index.get("name") == "partition" max_partition_id = cls.get_max_partition_id(database, table)
] sql = cls.select_star(
cluster_columns = [ database,
index.get("column_names", []) table,
for index in indexes engine,
if index.get("name") == "clustering" indent=False,
] show_cols=False,
return { latest_partition=True,
"partitions": {"cols": partitions_columns}, )
"clustering": {"cols": cluster_columns}, payload.update(
{
"partitions": {
"cols": [partition_column],
"latest": {partition_column: max_partition_id},
"partitionQuery": sql,
},
"indexes": [
{
"name": "partitioned",
"cols": [partition_column],
"type": "partitioned",
} }
],
}
)
return payload
@classmethod @classmethod
def epoch_to_dttm(cls) -> str: def epoch_to_dttm(cls) -> str:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import unittest.mock as mock import unittest.mock as mock
from contextlib import contextmanager
import pytest import pytest
from pandas import DataFrame from pandas import DataFrame
...@@ -32,6 +33,15 @@ from tests.integration_tests.fixtures.birth_names_dashboard import ( ...@@ -32,6 +33,15 @@ from tests.integration_tests.fixtures.birth_names_dashboard import (
) )
@contextmanager
def mock_engine_with_credentials(*args, **kwargs):
engine_mock = mock.Mock()
engine_mock.dialect.credentials_info = {
"key": "value"
} # Add the credentials_info attribute
yield engine_mock
class TestBigQueryDbEngineSpec(TestDbEngineSpec): class TestBigQueryDbEngineSpec(TestDbEngineSpec):
def test_bigquery_sqla_column_label(self): def test_bigquery_sqla_column_label(self):
""" """
...@@ -111,108 +121,45 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec): ...@@ -111,108 +121,45 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
result = BigQueryEngineSpec.fetch_data(None, 0) result = BigQueryEngineSpec.fetch_data(None, 0)
assert result == [1, 2] assert result == [1, 2]
def test_get_extra_table_metadata(self): @mock.patch.object(
BigQueryEngineSpec, "get_engine", side_effect=mock_engine_with_credentials
)
@mock.patch.object(BigQueryEngineSpec, "get_time_partition_column")
@mock.patch.object(BigQueryEngineSpec, "get_max_partition_id")
@mock.patch.object(BigQueryEngineSpec, "quote_table", return_value="`table_name`")
def test_get_extra_table_metadata(
self,
mock_quote_table,
mock_get_max_partition_id,
mock_get_time_partition_column,
mock_get_engine,
):
""" """
DB Eng Specs (bigquery): Test extra table metadata DB Eng Specs (bigquery): Test extra table metadata
""" """
database = mock.Mock() database = mock.Mock()
sql = "SELECT * FROM `table_name`"
database.compile_sqla_query.return_value = sql
tbl = Table("some_table", "some_schema")
# Test no indexes # Test no indexes
database.get_indexes = mock.MagicMock(return_value=None) mock_get_time_partition_column.return_value = None
result = BigQueryEngineSpec.get_extra_table_metadata( mock_get_max_partition_id.return_value = None
database, result = BigQueryEngineSpec.get_extra_table_metadata(database, tbl)
Table("some_table", "some_schema"),
)
assert result == {} assert result == {}
index_metadata = [ mock_get_time_partition_column.return_value = "ds"
{ mock_get_max_partition_id.return_value = "19690101"
"name": "clustering", result = BigQueryEngineSpec.get_extra_table_metadata(database, tbl)
"column_names": ["c_col1", "c_col2", "c_col3"], print(result)
}, assert result == {
{ "indexes": [{"cols": ["ds"], "name": "partitioned", "type": "partitioned"}],
"name": "partition", "partitions": {
"column_names": ["p_col1", "p_col2", "p_col3"], "cols": ["ds"],
"latest": {"ds": "19690101"},
"partitionQuery": sql,
}, },
]
expected_result = {
"partitions": {"cols": [["p_col1", "p_col2", "p_col3"]]},
"clustering": {"cols": [["c_col1", "c_col2", "c_col3"]]},
}
database.get_indexes = mock.MagicMock(return_value=index_metadata)
result = BigQueryEngineSpec.get_extra_table_metadata(
database,
Table("some_table", "some_schema"),
)
assert result == expected_result
def test_get_indexes(self):
database = mock.Mock()
inspector = mock.Mock()
schema = "foo"
table_name = "bar"
inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": [None],
"unique": False,
}
]
)
assert (
BigQueryEngineSpec.get_indexes(
database,
inspector,
Table(table_name, schema),
)
== []
)
inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
}
]
)
assert BigQueryEngineSpec.get_indexes(
database,
inspector,
Table(table_name, schema),
) == [
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
} }
]
inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": ["dttm", None],
"unique": False,
}
]
)
assert BigQueryEngineSpec.get_indexes(
database,
inspector,
Table(table_name, schema),
) == [
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
}
]
@mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine") @mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine")
@mock.patch("superset.db_engine_specs.bigquery.pandas_gbq") @mock.patch("superset.db_engine_specs.bigquery.pandas_gbq")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment