Select Git revision
string_tools.cpp
forwardable_request.py 10.72 KiB
from __future__ import annotations
import io
import os
import shutil
import tempfile
import urllib.parse
import uuid
from abc import abstractmethod, ABC
from typing import Any, Self, Generator, Iterator, Iterable, Callable, Literal, IO
import pydantic
import requests
import werkzeug.utils
from flask import Request
from pydantic import ConfigDict
from superset.utils.cache import generate_cache_key
from werkzeug.datastructures import Headers
from superset.commands.mitm.caching.utils import collect_chunks, chunked_raw_io
class ChunkedStream:
def __init__(self, chunk_gen: Iterable[bytes]):
self.chunk_gen = chunk_gen
self.iterator = iter(chunk_gen)
def __iter__(self):
return self.iterator
def read(self, size=-1) -> bytes:
try:
return next(self.iterator)
except StopIteration:
return b""
CachedFile = tuple[str, int]
CacheKey = str
def stream_cached_bytes(base_cache_key: str) -> Generator[bytes, None, None]:
from superset.commands.mitm.caching.cache_data import ReadStreamedCacheCommand
with ReadStreamedCacheCommand(base_cache_key).run(delete_after=True) as chunks:
for chunk in chunks:
yield chunk
def cache_raw_io(base_cache_key: str, raw_io: io.IOBase | IO[bytes]) -> CacheKey:
from superset.commands.mitm.caching.cache_data import StreamIntoCacheCommand
StreamIntoCacheCommand(base_cache_key,
lambda: chunked_raw_io(raw_io),
expiry_timeout=180).run()
return base_cache_key
def with_file_meta(k: str,
f: Any,
fms: dict[str, tuple[
str, str | None, dict[str, str]]] | None = None) -> tuple[
str, Any, str | None, dict[str, str]]:
if (fm := fms.get(k)) is not None:
fn, ct, h = fm
return fn, f, ct, h
else:
return k, f, None, {}
class PostActionMixin:
def do_post_actions(self):
pass
class TempDirMixin(PostActionMixin, pydantic.BaseModel):
tempdir: os.PathLike[str] | None = None
def do_post_actions(self):
super(TempDirMixin, self).do_post_actions()
if self.tempdir:
shutil.rmtree(self.tempdir, ignore_errors=True)
class ForwardableRequestBase(PostActionMixin, pydantic.BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
headers: list[tuple[str, str]]
path: str
method: str
base_url: str | None = None
query: str | None = None
url_params: dict[str, Any] | None = None
encoding: str | None = None
request_kwarg_overrides: dict[str, Any] | None = None
@classmethod
def from_dict(cls, dic: dict[str, Any]) -> Self:
return cls.model_validate(dic, strict=False)
def build_url(self) -> str:
base_url = self.base_url or ''
rel_url = (self.path or '') + (f'?{self.query}' if self.query else '')
return urllib.parse.urljoin(base_url, rel_url)
def to_dict(self) -> dict[str, Any]:
return self.model_dump(mode='python', round_trip=True, by_alias=True)
def mk_request_kwargs(self) -> dict[str, Any]:
return {
'method': self.method,
'url': self.build_url(),
'headers': Headers(self.headers),
'params': self.url_params}
def build_request(self) -> requests.Request:
kwargs = self.mk_request_kwargs()
if self.request_kwarg_overrides:
kwargs |= self.request_kwarg_overrides
return requests.Request(**kwargs)
def do_post_actions(self) -> None:
super(ForwardableRequestBase, self).do_post_actions()
def exec_request(self,
stream: bool | None = None,
timeout: int | None = None) -> requests.Response | None:
req = self.build_request()
try:
prepared_request = req.prepare()
with requests.session() as s:
return s.send(prepared_request, stream=stream, timeout=timeout)
finally:
self.do_post_actions()
class DatalessForwardableRequest(ForwardableRequestBase):
pass
class JsonRequest(ForwardableRequestBase):
json_data: dict[str, Any] | pydantic.BaseModel = pydantic.Field(default_factory=dict)
def mk_request_kwargs(self) -> dict[str, Any]:
json_data = self.json_data
if isinstance(self.json_data, pydantic.BaseModel):
json_data = self.json_data.model_dump(mode='python', round_trip=True, by_alias=True)
return super().mk_request_kwargs() | {'json': json_data}
class FormDataRequest(TempDirMixin, ForwardableRequestBase):
model_config = ConfigDict(arbitrary_types_allowed=True)
form_data: dict[str, Any] = pydantic.Field(default_factory=dict)
raw_files: dict[str, bytes | io.BytesIO] = pydantic.Field(repr=False, default_factory=dict)
filesystem_files: dict[str, str | os.PathLike[str]] = pydantic.Field(default_factory=dict)
cached_files: dict[str, CacheKey] = pydantic.Field(default_factory=dict)
files_meta: dict[str, tuple[str, str | None, dict[str, str]]] = pydantic.Field(
default_factory=dict)
stream_data: bool = False
def mk_request_kwargs(self) -> dict[str, Any]:
base_kwargs = super().mk_request_kwargs()
from requests_toolbelt import MultipartEncoder
fd = self.form_data or {}
raw_files = {k: with_file_meta(k, v, self.files_meta) for k, v in
self.raw_files.items()}
filesystem_files = {k: with_file_meta(k, open(v, 'rb'), self.files_meta) for
k, v in self.filesystem_files.items()}
cached_files = {
k: with_file_meta(k, stream_cached_bytes(v), self.files_meta) for k, v
in self.cached_files.items()}
multipart_data = MultipartEncoder(
fields=fd | raw_files | filesystem_files | cached_files,
)
headers = Headers(self.headers)
headers.add_header('Content-Type', multipart_data.content_type)
return base_kwargs | {
'data': multipart_data if self.stream_data else multipart_data.to_string(),
'headers': headers}
class RawDataRequest(TempDirMixin, ForwardableRequestBase):
raw_data: bytes | None = pydantic.Field(repr=False, default=None)
filesystem_data: str | os.PathLike[str] = pydantic.Field(default_factory=dict)
cached_data: CacheKey | None = pydantic.Field(default=None)
stream_data: bool = False
def mk_request_kwargs(self) -> dict[str, Any]:
assert not (
self.raw_data is None and self.filesystem_data is None and self.cached_data is None), 'one form of data must be provided'
assert not (
self.raw_data is not None and self.filesystem_data is not None and self.cached_data is not None), 'only one form of data can be provided'
data = self.raw_data
if self.filesystem_data:
data = open(self.filesystem_data, 'rb')
if not self.stream_data:
x = data.read()
data.close()
data = x
if self.cached_data:
data = stream_cached_bytes(self.cached_data)
if not self.stream_data:
data = collect_chunks(data)
return super().mk_request_kwargs() | {'data': data}
@property
def post_actions(self) -> list[Callable[[], ...]]:
return super().post_actions
def tempsave_raw_io(raw_io: io.IOBase | IO[bytes]) -> tuple[
str, str | os.PathLike[str]]:
td = tempfile.mkdtemp(prefix='forwarded_request')
path = os.path.join(td, 'raw_data')
with open(path, 'wb') as f:
f.write(raw_io.read())
return td, path
def tempsave_request_files(flask_request: Request) -> tuple[
str | None, dict[str, str | os.PathLike[str]]]:
td = None
filesystem_files = {}
if len(flask_request.files) > 0:
td = tempfile.mkdtemp(prefix='forwarded_request')
for n, f in flask_request.files.items():
secure_name = werkzeug.utils.secure_filename(f.filename)
path = os.path.join(td, secure_name)
f.save(path)
filesystem_files[n] = path
return td, filesystem_files
def raw_request_files(flask_request: Request) -> dict[str, bytes]:
return {n: f.stream.read() for n, f in flask_request.files.items()}
def cache_request_files(base_cache_key: str, flask_request: Request) -> dict[
str, CacheKey]:
cached_files = {}
for n, f in flask_request.files.items():
file_specific_base_cache_key = f'{base_cache_key}:{n}'
cached_files[n] = cache_raw_io(file_specific_base_cache_key, f.stream)
return cached_files
def get_file_meta(flask_request: Request) -> dict[
str, tuple[str, str | None, dict[str, str]]]:
return {n: (f.filename, f.content_type, dict(f.headers.to_wsgi_list())) for n, f in
flask_request.files.items()}
def mk_forwardable_request(endpoint: str,
flask_request: Request,
file_handling: Literal[
'raw', 'cache', 'filesystem'] = 'raw') -> ForwardableRequestBase:
kwargs = dict(headers=list(flask_request.headers.items()),
method=str(flask_request.method),
path=endpoint)
if flask_request.content_encoding == 'multipart/form-data':
kwargs |= dict(files_meta=get_file_meta(flask_request))
if file_handling == 'filesystem':
td, filesystem_files = tempsave_request_files(flask_request)
kwargs |= dict(tempdir=td, filesystem_files=filesystem_files)
elif file_handling == 'cache':
cached_files = cache_request_files(f'forwarded-request:files:{uuid.uuid4()}',
flask_request)
kwargs |= dict(cached_files=cached_files)
elif file_handling == 'raw':
kwargs |= dict(raw_files=raw_request_files(flask_request))
else:
raise ValueError(f'unknown file handling mode: {file_handling}')
return FormDataRequest(**kwargs, form_data=flask_request.form.to_dict())
elif flask_request.content_encoding == 'application/json':
return JsonRequest(**kwargs, json_data=flask_request.json)
else:
if file_handling == 'filesystem':
td, filesystem_files = tempsave_raw_io(flask_request.stream)
kwargs |= dict(tempdir=td, filesystem_files=filesystem_files)
elif file_handling == 'cache':
cached_data = cache_raw_io(f'forwarded-request:data:{uuid.uuid4()}',
flask_request.stream)
kwargs |= dict(cached_data=cached_data)
elif file_handling == 'raw':
kwargs |= dict(raw_data=flask_request.get_data(cache=False))
else:
raise ValueError(f'unknown file handling mode: {file_handling}')
return RawDataRequest(**kwargs)