Skip to content
Snippets Groups Projects
Select Git revision
  • c01c699f16926b6fb0c9a810c20ae700a0988bb2
  • main default protected
  • YL
  • NS
4 results

evaluation.py

Blame
  • forwardable_request.py 10.72 KiB
    from __future__ import annotations
    import io
    import os
    import shutil
    import tempfile
    import urllib.parse
    import uuid
    from abc import abstractmethod, ABC
    from typing import Any, Self, Generator, Iterator, Iterable, Callable, Literal, IO
    
    import pydantic
    import requests
    import werkzeug.utils
    from flask import Request
    from pydantic import ConfigDict
    from superset.utils.cache import generate_cache_key
    from werkzeug.datastructures import Headers
    
    from superset.commands.mitm.caching.utils import collect_chunks, chunked_raw_io
    
    
    class ChunkedStream:
        def __init__(self, chunk_gen: Iterable[bytes]):
            self.chunk_gen = chunk_gen
            self.iterator = iter(chunk_gen)
    
        def __iter__(self):
            return self.iterator
    
        def read(self, size=-1) -> bytes:
            try:
                return next(self.iterator)
            except StopIteration:
                return b""
    
    
    CachedFile = tuple[str, int]
    CacheKey = str
    
    
    def stream_cached_bytes(base_cache_key: str) -> Generator[bytes, None, None]:
        from superset.commands.mitm.caching.cache_data import ReadStreamedCacheCommand
        with ReadStreamedCacheCommand(base_cache_key).run(delete_after=True) as chunks:
            for chunk in chunks:
                yield chunk
    
    
    def cache_raw_io(base_cache_key: str, raw_io: io.IOBase | IO[bytes]) -> CacheKey:
        from superset.commands.mitm.caching.cache_data import StreamIntoCacheCommand
        StreamIntoCacheCommand(base_cache_key,
                               lambda: chunked_raw_io(raw_io),
                               expiry_timeout=180).run()
        return base_cache_key
    
    
    def with_file_meta(k: str,
                       f: Any,
                       fms: dict[str, tuple[
                           str, str | None, dict[str, str]]] | None = None) -> tuple[
        str, Any, str | None, dict[str, str]]:
        if (fm := fms.get(k)) is not None:
            fn, ct, h = fm
            return fn, f, ct, h
        else:
            return k, f, None, {}
    
    
    class PostActionMixin:
        def do_post_actions(self):
            pass
    
    
    class TempDirMixin(PostActionMixin, pydantic.BaseModel):
        tempdir: os.PathLike[str] | None = None
    
        def do_post_actions(self):
            super(TempDirMixin, self).do_post_actions()
            if self.tempdir:
                shutil.rmtree(self.tempdir, ignore_errors=True)
    
    
    class ForwardableRequestBase(PostActionMixin, pydantic.BaseModel):
        model_config = ConfigDict(arbitrary_types_allowed=True)
    
        headers: list[tuple[str, str]]
        path: str
        method: str
        base_url: str | None = None
        query: str | None = None
        url_params: dict[str, Any] | None = None
        encoding: str | None = None
        request_kwarg_overrides: dict[str, Any] | None = None
    
        @classmethod
        def from_dict(cls, dic: dict[str, Any]) -> Self:
            return cls.model_validate(dic, strict=False)
    
        def build_url(self) -> str:
            base_url = self.base_url or ''
            rel_url = (self.path or '') + (f'?{self.query}' if self.query else '')
            return urllib.parse.urljoin(base_url, rel_url)
    
        def to_dict(self) -> dict[str, Any]:
            return self.model_dump(mode='python', round_trip=True, by_alias=True)
    
        def mk_request_kwargs(self) -> dict[str, Any]:
            return {
                'method': self.method,
                'url': self.build_url(),
                'headers': Headers(self.headers),
                'params': self.url_params}
    
        def build_request(self) -> requests.Request:
            kwargs = self.mk_request_kwargs()
            if self.request_kwarg_overrides:
                kwargs |= self.request_kwarg_overrides
            return requests.Request(**kwargs)
    
        def do_post_actions(self) -> None:
            super(ForwardableRequestBase, self).do_post_actions()
    
        def exec_request(self,
                         stream: bool | None = None,
                         timeout: int | None = None) -> requests.Response | None:
            req = self.build_request()
            try:
                prepared_request = req.prepare()
                with requests.session() as s:
                    return s.send(prepared_request, stream=stream, timeout=timeout)
            finally:
                self.do_post_actions()
    
    class DatalessForwardableRequest(ForwardableRequestBase):
        pass
    
    class JsonRequest(ForwardableRequestBase):
        json_data: dict[str, Any] | pydantic.BaseModel = pydantic.Field(default_factory=dict)
    
        def mk_request_kwargs(self) -> dict[str, Any]:
            json_data = self.json_data
            if isinstance(self.json_data, pydantic.BaseModel):
                json_data = self.json_data.model_dump(mode='python', round_trip=True, by_alias=True)
            return super().mk_request_kwargs() | {'json': json_data}
    
    
    class FormDataRequest(TempDirMixin, ForwardableRequestBase):
        model_config = ConfigDict(arbitrary_types_allowed=True)
    
        form_data: dict[str, Any] = pydantic.Field(default_factory=dict)
    
        raw_files: dict[str, bytes | io.BytesIO] = pydantic.Field(repr=False, default_factory=dict)
        filesystem_files: dict[str, str | os.PathLike[str]] = pydantic.Field(default_factory=dict)
        cached_files: dict[str, CacheKey] = pydantic.Field(default_factory=dict)
    
        files_meta: dict[str, tuple[str, str | None, dict[str, str]]] = pydantic.Field(
            default_factory=dict)
    
        stream_data: bool = False
    
        def mk_request_kwargs(self) -> dict[str, Any]:
            base_kwargs = super().mk_request_kwargs()
    
            from requests_toolbelt import MultipartEncoder
            fd = self.form_data or {}
    
            raw_files = {k: with_file_meta(k, v, self.files_meta) for k, v in
                         self.raw_files.items()}
            filesystem_files = {k: with_file_meta(k, open(v, 'rb'), self.files_meta) for
                                k, v in self.filesystem_files.items()}
            cached_files = {
                k: with_file_meta(k, stream_cached_bytes(v), self.files_meta) for k, v
                in self.cached_files.items()}
            multipart_data = MultipartEncoder(
                fields=fd | raw_files | filesystem_files | cached_files,
            )
            headers = Headers(self.headers)
            headers.add_header('Content-Type', multipart_data.content_type)
            return base_kwargs | {
                'data': multipart_data if self.stream_data else multipart_data.to_string(),
                'headers': headers}
    
    
    class RawDataRequest(TempDirMixin, ForwardableRequestBase):
        raw_data: bytes | None = pydantic.Field(repr=False, default=None)
        filesystem_data: str | os.PathLike[str] = pydantic.Field(default_factory=dict)
        cached_data: CacheKey | None = pydantic.Field(default=None)
        stream_data: bool = False
    
        def mk_request_kwargs(self) -> dict[str, Any]:
            assert not (
                self.raw_data is None and self.filesystem_data is None and self.cached_data is None), 'one form of data must be provided'
            assert not (
                self.raw_data is not None and self.filesystem_data is not None and self.cached_data is not None), 'only one form of data can be provided'
    
            data = self.raw_data
            if self.filesystem_data:
                data = open(self.filesystem_data, 'rb')
                if not self.stream_data:
                    x = data.read()
                    data.close()
                    data = x
            if self.cached_data:
                data = stream_cached_bytes(self.cached_data)
                if not self.stream_data:
                    data = collect_chunks(data)
    
            return super().mk_request_kwargs() | {'data': data}
    
        @property
        def post_actions(self) -> list[Callable[[], ...]]:
            return super().post_actions
    
    
    def tempsave_raw_io(raw_io: io.IOBase | IO[bytes]) -> tuple[
        str, str | os.PathLike[str]]:
        td = tempfile.mkdtemp(prefix='forwarded_request')
        path = os.path.join(td, 'raw_data')
        with open(path, 'wb') as f:
            f.write(raw_io.read())
        return td, path
    
    
    def tempsave_request_files(flask_request: Request) -> tuple[
        str | None, dict[str, str | os.PathLike[str]]]:
        td = None
        filesystem_files = {}
        if len(flask_request.files) > 0:
            td = tempfile.mkdtemp(prefix='forwarded_request')
            for n, f in flask_request.files.items():
                secure_name = werkzeug.utils.secure_filename(f.filename)
                path = os.path.join(td, secure_name)
                f.save(path)
                filesystem_files[n] = path
        return td, filesystem_files
    
    
    def raw_request_files(flask_request: Request) -> dict[str, bytes]:
        return {n: f.stream.read() for n, f in flask_request.files.items()}
    
    
    def cache_request_files(base_cache_key: str, flask_request: Request) -> dict[
        str, CacheKey]:
        cached_files = {}
        for n, f in flask_request.files.items():
            file_specific_base_cache_key = f'{base_cache_key}:{n}'
            cached_files[n] = cache_raw_io(file_specific_base_cache_key, f.stream)
        return cached_files
    
    
    def get_file_meta(flask_request: Request) -> dict[
        str, tuple[str, str | None, dict[str, str]]]:
        return {n: (f.filename, f.content_type, dict(f.headers.to_wsgi_list())) for n, f in
                flask_request.files.items()}
    
    
    def mk_forwardable_request(endpoint: str,
                               flask_request: Request,
                               file_handling: Literal[
                                   'raw', 'cache', 'filesystem'] = 'raw') -> ForwardableRequestBase:
        kwargs = dict(headers=list(flask_request.headers.items()),
                      method=str(flask_request.method),
                      path=endpoint)
        if flask_request.content_encoding == 'multipart/form-data':
            kwargs |= dict(files_meta=get_file_meta(flask_request))
            if file_handling == 'filesystem':
                td, filesystem_files = tempsave_request_files(flask_request)
                kwargs |= dict(tempdir=td, filesystem_files=filesystem_files)
            elif file_handling == 'cache':
                cached_files = cache_request_files(f'forwarded-request:files:{uuid.uuid4()}',
                                                   flask_request)
                kwargs |= dict(cached_files=cached_files)
            elif file_handling == 'raw':
                kwargs |= dict(raw_files=raw_request_files(flask_request))
            else:
                raise ValueError(f'unknown file handling mode: {file_handling}')
    
            return FormDataRequest(**kwargs, form_data=flask_request.form.to_dict())
        elif flask_request.content_encoding == 'application/json':
            return JsonRequest(**kwargs, json_data=flask_request.json)
        else:
            if file_handling == 'filesystem':
                td, filesystem_files = tempsave_raw_io(flask_request.stream)
                kwargs |= dict(tempdir=td, filesystem_files=filesystem_files)
            elif file_handling == 'cache':
                cached_data = cache_raw_io(f'forwarded-request:data:{uuid.uuid4()}',
                                           flask_request.stream)
                kwargs |= dict(cached_data=cached_data)
            elif file_handling == 'raw':
                kwargs |= dict(raw_data=flask_request.get_data(cache=False))
            else:
                raise ValueError(f'unknown file handling mode: {file_handling}')
            return RawDataRequest(**kwargs)