diff --git a/superset-frontend/src/components/Datasource/DatasourceModal.tsx b/superset-frontend/src/components/Datasource/DatasourceModal.tsx index 1055e51f4447b227b16388c9550d2358c0bcce8f..51a60e11bef7268196ac0ac157926a8582f1e0df 100644 --- a/superset-frontend/src/components/Datasource/DatasourceModal.tsx +++ b/superset-frontend/src/components/Datasource/DatasourceModal.tsx @@ -120,71 +120,83 @@ const DatasourceModal: FunctionComponent<DatasourceModalProps> = ({ const [isEditing, setIsEditing] = useState<boolean>(false); const dialog = useRef<any>(null); const [modal, contextHolder] = Modal.useModal(); - const buildPayload = (datasource: Record<string, any>) => ({ - table_name: datasource.table_name, - database_id: datasource.database?.id, - sql: datasource.sql, - filter_select_enabled: datasource.filter_select_enabled, - fetch_values_predicate: datasource.fetch_values_predicate, - schema: - datasource.tableSelector?.schema || - datasource.databaseSelector?.schema || - datasource.schema, - description: datasource.description, - main_dttm_col: datasource.main_dttm_col, - normalize_columns: datasource.normalize_columns, - always_filter_main_dttm: datasource.always_filter_main_dttm, - offset: datasource.offset, - default_endpoint: datasource.default_endpoint, - cache_timeout: - datasource.cache_timeout === '' ? null : datasource.cache_timeout, - is_sqllab_view: datasource.is_sqllab_view, - template_params: datasource.template_params, - extra: datasource.extra, - is_managed_externally: datasource.is_managed_externally, - external_url: datasource.external_url, - metrics: datasource?.metrics?.map((metric: DatasetObject['metrics'][0]) => { - const metricBody: any = { - expression: metric.expression, - description: metric.description, - metric_name: metric.metric_name, - metric_type: metric.metric_type, - d3format: metric.d3format || null, - currency: !isDefined(metric.currency) - ? null - : JSON.stringify(metric.currency), - verbose_name: metric.verbose_name, - warning_text: metric.warning_text, - uuid: metric.uuid, - extra: buildExtraJsonObject(metric), - }; - if (!Number.isNaN(Number(metric.id))) { - metricBody.id = metric.id; - } - return metricBody; - }), - columns: datasource?.columns?.map( - (column: DatasetObject['columns'][0]) => ({ - id: typeof column.id === 'number' ? column.id : undefined, - column_name: column.column_name, - type: column.type, - advanced_data_type: column.advanced_data_type, - verbose_name: column.verbose_name, - description: column.description, - expression: column.expression, - filterable: column.filterable, - groupby: column.groupby, - is_active: column.is_active, - is_dttm: column.is_dttm, - python_date_format: column.python_date_format || null, - uuid: column.uuid, - extra: buildExtraJsonObject(column), - }), - ), - owners: datasource.owners.map( - (o: Record<string, number>) => o.value || o.id, - ), - }); + const buildPayload = (datasource: Record<string, any>) => { + const payload: Record<string, any> = { + table_name: datasource.table_name, + database_id: datasource.database?.id, + sql: datasource.sql, + filter_select_enabled: datasource.filter_select_enabled, + fetch_values_predicate: datasource.fetch_values_predicate, + schema: + datasource.tableSelector?.schema || + datasource.databaseSelector?.schema || + datasource.schema, + description: datasource.description, + main_dttm_col: datasource.main_dttm_col, + normalize_columns: datasource.normalize_columns, + always_filter_main_dttm: datasource.always_filter_main_dttm, + offset: datasource.offset, + default_endpoint: datasource.default_endpoint, + cache_timeout: + datasource.cache_timeout === '' ? null : datasource.cache_timeout, + is_sqllab_view: datasource.is_sqllab_view, + template_params: datasource.template_params, + extra: datasource.extra, + is_managed_externally: datasource.is_managed_externally, + external_url: datasource.external_url, + metrics: datasource?.metrics?.map( + (metric: DatasetObject['metrics'][0]) => { + const metricBody: any = { + expression: metric.expression, + description: metric.description, + metric_name: metric.metric_name, + metric_type: metric.metric_type, + d3format: metric.d3format || null, + currency: !isDefined(metric.currency) + ? null + : JSON.stringify(metric.currency), + verbose_name: metric.verbose_name, + warning_text: metric.warning_text, + uuid: metric.uuid, + extra: buildExtraJsonObject(metric), + }; + if (!Number.isNaN(Number(metric.id))) { + metricBody.id = metric.id; + } + return metricBody; + }, + ), + columns: datasource?.columns?.map( + (column: DatasetObject['columns'][0]) => ({ + id: typeof column.id === 'number' ? column.id : undefined, + column_name: column.column_name, + type: column.type, + advanced_data_type: column.advanced_data_type, + verbose_name: column.verbose_name, + description: column.description, + expression: column.expression, + filterable: column.filterable, + groupby: column.groupby, + is_active: column.is_active, + is_dttm: column.is_dttm, + python_date_format: column.python_date_format || null, + uuid: column.uuid, + extra: buildExtraJsonObject(column), + }), + ), + owners: datasource.owners.map( + (o: Record<string, number>) => o.value || o.id, + ), + }; + // Handle catalog based on database's allow_multi_catalog setting + // If multi-catalog is disabled, don't include catalog in payload + // The backend will use the default catalog + // If multi-catalog is enabled, include the selected catalog + if (datasource.database?.allow_multi_catalog) { + payload.catalog = datasource.catalog; + } + return payload; + }; const onConfirmSave = async () => { // Pull out extra fields into the extra object setIsSaving(true); diff --git a/superset-frontend/src/features/datasets/types.ts b/superset-frontend/src/features/datasets/types.ts index d343ad153dc7b34ee9cdeb369fe5b0647d8d0d37..3e91e8effa71e508d00220f06178cc224d4e0dcf 100644 --- a/superset-frontend/src/features/datasets/types.ts +++ b/superset-frontend/src/features/datasets/types.ts @@ -62,6 +62,7 @@ export type DatasetObject = { filter_select_enabled?: boolean; fetch_values_predicate?: string; schema?: string; + catalog?: string; description: string | null; main_dttm_col: string; offset?: number; diff --git a/superset/commands/dataset/exceptions.py b/superset/commands/dataset/exceptions.py index 04afc4fc9b0e34fc522c1bc20dd700bcbdb65af5..be82b7c88f96d6f460723ab6f15e4352813bc471 100644 --- a/superset/commands/dataset/exceptions.py +++ b/superset/commands/dataset/exceptions.py @@ -33,22 +33,25 @@ def get_dataset_exist_error_msg(table: Table) -> str: return _("Dataset %(table)s already exists", table=table) -class DatabaseNotFoundValidationError(ValidationError): +class MultiCatalogDisabledValidationError(ValidationError): """ - Marshmallow validation error for database does not exist + Validation error for using a non-default catalog when multi-catalog is disabled """ def __init__(self) -> None: - super().__init__([_("Database does not exist")], field_name="database") + super().__init__( + [_("Only the default catalog is supported for this connection")], + field_name="catalog", + ) -class DatabaseChangeValidationError(ValidationError): +class DatabaseNotFoundValidationError(ValidationError): """ - Marshmallow validation error database changes are not allowed on update + Marshmallow validation error for database does not exist """ def __init__(self) -> None: - super().__init__([_("Database not allowed to change")], field_name="database") + super().__init__([_("Database does not exist")], field_name="database") class DatasetExistsValidationError(ValidationError): diff --git a/superset/commands/dataset/update.py b/superset/commands/dataset/update.py index 7f6134d20a068c0372255fd1d24335814bd52a7e..0f0e8f30ecc5738afea312b1fb2e34a90c87df0c 100644 --- a/superset/commands/dataset/update.py +++ b/superset/commands/dataset/update.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import logging from collections import Counter from functools import partial @@ -26,7 +28,7 @@ from sqlalchemy.exc import SQLAlchemyError from superset import is_feature_enabled, security_manager from superset.commands.base import BaseCommand, UpdateMixin from superset.commands.dataset.exceptions import ( - DatabaseChangeValidationError, + DatabaseNotFoundValidationError, DatasetColumnNotFoundValidationError, DatasetColumnsDuplicateValidationError, DatasetColumnsExistsValidationError, @@ -38,11 +40,13 @@ from superset.commands.dataset.exceptions import ( DatasetMetricsNotFoundValidationError, DatasetNotFoundError, DatasetUpdateFailedError, + MultiCatalogDisabledValidationError, ) from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.daos.dataset import DatasetDAO from superset.datasets.schemas import FolderSchema from superset.exceptions import SupersetSecurityException +from superset.models.core import Database from superset.sql_parse import Table from superset.utils.decorators import on_error, transaction @@ -86,38 +90,12 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand): if not self._model: raise DatasetNotFoundError() - # Check ownership + # Check permission to update the dataset try: security_manager.raise_for_ownership(self._model) except SupersetSecurityException as ex: raise DatasetForbiddenError() from ex - database_id = self._properties.get("database") - - catalog = self._properties.get("catalog") - if not catalog: - catalog = self._properties["catalog"] = ( - self._model.database.get_default_catalog() - ) - - table = Table( - self._properties.get("table_name"), # type: ignore - self._properties.get("schema"), - catalog, - ) - - # Validate uniqueness - if not DatasetDAO.validate_update_uniqueness( - self._model.database, - table, - self._model_id, - ): - exceptions.append(DatasetExistsValidationError(table)) - - # Validate/Populate database not allowed to change - if database_id and database_id != self._model: - exceptions.append(DatabaseChangeValidationError()) - # Validate/Populate owner try: owners = self.compute_owners( @@ -128,15 +106,68 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand): except ValidationError as ex: exceptions.append(ex) + self._validate_dataset_source(exceptions) self._validate_semantics(exceptions) if exceptions: raise DatasetInvalidError(exceptions=exceptions) - def _validate_semantics(self, exceptions: list[ValidationError]) -> None: + def _validate_dataset_source(self, exceptions: list[ValidationError]) -> None: # we know we have a valid model self._model = cast(SqlaTable, self._model) + database_id = self._properties.pop("database_id", None) + catalog = self._properties.get("catalog") + new_db_connection: Database | None = None + + if database_id and database_id != self._model.database.id: + if new_db_connection := DatasetDAO.get_database_by_id(database_id): + self._properties["database"] = new_db_connection + else: + exceptions.append(DatabaseNotFoundValidationError()) + db = new_db_connection or self._model.database + default_catalog = db.get_default_catalog() + + # If multi-catalog is disabled, and catalog provided is not + # the default one, fail + if ( + "catalog" in self._properties + and catalog != default_catalog + and not db.allow_multi_catalog + ): + exceptions.append(MultiCatalogDisabledValidationError()) + + # If the DB connection does not support multi-catalog, + # use the default catalog + elif not db.allow_multi_catalog: + catalog = self._properties["catalog"] = default_catalog + + # Fallback to using the previous value if not provided + elif "catalog" not in self._properties: + catalog = self._model.catalog + + schema = ( + self._properties["schema"] + if "schema" in self._properties + else self._model.schema + ) + table = Table( + self._properties.get("table_name", self._model.table_name), + schema, + catalog, + ) + + # Validate uniqueness + if not DatasetDAO.validate_update_uniqueness( + db, + table, + self._model_id, + ): + exceptions.append(DatasetExistsValidationError(table)) + + def _validate_semantics(self, exceptions: list[ValidationError]) -> None: + # we know we have a valid model + self._model = cast(SqlaTable, self._model) if columns := self._properties.get("columns"): self._validate_columns(columns, exceptions) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 9990110befefff8ae40ef648081d4d5ff7a2c1f7..8ed7ec216ba03cdb5a45dbef40607b0b0e39e0ec 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -14,11 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Unit tests for Superset""" +from __future__ import annotations import unittest from io import BytesIO -from typing import Optional from unittest.mock import ANY, patch from zipfile import is_zipfile, ZipFile @@ -70,14 +69,26 @@ from tests.integration_tests.fixtures.importexport import ( class TestDatasetApi(SupersetTestCase): fixture_tables_names = ("ab_permission", "ab_permission_view", "ab_view_menu") fixture_virtual_table_names = ("sql_virtual_dataset_1", "sql_virtual_dataset_2") + items_to_delete: list[SqlaTable | Database | TableColumn] = [] + + def setUp(self): + self.items_to_delete = [] + + def tearDown(self): + for item in self.items_to_delete: + db.session.delete(item) + db.session.commit() + super().tearDown() @staticmethod def insert_dataset( table_name: str, owners: list[int], database: Database, - sql: Optional[str] = None, - schema: Optional[str] = None, + sql: str | None = None, + schema: str | None = None, + catalog: str | None = None, + fetch_metadata: bool = True, ) -> SqlaTable: obj_owners = list() # noqa: C408 for owner in owners: @@ -89,10 +100,12 @@ class TestDatasetApi(SupersetTestCase): owners=obj_owners, database=database, sql=sql, + catalog=catalog, ) db.session.add(table) db.session.commit() - table.fetch_metadata() + if fetch_metadata: + table.fetch_metadata() return table def insert_default_dataset(self): @@ -100,6 +113,16 @@ class TestDatasetApi(SupersetTestCase): "ab_permission", [self.get_user("admin").id], get_main_database() ) + def insert_database(self, name: str, allow_multi_catalog: bool = False) -> Database: + db_connection = Database( + database_name=name, + sqlalchemy_uri=get_example_database().sqlalchemy_uri, + extra=('{"allow_multi_catalog": true}' if allow_multi_catalog else "{}"), + ) + db.session.add(db_connection) + db.session.commit() + return db_connection + def get_fixture_datasets(self) -> list[SqlaTable]: return ( db.session.query(SqlaTable) @@ -315,8 +338,7 @@ class TestDatasetApi(SupersetTestCase): # revert gamma permission gamma_role.permissions.remove(main_db_pvm) - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_get_dataset_related_database_gamma(self): """ @@ -480,8 +502,7 @@ class TestDatasetApi(SupersetTestCase): ], } - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_get_dataset_render_jinja_exceptions(self): """ @@ -547,8 +568,7 @@ class TestDatasetApi(SupersetTestCase): == "Unable to render expression from dataset calculated column." ) - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_get_dataset_distinct_schema(self): """ @@ -618,9 +638,7 @@ class TestDatasetApi(SupersetTestCase): }, ) - for dataset in datasets: - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = datasets def test_get_dataset_distinct_not_allowed(self): """ @@ -647,8 +665,7 @@ class TestDatasetApi(SupersetTestCase): assert response["count"] == 0 assert response["result"] == [] - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_get_dataset_info(self): """ @@ -722,8 +739,7 @@ class TestDatasetApi(SupersetTestCase): ) assert columns[0].expression == "COUNT(*)" - db.session.delete(model) - db.session.commit() + self.items_to_delete = [model] def test_create_dataset_item_normalize(self): """ @@ -749,8 +765,7 @@ class TestDatasetApi(SupersetTestCase): assert model.database_id == table_data["database"] assert model.normalize_columns is True - db.session.delete(model) - db.session.commit() + self.items_to_delete = [model] def test_create_dataset_item_gamma(self): """ @@ -791,8 +806,7 @@ class TestDatasetApi(SupersetTestCase): model = db.session.query(SqlaTable).get(data.get("id")) assert admin in model.owners assert alpha in model.owners - db.session.delete(model) - db.session.commit() + self.items_to_delete = [model] def test_create_dataset_item_owners_invalid(self): """ @@ -839,8 +853,7 @@ class TestDatasetApi(SupersetTestCase): model = db.session.query(SqlaTable).get(data.get("id")) assert admin in model.owners assert alpha in model.owners - db.session.delete(model) - db.session.commit() + self.items_to_delete = [model] @unittest.skip("test is failing stochastically") def test_create_dataset_same_name_different_schema(self): @@ -991,8 +1004,7 @@ class TestDatasetApi(SupersetTestCase): model = db.session.query(SqlaTable).get(dataset.id) assert model.owners == current_owners - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_clear_owner_list(self): """ @@ -1008,8 +1020,7 @@ class TestDatasetApi(SupersetTestCase): model = db.session.query(SqlaTable).get(dataset.id) assert model.owners == [] - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_populate_owner(self): """ @@ -1026,8 +1037,7 @@ class TestDatasetApi(SupersetTestCase): model = db.session.query(SqlaTable).get(dataset.id) assert model.owners == [gamma] - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_item(self): """ @@ -1045,8 +1055,7 @@ class TestDatasetApi(SupersetTestCase): assert model.description == dataset_data["description"] assert model.owners == current_owners - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_item_w_override_columns(self): """ @@ -1082,8 +1091,7 @@ class TestDatasetApi(SupersetTestCase): col.advanced_data_type for col in columns ] - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_item_w_override_columns_same_columns(self): """ @@ -1130,8 +1138,7 @@ class TestDatasetApi(SupersetTestCase): columns = db.session.query(TableColumn).filter_by(table_id=dataset.id).all() assert len(columns) != prev_col_len assert len(columns) == 3 - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_create_column_and_metric(self): """ @@ -1226,8 +1233,7 @@ class TestDatasetApi(SupersetTestCase): assert metrics[1].warning_text == new_metric_data["warning_text"] assert str(metrics[1].uuid) == new_metric_data["uuid"] - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_delete_column(self): """ @@ -1276,8 +1282,7 @@ class TestDatasetApi(SupersetTestCase): assert columns[1].column_name == "name" assert len(columns) == 2 - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_update_column(self): """ @@ -1313,8 +1318,7 @@ class TestDatasetApi(SupersetTestCase): assert columns[0].groupby is False assert columns[0].filterable is False - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_delete_metric(self): """ @@ -1357,8 +1361,7 @@ class TestDatasetApi(SupersetTestCase): metrics = metrics_query.all() assert len(metrics) == 1 - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_update_column_uniqueness(self): """ @@ -1378,8 +1381,7 @@ class TestDatasetApi(SupersetTestCase): "message": {"columns": ["One or more columns already exist"]} } assert data == expected_result - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_update_metric_uniqueness(self): """ @@ -1399,8 +1401,7 @@ class TestDatasetApi(SupersetTestCase): "message": {"metrics": ["One or more metrics already exist"]} } assert data == expected_result - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_update_column_duplicate(self): """ @@ -1425,8 +1426,7 @@ class TestDatasetApi(SupersetTestCase): "message": {"columns": ["One or more columns are duplicated"]} } assert data == expected_result - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_update_metric_duplicate(self): """ @@ -1451,8 +1451,7 @@ class TestDatasetApi(SupersetTestCase): "message": {"metrics": ["One or more metrics are duplicated"]} } assert data == expected_result - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_item_gamma(self): """ @@ -1465,8 +1464,7 @@ class TestDatasetApi(SupersetTestCase): uri = f"api/v1/dataset/{dataset.id}" rv = self.client.put(uri, json=table_data) assert rv.status_code == 403 - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_dataset_get_list_no_username(self): """ @@ -1491,8 +1489,7 @@ class TestDatasetApi(SupersetTestCase): assert current_dataset["description"] == "changed_description" assert "username" not in current_dataset["changed_by"].keys() - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_dataset_get_no_username(self): """ @@ -1512,8 +1509,7 @@ class TestDatasetApi(SupersetTestCase): assert res["description"] == "changed_description" assert "username" not in res["changed_by"].keys() - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_item_not_owned(self): """ @@ -1526,8 +1522,7 @@ class TestDatasetApi(SupersetTestCase): uri = f"api/v1/dataset/{dataset.id}" rv = self.put_assert_metric(uri, table_data, "put") assert rv.status_code == 403 - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_update_dataset_item_owners_invalid(self): """ @@ -1540,8 +1535,7 @@ class TestDatasetApi(SupersetTestCase): uri = f"api/v1/dataset/{dataset.id}" rv = self.put_assert_metric(uri, table_data, "put") assert rv.status_code == 422 - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] @patch("superset.daos.dataset.DatasetDAO.update") def test_update_dataset_sqlalchemy_error(self, mock_dao_update): @@ -1560,8 +1554,7 @@ class TestDatasetApi(SupersetTestCase): assert rv.status_code == 422 assert data == {"message": "Dataset could not be updated."} - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] @with_feature_flags(DATASET_FOLDERS=True) def test_update_dataset_add_folders(self): @@ -1607,7 +1600,6 @@ class TestDatasetApi(SupersetTestCase): uri = f"api/v1/dataset/{dataset.id}" rv = self.put_assert_metric(uri, dataset_data, "put") - print(rv.data.decode("utf-8")) assert rv.status_code == 200 model = db.session.query(SqlaTable).get(dataset.id) @@ -1643,8 +1635,229 @@ class TestDatasetApi(SupersetTestCase): }, ] - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] + + def test_update_dataset_change_db_connection_multi_catalog_disabled(self): + """ + Dataset API: Test changing the DB connection powering the dataset + to a connection with multi-catalog disabled. + """ + self.login(ADMIN_USERNAME) + + db_connection = self.insert_database("db_connection") + new_db_connection = self.insert_database("new_db_connection") + dataset = self.insert_dataset( + table_name="test_dataset", + owners=[], + database=db_connection, + sql="select 1 as one", + schema="test_schema", + catalog="old_default_catalog", + fetch_metadata=False, + ) + + with patch.object( + new_db_connection, "get_default_catalog", return_value="new_default_catalog" + ): + payload = {"database_id": new_db_connection.id} + uri = f"api/v1/dataset/{dataset.id}" + rv = self.put_assert_metric(uri, payload, "put") + assert rv.status_code == 200 + + model = db.session.query(SqlaTable).get(dataset.id) + assert model.database == new_db_connection + # Catalog should have been updated to new connection's default catalog + assert model.catalog == "new_default_catalog" + + self.items_to_delete = [dataset, db_connection, new_db_connection] + + def test_update_dataset_change_db_connection_multi_catalog_enabled(self): + """ + Dataset API: Test changing the DB connection powering the dataset + to a connection with multi-catalog enabled. + """ + self.login(ADMIN_USERNAME) + + db_connection = self.insert_database("db_connection") + new_db_connection = self.insert_database( + "new_db_connection", allow_multi_catalog=True + ) + dataset = self.insert_dataset( + table_name="test_dataset", + owners=[], + database=db_connection, + sql="select 1 as one", + schema="test_schema", + catalog="old_default_catalog", + fetch_metadata=False, + ) + + with patch.object( + new_db_connection, "get_default_catalog", return_value="default" + ): + payload = {"database_id": new_db_connection.id} + uri = f"api/v1/dataset/{dataset.id}" + rv = self.put_assert_metric(uri, payload, "put") + assert rv.status_code == 200 + + model = db.session.query(SqlaTable).get(dataset.id) + assert model.database == new_db_connection + # Catalog was not changed as not provided and multi-catalog is enabled + assert model.catalog == "old_default_catalog" + + self.items_to_delete = [dataset, db_connection, new_db_connection] + + def test_update_dataset_change_db_connection_not_found(self): + """ + Dataset API: Test changing the DB connection powering the dataset + to an invalid DB connection. + """ + self.login(ADMIN_USERNAME) + + dataset = self.insert_default_dataset() + + payload = {"database_id": 1500} + uri = f"api/v1/dataset/{dataset.id}" + rv = self.put_assert_metric(uri, payload, "put") + response = json.loads(rv.data.decode("utf-8")) + assert rv.status_code == 422 + assert response["message"] == {"database": ["Database does not exist"]} + + self.items_to_delete = [dataset] + + def test_update_dataset_change_catalog(self): + """ + Dataset API: Test changing the catalog associated with the dataset. + """ + self.login(ADMIN_USERNAME) + + db_connection = self.insert_database("db_connection", allow_multi_catalog=True) + dataset = self.insert_dataset( + table_name="test_dataset", + owners=[], + database=db_connection, + sql="select 1 as one", + schema="test_schema", + catalog="test_catalog", + fetch_metadata=False, + ) + + with patch.object(db_connection, "get_default_catalog", return_value="default"): + payload = {"catalog": "other_catalog"} + uri = f"api/v1/dataset/{dataset.id}" + rv = self.put_assert_metric(uri, payload, "put") + assert rv.status_code == 200 + + model = db.session.query(SqlaTable).get(dataset.id) + assert model.catalog == "other_catalog" + + self.items_to_delete = [dataset, db_connection] + + def test_update_dataset_change_catalog_not_allowed(self): + """ + Dataset API: Test changing the catalog associated with the dataset fails + when multi-catalog is disabled on the DB connection. + """ + self.login(ADMIN_USERNAME) + + db_connection = self.insert_database("db_connection") + dataset = self.insert_dataset( + table_name="test_dataset", + owners=[], + database=db_connection, + sql="select 1 as one", + schema="test_schema", + catalog="test_catalog", + fetch_metadata=False, + ) + + with patch.object(db_connection, "get_default_catalog", return_value="default"): + payload = {"catalog": "other_catalog"} + uri = f"api/v1/dataset/{dataset.id}" + rv = self.put_assert_metric(uri, payload, "put") + response = json.loads(rv.data.decode("utf-8")) + assert rv.status_code == 422 + assert response["message"] == { + "catalog": ["Only the default catalog is supported for this connection"] + } + + self.items_to_delete = [dataset, db_connection] + + def test_update_dataset_validate_uniqueness(self): + """ + Dataset API: Test the dataset uniqueness validation takes into + consideration the new database connection. + """ + test_db = get_main_database() + if test_db.backend == "sqlite": + # Skip this test for SQLite as it doesn't support multiple + # schemas. + return + + self.login(ADMIN_USERNAME) + + db_connection = self.insert_database("db_connection") + new_db_connection = self.insert_database("new_db_connection") + first_schema_dataset = self.insert_dataset( + table_name="test_dataset", + owners=[], + database=db_connection, + sql="select 1 as one", + schema="first_schema", + fetch_metadata=False, + ) + second_schema_dataset = self.insert_dataset( + table_name="test_dataset", + owners=[], + database=db_connection, + sql="select 1 as one", + schema="second_schema", + fetch_metadata=False, + ) + new_db_conn_dataset = self.insert_dataset( + table_name="test_dataset", + owners=[], + database=new_db_connection, + sql="select 1 as one", + schema="first_schema", + fetch_metadata=False, + ) + + with patch.object( + db_connection, + "get_default_catalog", + return_value=None, + ): + payload = {"schema": "second_schema"} + uri = f"api/v1/dataset/{first_schema_dataset.id}" + rv = self.put_assert_metric(uri, payload, "put") + response = json.loads(rv.data.decode("utf-8")) + assert rv.status_code == 422 + assert response["message"] == { + "table": ["Dataset second_schema.test_dataset already exists"] + } + + with patch.object( + new_db_connection, + "get_default_catalog", + return_value=None, + ): + payload["database_id"] = new_db_connection.id + uri = f"api/v1/dataset/{first_schema_dataset.id}" + rv = self.put_assert_metric(uri, payload, "put") + assert rv.status_code == 200 + + model = db.session.query(SqlaTable).get(first_schema_dataset.id) + assert model.database == new_db_connection + assert model.schema == "second_schema" + + self.items_to_delete = [ + first_schema_dataset, + second_schema_dataset, + new_db_conn_dataset, + new_db_connection, + db_connection, + ] def test_delete_dataset_item(self): """ @@ -1674,8 +1887,7 @@ class TestDatasetApi(SupersetTestCase): uri = f"api/v1/dataset/{dataset.id}" rv = self.delete_assert_metric(uri, "delete") assert rv.status_code == 403 - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_delete_dataset_item_not_authorized(self): """ @@ -1687,8 +1899,7 @@ class TestDatasetApi(SupersetTestCase): uri = f"api/v1/dataset/{dataset.id}" rv = self.client.delete(uri) assert rv.status_code == 403 - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] @patch("superset.daos.dataset.DatasetDAO.delete") def test_delete_dataset_sqlalchemy_error(self, mock_dao_delete): @@ -1705,8 +1916,7 @@ class TestDatasetApi(SupersetTestCase): data = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 assert data == {"message": "Datasets could not be deleted."} - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] @pytest.mark.usefixtures("create_datasets") def test_delete_dataset_column(self): @@ -1947,8 +2157,7 @@ class TestDatasetApi(SupersetTestCase): .filter_by(table_id=dataset.id, column_name="id") .one() ) - db.session.delete(id_column) - db.session.commit() + self.items_to_delete = [id_column] self.login(ADMIN_USERNAME) uri = f"api/v1/dataset/{dataset.id}/refresh" @@ -1961,8 +2170,7 @@ class TestDatasetApi(SupersetTestCase): .one() ) assert id_column is not None - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] def test_dataset_item_refresh_not_found(self): """ @@ -1987,8 +2195,7 @@ class TestDatasetApi(SupersetTestCase): rv = self.put_assert_metric(uri, {}, "refresh") assert rv.status_code == 403 - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] @unittest.skip("test is failing stochastically") def test_export_dataset(self): @@ -2250,8 +2457,7 @@ class TestDatasetApi(SupersetTestCase): dataset = ( db.session.query(SqlaTable).filter_by(table_name="birth_names_2").one() ) - db.session.delete(dataset) - db.session.commit() + self.items_to_delete = [dataset] @patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_dataset_overwrite(self, mock_add_permissions): @@ -2447,8 +2653,7 @@ class TestDatasetApi(SupersetTestCase): response = json.loads(rv.data.decode("utf-8")) assert response.get("count") == 1 - db.session.delete(table_w_certification) - db.session.commit() + self.items_to_delete = [table_w_certification] @pytest.mark.usefixtures("create_virtual_datasets") def test_duplicate_virtual_dataset(self): @@ -2473,8 +2678,7 @@ class TestDatasetApi(SupersetTestCase): assert len(new_dataset.columns) == 2 assert new_dataset.columns[0].column_name == "id" assert new_dataset.columns[1].column_name == "name" - db.session.delete(new_dataset) - db.session.commit() + self.items_to_delete = [new_dataset] @pytest.mark.usefixtures("create_datasets") def test_duplicate_physical_dataset(self): @@ -2604,8 +2808,7 @@ class TestDatasetApi(SupersetTestCase): assert table.template_params == '{"param": 1}' assert table.normalize_columns is False - db.session.delete(table) - db.session.commit() + self.items_to_delete = [table] with examples_db.get_sqla_engine() as engine: engine.execute("DROP TABLE test_create_sqla_table_api") diff --git a/tests/unit_tests/commands/dataset/test_update.py b/tests/unit_tests/commands/dataset/update_test.py similarity index 72% rename from tests/unit_tests/commands/dataset/test_update.py rename to tests/unit_tests/commands/dataset/update_test.py index fa2026b533b19fd17e99d4d93d5e0df95d9532d4..45a1a4160d80e64ecc3aa8db28929894c6549654 100644 --- a/tests/unit_tests/commands/dataset/test_update.py +++ b/tests/unit_tests/commands/dataset/update_test.py @@ -15,78 +15,125 @@ # specific language governing permissions and limitations # under the License. -from typing import cast -from unittest.mock import MagicMock +from typing import Any, cast import pytest from marshmallow import ValidationError from pytest_mock import MockerFixture -from superset import db -from superset.commands.dataset.exceptions import DatasetInvalidError +from superset.commands.dataset.exceptions import ( + DatabaseNotFoundValidationError, + DatasetExistsValidationError, + DatasetForbiddenError, + DatasetInvalidError, + DatasetNotFoundError, + MultiCatalogDisabledValidationError, +) from superset.commands.dataset.update import UpdateDatasetCommand, validate_folders -from superset.connectors.sqla.models import SqlaTable +from superset.commands.exceptions import OwnersNotFoundValidationError from superset.datasets.schemas import FolderSchema -from superset.models.core import Database +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import SupersetSecurityException from tests.unit_tests.conftest import with_feature_flags -@pytest.mark.usefixture("session") -def test_update_uniqueness_error(mocker: MockerFixture) -> None: +def test_update_dataset_not_found(mocker: MockerFixture) -> None: """ - Test uniqueness validation in dataset update command. + Test updating an unexisting ID raises a `DatasetNotFoundError`. """ - SqlaTable.metadata.create_all(db.session.get_bind()) - - # First, make sure session is clean - db.session.rollback() - - try: - # Set up test data - database = Database(database_name="my_db", sqlalchemy_uri="sqlite://") - bar = SqlaTable(table_name="bar", schema="foo", database=database) - baz = SqlaTable(table_name="baz", schema="qux", database=database) - db.session.add_all([database, bar, baz]) - db.session.commit() - - # Set up mocks - mock_g = mocker.patch("superset.security.manager.g") - mock_g.user = MagicMock() - mocker.patch( - "superset.views.base.security_manager.can_access_all_datasources", - return_value=True, - ) - mocker.patch( - "superset.commands.dataset.update.security_manager.raise_for_ownership", - return_value=None, - ) - mocker.patch.object(UpdateDatasetCommand, "compute_owners", return_value=[]) - - # Run the test that should fail - with pytest.raises(DatasetInvalidError): - UpdateDatasetCommand( - bar.id, - { - "table_name": "baz", - "schema": "qux", - }, - ).run() - except Exception: - db.session.rollback() - raise - finally: - # Clean up - this will run even if the test fails - try: - db.session.query(SqlaTable).filter( - SqlaTable.table_name.in_(["bar", "baz"]), - SqlaTable.schema.in_(["foo", "qux"]), - ).delete(synchronize_session=False) - db.session.query(Database).filter(Database.database_name == "my_db").delete( - synchronize_session=False + mock_dataset_dao = mocker.patch("superset.commands.dataset.update.DatasetDAO") + mock_dataset_dao.find_by_id.return_value = None + + with pytest.raises(DatasetNotFoundError): + UpdateDatasetCommand(1, {"name": "test"}).run() + + +def test_update_dataset_forbidden(mocker: MockerFixture) -> None: + """ + Test try updating a dataset without permission raises a `DatasetForbiddenError`. + """ + mock_dataset_dao = mocker.patch("superset.commands.dataset.update.DatasetDAO") + mock_dataset_dao.find_by_id.return_value = mocker.MagicMock() + + mocker.patch( + "superset.commands.dataset.update.security_manager.raise_for_ownership", + side_effect=SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.MISSING_OWNERSHIP_ERROR, + message="Sample message", + level=ErrorLevel.ERROR, ) - db.session.commit() - except Exception: - db.session.rollback() + ), + ) + + with pytest.raises(DatasetForbiddenError): + UpdateDatasetCommand(1, {"name": "test"}).run() + + +@pytest.mark.parametrize( + ("payload, exception, error_msg"), + [ + ( + {"database_id": 2}, + DatabaseNotFoundValidationError, + "Database does not exist", + ), + ( + {"catalog": "test"}, + MultiCatalogDisabledValidationError, + "Only the default catalog is supported for this connection", + ), + ( + {"table_name": "table", "schema": "schema"}, + DatasetExistsValidationError, + "Dataset catalog.schema.table already exists", + ), + ( + {"owners": [1]}, + OwnersNotFoundValidationError, + "Owners are invalid", + ), + ], +) +def test_update_validation_errors( + payload: dict[str, Any], + exception: Exception, + error_msg: str, + mocker: MockerFixture, +) -> None: + """ + Test validation errors for the `UpdateDatasetCommand`. + """ + mock_dataset_dao = mocker.patch("superset.commands.dataset.update.DatasetDAO") + mocker.patch( + "superset.commands.dataset.update.security_manager.raise_for_ownership", + ) + mocker.patch("superset.commands.utils.security_manager.is_admin", return_value=True) + mocker.patch( + "superset.commands.utils.security_manager.get_user_by_id", return_value=None + ) + mock_database = mocker.MagicMock() + mock_database.id = 1 + mock_database.get_default_catalog.return_value = "catalog" + mock_database.allow_multi_catalog = False + mock_dataset = mocker.MagicMock() + mock_dataset.database = mock_database + mock_dataset.catalog = "catalog" + mock_dataset_dao.find_by_id.return_value = mock_dataset + + if exception == DatabaseNotFoundValidationError: + mock_dataset_dao.get_database_by_id.return_value = None + else: + mock_dataset_dao.get_database_by_id.return_value = mock_database + + if exception == DatasetExistsValidationError: + mock_dataset_dao.validate_update_uniqueness.return_value = False + else: + mock_dataset_dao.validate_update_uniqueness.return_value = True + + with pytest.raises(DatasetInvalidError) as excinfo: + UpdateDatasetCommand(1, payload).run() + assert any(error_msg in str(exc) for exc in excinfo.value._exceptions) @with_feature_flags(DATASET_FOLDERS=True)