Skip to content
Snippets Groups Projects
Unverified Commit fd947a09 authored by Asher  Manangan's avatar Asher Manangan Committed by GitHub
Browse files

feat(tags): Export and Import Functionality for Superset Dashboards and Charts (#30833)

parent e1383d38
Branches
No related tags found
No related merge requests found
Showing
with 512 additions and 18 deletions
......@@ -62,7 +62,6 @@ MAPBOX_API_KEY=''
# Make sure you set this to a unique secure random value on production
SUPERSET_SECRET_KEY=TEST_NON_DEV_SECRET
ENABLE_PLAYWRIGHT=false
PUPPETEER_SKIP_CHROMIUM_DOWNLOAD=true
BUILD_SUPERSET_FRONTEND_IN_DOCKER=true
......
......@@ -1564,6 +1564,7 @@ class ImportV1ChartSchema(Schema):
dataset_uuid = fields.UUID(required=True)
is_managed_externally = fields.Boolean(allow_none=True, dump_default=False)
external_url = fields.String(allow_none=True)
tags = fields.List(fields.String(), allow_none=True)
class ChartCacheWarmUpRequestSchema(Schema):
......
......@@ -26,10 +26,13 @@ from superset.commands.chart.exceptions import ChartNotFoundError
from superset.daos.chart import ChartDAO
from superset.commands.dataset.export import ExportDatasetsCommand
from superset.commands.export.models import ExportModelsCommand
from superset.commands.tag.export import ExportTagsCommand
from superset.models.slice import Slice
from superset.tags.models import TagType
from superset.utils.dict_import_export import EXPORT_VERSION
from superset.utils.file import get_filename
from superset.utils import json
from superset.extensions import feature_flag_manager
logger = logging.getLogger(__name__)
......@@ -71,9 +74,23 @@ class ExportChartsCommand(ExportModelsCommand):
if model.table:
payload["dataset_uuid"] = str(model.table.uuid)
# Fetch tags from the database if TAGGING_SYSTEM is enabled
if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):
tags = getattr(model, "tags", [])
payload["tags"] = [tag.name for tag in tags if tag.type == TagType.custom]
file_content = yaml.safe_dump(payload, sort_keys=False)
return file_content
_include_tags: bool = True # Default to True
@classmethod
def disable_tag_export(cls) -> None:
cls._include_tags = False
@classmethod
def enable_tag_export(cls) -> None:
cls._include_tags = True
@staticmethod
def _export(
model: Slice, export_related: bool = True
......@@ -85,3 +102,12 @@ class ExportChartsCommand(ExportModelsCommand):
if model.table and export_related:
yield from ExportDatasetsCommand([model.table.id]).run()
# Check if the calling class is ExportDashboardCommands
if (
export_related
and ExportChartsCommand._include_tags
and feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM")
):
chart_id = model.id
yield from ExportTagsCommand().export(chart_ids=[chart_id])
......@@ -14,23 +14,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session # noqa: F401
from superset import db
from superset.charts.schemas import ImportV1ChartSchema
from superset.commands.chart.exceptions import ChartImportError
from superset.commands.chart.importers.v1.utils import import_chart
from superset.commands.database.importers.v1.utils import import_database
from superset.commands.dataset.importers.v1.utils import import_dataset
from superset.commands.importers.v1 import ImportModelsCommand
from superset.commands.importers.v1.utils import import_tag
from superset.commands.utils import update_chart_config_dataset
from superset.connectors.sqla.models import SqlaTable
from superset.daos.chart import ChartDAO
from superset.databases.schemas import ImportV1DatabaseSchema
from superset.datasets.schemas import ImportV1DatasetSchema
from superset.extensions import feature_flag_manager
class ImportChartsCommand(ImportModelsCommand):
......@@ -47,7 +51,13 @@ class ImportChartsCommand(ImportModelsCommand):
import_error = ChartImportError
@staticmethod
def _import(configs: dict[str, Any], overwrite: bool = False) -> None: # noqa: C901
# ruff: noqa: C901
def _import(
configs: dict[str, Any],
overwrite: bool = False,
contents: dict[str, Any] | None = None,
) -> None:
contents = {} if contents is None else contents
# discover datasets associated with charts
dataset_uuids: set[str] = set()
for file_name, config in configs.items():
......@@ -93,4 +103,12 @@ class ImportChartsCommand(ImportModelsCommand):
"datasource_name": dataset.table_name,
}
config = update_chart_config_dataset(config, dataset_dict)
import_chart(config, overwrite=overwrite)
chart = import_chart(config, overwrite=overwrite)
# Handle tags using import_tag function
if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):
if "tags" in config:
target_tag_names = config["tags"]
import_tag(
target_tag_names, contents, chart.id, "chart", db.session
)
......@@ -25,6 +25,7 @@ from collections.abc import Iterator
import yaml
from superset.commands.chart.export import ExportChartsCommand
from superset.commands.tag.export import ExportTagsCommand
from superset.commands.dashboard.exceptions import DashboardNotFoundError
from superset.commands.dashboard.importers.v1.utils import find_chart_uuids
from superset.daos.dashboard import DashboardDAO
......@@ -33,9 +34,11 @@ from superset.commands.dataset.export import ExportDatasetsCommand
from superset.daos.dataset import DatasetDAO
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.tags.models import TagType
from superset.utils.dict_import_export import EXPORT_VERSION
from superset.utils.file import get_filename
from superset.utils import json
from superset.extensions import feature_flag_manager # Import the feature flag manager
logger = logging.getLogger(__name__)
......@@ -112,6 +115,7 @@ class ExportDashboardsCommand(ExportModelsCommand):
return f"dashboards/{file_name}.yaml"
@staticmethod
# ruff: noqa: C901
def _file_content(model: Dashboard) -> str:
payload = model.export_to_dict(
recursive=False,
......@@ -159,10 +163,16 @@ class ExportDashboardsCommand(ExportModelsCommand):
payload["version"] = EXPORT_VERSION
# Check if the TAGGING_SYSTEM feature is enabled
if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):
tags = model.tags if hasattr(model, "tags") else []
payload["tags"] = [tag.name for tag in tags if tag.type == TagType.custom]
file_content = yaml.safe_dump(payload, sort_keys=False)
return file_content
@staticmethod
# ruff: noqa: C901
def _export(
model: Dashboard, export_related: bool = True
) -> Iterator[tuple[str, Callable[[], str]]]:
......@@ -173,7 +183,15 @@ class ExportDashboardsCommand(ExportModelsCommand):
if export_related:
chart_ids = [chart.id for chart in model.slices]
yield from ExportChartsCommand(chart_ids).run()
dashboard_ids = model.id
command = ExportChartsCommand(chart_ids)
command.disable_tag_export()
yield from command.run()
command.enable_tag_export()
if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):
yield from ExportTagsCommand.export(
dashboard_ids=dashboard_ids, chart_ids=chart_ids
)
payload = model.export_to_dict(
recursive=False,
......
......@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import Any
from marshmallow import Schema
......@@ -34,11 +36,13 @@ from superset.commands.dashboard.importers.v1.utils import (
from superset.commands.database.importers.v1.utils import import_database
from superset.commands.dataset.importers.v1.utils import import_dataset
from superset.commands.importers.v1 import ImportModelsCommand
from superset.commands.importers.v1.utils import import_tag
from superset.commands.utils import update_chart_config_dataset
from superset.daos.dashboard import DashboardDAO
from superset.dashboards.schemas import ImportV1DashboardSchema
from superset.databases.schemas import ImportV1DatabaseSchema
from superset.datasets.schemas import ImportV1DatasetSchema
from superset.extensions import feature_flag_manager
from superset.migrations.shared.native_filters import migrate_dashboard
from superset.models.dashboard import Dashboard, dashboard_slices
......@@ -58,9 +62,15 @@ class ImportDashboardsCommand(ImportModelsCommand):
import_error = DashboardImportError
# TODO (betodealmeida): refactor to use code from other commands
# pylint: disable=too-many-branches, too-many-locals
# pylint: disable=too-many-branches, too-many-locals, too-many-statements
@staticmethod
def _import(configs: dict[str, Any], overwrite: bool = False) -> None: # noqa: C901
# ruff: noqa: C901
def _import(
configs: dict[str, Any],
overwrite: bool = False,
contents: dict[str, Any] | None = None,
) -> None:
contents = {} if contents is None else contents
# discover charts and datasets associated with dashboards
chart_uuids: set[str] = set()
dataset_uuids: set[str] = set()
......@@ -120,6 +130,14 @@ class ImportDashboardsCommand(ImportModelsCommand):
charts.append(chart)
chart_ids[str(chart.uuid)] = chart.id
# Handle tags using import_tag function
if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):
if "tags" in config:
target_tag_names = config["tags"]
import_tag(
target_tag_names, contents, chart.id, "chart", db.session
)
# store the existing relationship between dashboards and charts
existing_relationships = db.session.execute(
select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id])
......@@ -140,6 +158,18 @@ class ImportDashboardsCommand(ImportModelsCommand):
if (dashboard.id, chart_id) not in existing_relationships:
dashboard_chart_ids.append((dashboard.id, chart_id))
# Handle tags using import_tag function
if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):
if "tags" in config:
target_tag_names = config["tags"]
import_tag(
target_tag_names,
contents,
dashboard.id,
"dashboard",
db.session,
)
# set ref in the dashboard_slices table
values = [
{"dashboard_id": dashboard_id, "slice_id": chart_id}
......
......@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import Any
from marshmallow import Schema
......@@ -42,7 +44,11 @@ class ImportDatabasesCommand(ImportModelsCommand):
import_error = DatabaseImportError
@staticmethod
def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
def _import(
configs: dict[str, Any],
overwrite: bool = False,
contents: dict[str, Any] | None = None,
) -> None:
# first import databases
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
......
......@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from typing import Any
from typing import Any, Optional
from marshmallow import Schema
from sqlalchemy.orm import Session # noqa: F401
......@@ -42,7 +42,13 @@ class ImportDatasetsCommand(ImportModelsCommand):
import_error = DatasetImportError
@staticmethod
def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
def _import(
configs: dict[str, Any],
overwrite: bool = False,
contents: Optional[dict[str, Any]] = None,
) -> None:
if contents is None:
contents = {}
# discover databases associated with datasets
database_uuids: set[str] = set()
for file_name, config in configs.items():
......
......@@ -53,6 +53,7 @@ class ExportAssetsCommand(BaseCommand):
ExportDashboardsCommand,
ExportSavedQueriesCommand,
]
for command in commands:
ids = [model.id for model in command.dao.find_all()]
for file_name, file_content in command(ids, export_related=False).run():
......
......@@ -14,8 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from typing import Any, Optional
from typing import Any
from marshmallow import Schema, validate # noqa: F401
from marshmallow.exceptions import ValidationError
......@@ -64,7 +67,12 @@ class ImportModelsCommand(BaseCommand):
self._configs: dict[str, Any] = {}
@staticmethod
def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
# ruff: noqa: C901
def _import(
configs: dict[str, Any],
overwrite: bool = False,
contents: dict[str, Any] | None = None,
) -> None:
raise NotImplementedError("Subclasses MUST implement _import")
@classmethod
......@@ -76,7 +84,7 @@ class ImportModelsCommand(BaseCommand):
self.validate()
try:
self._import(self._configs, self.overwrite)
self._import(self._configs, self.overwrite, self.contents)
except CommandException:
raise
except Exception as ex:
......@@ -87,7 +95,7 @@ class ImportModelsCommand(BaseCommand):
# verify that the metadata file is present and valid
try:
metadata: Optional[dict[str, str]] = load_metadata(self.contents)
metadata: dict[str, str] | None = load_metadata(self.contents)
except ValidationError as exc:
exceptions.append(exc)
metadata = None
......
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any
from typing import Any, Optional
from marshmallow import Schema
from sqlalchemy.exc import MultipleResultsFound
......@@ -90,6 +90,7 @@ class ImportExamplesCommand(ImportModelsCommand):
def _import( # pylint: disable=too-many-locals, too-many-branches # noqa: C901
configs: dict[str, Any],
overwrite: bool = False,
contents: Optional[dict[str, Any]] = None,
force_data: bool = False,
) -> None:
# import databases
......
......@@ -21,12 +21,17 @@ from zipfile import ZipFile
import yaml
from marshmallow import fields, Schema, validate
from marshmallow.exceptions import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from superset import db
from superset.commands.importers.exceptions import IncorrectVersionError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.extensions import feature_flag_manager
from superset.models.core import Database
from superset.tags.models import Tag, TaggedObject
from superset.utils.core import check_is_safe_zip
from superset.utils.decorators import transaction
METADATA_FILE_NAME = "metadata.yaml"
IMPORT_VERSION = "1.0.0"
......@@ -96,7 +101,8 @@ def validate_metadata_type(
# pylint: disable=too-many-locals,too-many-arguments
def load_configs( # noqa: C901
# ruff: noqa: C901
def load_configs(
contents: dict[str, str],
schemas: dict[str, Schema],
passwords: dict[str, str],
......@@ -216,6 +222,91 @@ def get_contents_from_bundle(bundle: ZipFile) -> dict[str, str]:
}
# pylint: disable=consider-using-transaction
# ruff: noqa: C901
@transaction()
def import_tag(
target_tag_names: list[str],
contents: dict[str, Any],
object_id: int,
object_type: str,
db_session: Session,
) -> list[int]:
"""Handles the import logic for tags for charts and dashboards"""
if not feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):
return []
tag_descriptions = {}
new_tag_ids = []
if "tags.yaml" in contents:
try:
tags_config = yaml.safe_load(contents["tags.yaml"])
except yaml.YAMLError as err:
logger.error("Error parsing tags.yaml: %s", err)
tags_config = {}
for tag_info in tags_config.get("tags", []):
tag_name = tag_info.get("tag_name")
description = tag_info.get("description", None)
if tag_name:
tag_descriptions[tag_name] = description
existing_assocs = (
db_session.query(TaggedObject)
.filter_by(object_id=object_id, object_type=object_type)
.all()
)
existing_tags = {
tag.name: tag
for tag in db_session.query(Tag).filter(Tag.name.in_(target_tag_names))
}
for tag_name in target_tag_names:
try:
tag = existing_tags.get(tag_name)
# If tag does not exist, create it
if tag is None:
description = tag_descriptions.get(tag_name, None)
tag = Tag(name=tag_name, description=description, type="custom")
db_session.add(tag)
existing_tags[tag_name] = tag # Update the existing_tags dictionary
# Ensure the association with the object
tagged_object = (
db_session.query(TaggedObject)
.filter_by(object_id=object_id, object_type=object_type, tag_id=tag.id)
.first()
)
if not tagged_object:
new_tagged_object = TaggedObject(
tag_id=tag.id, object_id=object_id, object_type=object_type
)
db_session.add(new_tagged_object)
new_tag_ids.append(tag.id)
except SQLAlchemyError as err:
logger.error(
"Error processing tag '%s' for %s ID %d: %s",
tag_name,
object_type,
object_id,
err,
)
continue # No need for manual rollback, handled by transaction decorator
# Remove old tags not in the new config
for tag in existing_assocs:
if tag.tag_id not in new_tag_ids:
db_session.delete(tag)
return new_tag_ids
def get_resource_mappings_batched(
model_class: Type[Any],
batch_size: int = 1000,
......
......@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from typing import Any
from typing import Any, Optional
from marshmallow import Schema
from sqlalchemy.orm import Session # noqa: F401
......@@ -43,7 +43,11 @@ class ImportSavedQueriesCommand(ImportModelsCommand):
import_error = SavedQueryImportError
@staticmethod
def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
def _import(
configs: dict[str, Any],
overwrite: bool = False,
contents: Optional[dict[str, Any]] = None,
) -> None:
# discover databases associated with saved queries
database_uuids: set[str] = set()
for file_name, config in configs.items():
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
from typing import Any, Callable, List, Optional, Union
from collections.abc import Iterator
import yaml
from superset.daos.chart import ChartDAO
from superset.daos.dashboard import DashboardDAO
from superset.extensions import feature_flag_manager
from superset.tags.models import TagType
from superset.commands.tag.exceptions import TagNotFoundError
# pylint: disable=too-few-public-methods
class ExportTagsCommand:
not_found = TagNotFoundError
@staticmethod
def _file_name() -> str:
# Use the model to determine the filename
return "tags.yaml"
@staticmethod
def _merge_tags(
dashboard_tags: List[dict[str, Any]], chart_tags: List[dict[str, Any]]
) -> List[dict[str, Any]]:
# Create a dictionary to prevent duplicates based on tag name
tags_dict = {tag["tag_name"]: tag for tag in dashboard_tags}
# Add chart tags, preserving unique tag names
for tag in chart_tags:
if tag["tag_name"] not in tags_dict:
tags_dict[tag["tag_name"]] = tag
# Return merged tags as a list
return list(tags_dict.values())
@staticmethod
def _file_content(
dashboard_ids: Optional[Union[int, List[Union[int, str]]]] = None,
chart_ids: Optional[Union[int, List[Union[int, str]]]] = None,
) -> str:
payload: dict[str, list[dict[str, Any]]] = {"tags": []}
dashboard_tags = []
chart_tags = []
# Fetch dashboard tags if provided
if dashboard_ids:
# Ensure dashboard_ids is a list
if isinstance(dashboard_ids, int):
dashboard_ids = [
dashboard_ids
] # Convert single int to list for consistency
dashboards = [
dashboard
for dashboard in (
DashboardDAO.find_by_id(dashboard_id)
for dashboard_id in dashboard_ids
)
if dashboard is not None
]
for dashboard in dashboards:
tags = dashboard.tags if hasattr(dashboard, "tags") else []
filtered_tags = [
{"tag_name": tag.name, "description": tag.description}
for tag in tags
if tag.type == TagType.custom
]
dashboard_tags.extend(filtered_tags)
# Fetch chart tags if provided
if chart_ids:
# Ensure chart_ids is a list
if isinstance(chart_ids, int):
chart_ids = [chart_ids] # Convert single int to list for consistency
charts = [
chart
for chart in (ChartDAO.find_by_id(chart_id) for chart_id in chart_ids)
if chart is not None
]
for chart in charts:
tags = chart.tags if hasattr(chart, "tags") else []
filtered_tags = [
{"tag_name": tag.name, "description": tag.description}
for tag in tags
if "type:" not in tag.name and "owner:" not in tag.name
]
chart_tags.extend(filtered_tags)
# Merge the tags from both dashboards and charts
merged_tags = ExportTagsCommand._merge_tags(dashboard_tags, chart_tags)
payload["tags"].extend(merged_tags)
# Convert to YAML format
file_content = yaml.safe_dump(payload, sort_keys=False)
return file_content
@staticmethod
def export(
dashboard_ids: Optional[Union[int, List[Union[int, str]]]] = None,
chart_ids: Optional[Union[int, List[Union[int, str]]]] = None,
) -> Iterator[tuple[str, Callable[[], str]]]:
if not feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):
return
yield (
ExportTagsCommand._file_name(),
lambda: ExportTagsCommand._file_content(dashboard_ids, chart_ids),
)
......@@ -484,6 +484,7 @@ class ImportV1DashboardSchema(Schema):
certified_by = fields.String(allow_none=True)
certification_details = fields.String(allow_none=True)
published = fields.Boolean(allow_none=True)
tags = fields.List(fields.String(), allow_none=True)
class EmbeddedDashboardConfigSchema(Schema):
......
......@@ -18,8 +18,10 @@
import copy
from collections.abc import Generator
from unittest.mock import patch
import pytest
import yaml
from flask_appbuilder.security.sqla.models import Role, User
from pytest_mock import MockerFixture
from sqlalchemy.orm.session import Session
......@@ -27,8 +29,11 @@ from sqlalchemy.orm.session import Session
from superset import security_manager
from superset.commands.chart.importers.v1.utils import import_chart
from superset.commands.exceptions import ImportFailedError
from superset.commands.importers.v1.utils import import_tag
from superset.connectors.sqla.models import Database, SqlaTable
from superset.extensions import feature_flag_manager
from superset.models.slice import Slice
from superset.tags.models import TaggedObject
from superset.utils.core import override_user
from tests.integration_tests.fixtures.importexport import chart_config
......@@ -280,3 +285,43 @@ def test_import_existing_chart_with_permission(
# Assert that the can write to chart was checked
mock_can_access.assert_called_once_with("can_write", "Chart")
mock_can_access_chart.assert_called_once_with(slice)
def test_import_tag_logic_for_charts(session_with_schema: Session):
contents = {
"tags.yaml": yaml.dump(
{"tags": [{"tag_name": "tag_1", "description": "Description for tag_1"}]}
)
}
object_id = 1
object_type = "chart"
with patch.object(feature_flag_manager, "is_feature_enabled", return_value=True):
new_tag_ids = import_tag(
["tag_1"], contents, object_id, object_type, session_with_schema
)
assert len(new_tag_ids) > 0
assert (
session_with_schema.query(TaggedObject)
.filter_by(object_id=object_id, object_type=object_type)
.count()
> 0
)
session_with_schema.query(TaggedObject).filter_by(
object_id=object_id, object_type=object_type
).delete()
session_with_schema.commit()
with patch.object(feature_flag_manager, "is_feature_enabled", return_value=False):
new_tag_ids_disabled = import_tag(
["tag_1"], contents, object_id, object_type, session_with_schema
)
assert len(new_tag_ids_disabled) == 0
associated_tags = (
session_with_schema.query(TaggedObject)
.filter_by(object_id=object_id, object_type=object_type)
.all()
)
assert len(associated_tags) == 0
......@@ -16,9 +16,15 @@
# under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel
from unittest.mock import patch
import pytest
import yaml
from freezegun import freeze_time
from pytest_mock import MockerFixture
from superset.extensions import feature_flag_manager
def test_export_assets_command(mocker: MockerFixture) -> None:
"""
......@@ -80,7 +86,6 @@ def test_export_assets_command(mocker: MockerFixture) -> None:
with freeze_time("2022-01-01T00:00:00Z"):
command = ExportAssetsCommand()
output = [(file[0], file[1]()) for file in list(command.run())]
assert output == [
(
"metadata.yaml",
......@@ -92,3 +97,61 @@ def test_export_assets_command(mocker: MockerFixture) -> None:
("dashboards/sales.yaml", "<DASHBOARD CONTENTS>"),
("queries/example/metric.yaml", "<SAVED QUERY CONTENTS>"),
]
@pytest.fixture
def mock_export_tags_command_charts_dashboards(mocker):
export_tags = mocker.patch("superset.commands.tag.export.ExportTagsCommand")
def _mock_export(dashboard_ids=None, chart_ids=None):
if not feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):
return iter([])
return [
(
"tags.yaml",
lambda: yaml.dump(
{
"tags": [
{
"tag_name": "tag_1",
"description": "Description for tag_1",
}
]
},
sort_keys=False,
),
),
("charts/pie.yaml", lambda: "tag:\n- tag_1"),
]
export_tags.return_value._export.side_effect = _mock_export
return export_tags
def test_export_tags_with_charts_dashboards(
mock_export_tags_command_charts_dashboards, mocker
):
with patch.object(feature_flag_manager, "is_feature_enabled", return_value=True):
command = mock_export_tags_command_charts_dashboards()
result = list(command._export(chart_ids=[1]))
file_name, file_content_func = result[0]
file_content = file_content_func()
assert file_name == "tags.yaml"
payload = yaml.safe_load(file_content)
assert payload["tags"] == [
{"tag_name": "tag_1", "description": "Description for tag_1"}
]
file_name, file_content_func = result[1]
file_content = file_content_func()
assert file_name == "charts/pie.yaml"
assert file_content == "tag:\n- tag_1"
with patch.object(feature_flag_manager, "is_feature_enabled", return_value=False):
command = mock_export_tags_command_charts_dashboards()
result = list(command._export(chart_ids=[1]))
assert not any(file_name == "tags.yaml" for file_name, _ in result)
assert all(
file_content_func() != "tag:\n- tag_1" for _, file_content_func in result
)
......@@ -18,8 +18,10 @@
import copy
from collections.abc import Generator
from unittest.mock import patch
import pytest
import yaml
from flask_appbuilder.security.sqla.models import Role, User
from pytest_mock import MockerFixture
from sqlalchemy.orm.session import Session
......@@ -27,7 +29,10 @@ from sqlalchemy.orm.session import Session
from superset import security_manager
from superset.commands.dashboard.importers.v1.utils import import_dashboard
from superset.commands.exceptions import ImportFailedError
from superset.commands.importers.v1.utils import import_tag
from superset.extensions import feature_flag_manager
from superset.models.dashboard import Dashboard
from superset.tags.models import TaggedObject
from superset.utils.core import override_user
from tests.integration_tests.fixtures.importexport import dashboard_config
......@@ -238,3 +243,43 @@ def test_import_existing_dashboard_with_permission(
# Assert that the can write to dashboard was checked
mock_can_access.assert_called_once_with("can_write", "Dashboard")
mock_can_access_dashboard.assert_called_once_with(dashboard)
def test_import_tag_logic_for_dashboards(session_with_schema: Session):
contents = {
"tags.yaml": yaml.dump(
{"tags": [{"tag_name": "tag_1", "description": "Description for tag_1"}]}
)
}
object_id = 1
object_type = "dashboards"
with patch.object(feature_flag_manager, "is_feature_enabled", return_value=True):
new_tag_ids = import_tag(
["tag_1"], contents, object_id, object_type, session_with_schema
)
assert len(new_tag_ids) > 0
assert (
session_with_schema.query(TaggedObject)
.filter_by(object_id=object_id, object_type=object_type)
.count()
> 0
)
session_with_schema.query(TaggedObject).filter_by(
object_id=object_id, object_type=object_type
).delete()
session_with_schema.commit()
with patch.object(feature_flag_manager, "is_feature_enabled", return_value=False):
new_tag_ids_disabled = import_tag(
["tag_1"], contents, object_id, object_type, session_with_schema
)
assert len(new_tag_ids_disabled) == 0
associated_tags = (
session_with_schema.query(TaggedObject)
.filter_by(object_id=object_id, object_type=object_type)
.all()
)
assert len(associated_tags) == 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment