Skip to content
Snippets Groups Projects
Select Git revision
  • f29e1e4c29a46f7d607cfa59adb8bb21d107091c
  • master default protected
  • develop
  • release
4 results

test_databricks.py

Blame
  • test_databricks.py 8.24 KiB
    # Licensed to the Apache Software Foundation (ASF) under one
    # or more contributor license agreements.  See the NOTICE file
    # distributed with this work for additional information
    # regarding copyright ownership.  The ASF licenses this file
    # to you under the Apache License, Version 2.0 (the
    # "License"); you may not use this file except in compliance
    # with the License.  You may obtain a copy of the License at
    #
    #   http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing,
    # software distributed under the License is distributed on an
    # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    # KIND, either express or implied.  See the License for the
    # specific language governing permissions and limitations
    # under the License.
    # pylint: disable=unused-argument, import-outside-toplevel, protected-access
    
    import json
    from datetime import datetime
    from typing import Optional
    
    import pytest
    from pytest_mock import MockerFixture
    
    from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
    from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
    from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
    from tests.unit_tests.fixtures.common import dttm  # noqa: F401
    
    
    def test_get_parameters_from_uri() -> None:
        """
        Test that the result from ``get_parameters_from_uri`` is JSON serializable.
        """
        from superset.db_engine_specs.databricks import (
            DatabricksNativeEngineSpec,
            DatabricksNativeParametersType,
        )
    
        parameters = DatabricksNativeEngineSpec.get_parameters_from_uri(
            "databricks+connector://token:abc12345@my_hostname:1234/test"
        )
        assert parameters == DatabricksNativeParametersType(
            {
                "access_token": "abc12345",
                "host": "my_hostname",
                "port": 1234,
                "database": "test",
                "encryption": False,
            }
        )
        assert json.loads(json.dumps(parameters)) == parameters
    
    
    def test_build_sqlalchemy_uri() -> None:
        """
        test that the parameters are can correctly be compiled into a
        sqlalchemy_uri
        """
        from superset.db_engine_specs.databricks import (
            DatabricksNativeEngineSpec,
            DatabricksNativeParametersType,
        )
    
        parameters = DatabricksNativeParametersType(
            {
                "access_token": "abc12345",
                "host": "my_hostname",
                "port": 1234,
                "database": "test",
                "encryption": False,
            }
        )
        encrypted_extra = None
        sqlalchemy_uri = DatabricksNativeEngineSpec.build_sqlalchemy_uri(
            parameters, encrypted_extra
        )
        assert sqlalchemy_uri == (
            "databricks+connector://token:abc12345@my_hostname:1234/test"
        )
    
    
    def test_parameters_json_schema() -> None:
        """
        test that the parameters schema can be converted to json
        """
        from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
    
        json_schema = DatabricksNativeEngineSpec.parameters_json_schema()
    
        assert json_schema == {
            "type": "object",
            "properties": {
                "access_token": {"type": "string"},
                "database": {"type": "string"},
                "encryption": {
                    "description": "Use an encrypted connection to the database",
                    "type": "boolean",
                },
                "host": {"type": "string"},
                "http_path": {"type": "string"},
                "port": {
                    "description": "Database port",
                    "maximum": 65536,
                    "minimum": 0,
                    "type": "integer",
                },
            },
            "required": ["access_token", "database", "host", "http_path", "port"],
        }
    
    
    def test_get_extra_params(mocker: MockerFixture) -> None:
        """
        Test the ``get_extra_params`` method.
        """
        from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
    
        database = mocker.MagicMock()
    
        database.extra = {}
        assert DatabricksNativeEngineSpec.get_extra_params(database) == {
            "engine_params": {
                "connect_args": {
                    "http_headers": [("User-Agent", "Apache Superset")],
                    "_user_agent_entry": "Apache Superset",
                }
            }
        }
    
        database.extra = json.dumps(
            {
                "engine_params": {
                    "connect_args": {
                        "http_headers": [("User-Agent", "Custom user agent")],
                        "_user_agent_entry": "Custom user agent",
                        "foo": "bar",
                    }
                }
            }
        )
        assert DatabricksNativeEngineSpec.get_extra_params(database) == {
            "engine_params": {
                "connect_args": {
                    "http_headers": [["User-Agent", "Custom user agent"]],
                    "_user_agent_entry": "Custom user agent",
                    "foo": "bar",
                }
            }
        }
    
        # it should also remove whitespace from http_path
        database.extra = json.dumps(
            {
                "engine_params": {
                    "connect_args": {
                        "http_headers": [("User-Agent", "Custom user agent")],
                        "_user_agent_entry": "Custom user agent",
                        "http_path": "/some_path_here_with_whitespace ",
                    }
                }
            }
        )
        assert DatabricksNativeEngineSpec.get_extra_params(database) == {
            "engine_params": {
                "connect_args": {
                    "http_headers": [["User-Agent", "Custom user agent"]],
                    "_user_agent_entry": "Custom user agent",
                    "http_path": "/some_path_here_with_whitespace",
                }
            }
        }
    
    
    def test_extract_errors() -> None:
        """
        Test that custom error messages are extracted correctly.
        """
    
        msg = ": mismatched input 'fromm'. Expecting: "
        result = DatabricksNativeEngineSpec.extract_errors(Exception(msg))
    
        assert result == [
            SupersetError(
                message=": mismatched input 'fromm'. Expecting: ",
                error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
                level=ErrorLevel.ERROR,
                extra={
                    "engine_name": "Databricks",
                    "issue_codes": [
                        {
                            "code": 1002,
                            "message": "Issue 1002 - The database returned an unexpected error.",
                        }
                    ],
                },
            )
        ]
    
    
    def test_extract_errors_with_context() -> None:
        """
        Test that custom error messages are extracted correctly with context.
        """
    
        msg = ": mismatched input 'fromm'. Expecting: "
        context = {"hostname": "foo"}
        result = DatabricksNativeEngineSpec.extract_errors(Exception(msg), context)
    
        assert result == [
            SupersetError(
                message=": mismatched input 'fromm'. Expecting: ",
                error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
                level=ErrorLevel.ERROR,
                extra={
                    "engine_name": "Databricks",
                    "issue_codes": [
                        {
                            "code": 1002,
                            "message": "Issue 1002 - The database returned an unexpected error.",
                        }
                    ],
                },
            )
        ]
    
    
    @pytest.mark.parametrize(
        "target_type,expected_result",
        [
            ("Date", "CAST('2019-01-02' AS DATE)"),
            (
                "TimeStamp",
                "CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)",
            ),
            ("UnknownType", None),
        ],
    )
    def test_convert_dttm(
        target_type: str,
        expected_result: Optional[str],
        dttm: datetime,  # noqa: F811
    ) -> None:
        from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec as spec
    
        assert_convert_dttm(spec, target_type, expected_result, dttm)
    
    
    def test_get_prequeries() -> None:
        """
        Test the ``get_prequeries`` method.
        """
        from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
    
        assert DatabricksNativeEngineSpec.get_prequeries() == []
        assert DatabricksNativeEngineSpec.get_prequeries(schema="test") == [
            "USE SCHEMA test",
        ]
        assert DatabricksNativeEngineSpec.get_prequeries(catalog="test") == [
            "USE CATALOG test",
        ]
        assert DatabricksNativeEngineSpec.get_prequeries(catalog="foo", schema="bar") == [
            "USE CATALOG foo",
            "USE SCHEMA bar",
        ]