From f9284d809b4e494b476cca0a08b49b004e567cf7 Mon Sep 17 00:00:00 2001
From: Leah Tacke genannt Unterberg <leah.tgu@pads.rwth-aachen.de>
Date: Tue, 21 Jan 2025 14:01:26 +0100
Subject: [PATCH] more work on sql and superset representation

---
 mitm_tooling/data_types/data_types.py         | 36 +++++++-----
 mitm_tooling/extraction/__init__.py           |  4 +-
 .../intermediate_representation.py            | 24 ++++----
 .../representation/sql_representation.py      | 58 ++++++++++---------
 mitm_tooling/transformation/__init__.py       |  4 +-
 .../df/__init__.py                            |  0
 .../df/intermediate_transformation.py         |  7 +--
 .../superset/dataset_definition.py            |  2 +-
 .../superset/superset_representation.py       |  4 +-
 test/something.py                             |  6 +-
 test/test_to_df.py                            | 22 +++----
 11 files changed, 93 insertions(+), 74 deletions(-)
 rename mitm_tooling/{extraction => transformation}/df/__init__.py (100%)
 rename mitm_tooling/{extraction => transformation}/df/intermediate_transformation.py (96%)

diff --git a/mitm_tooling/data_types/data_types.py b/mitm_tooling/data_types/data_types.py
index 056114f..9936664 100644
--- a/mitm_tooling/data_types/data_types.py
+++ b/mitm_tooling/data_types/data_types.py
@@ -7,7 +7,9 @@ import pydantic
 import sqlalchemy as sa
 from sqlalchemy.sql import sqltypes
 
-SA_SQLTypeClass = type[sa.types.TypeEngine]
+SA_SQLType = sa.types.TypeEngine
+SA_SQLTypeInstanceBuilder = Callable[[], SA_SQLType]
+SA_SQLTypeClass = type[SA_SQLType]
 
 SA_SQLTypeName = str
 PandasCast = Callable[[pd.Series], pd.Series]
@@ -26,8 +28,14 @@ class MITMDataType(enum.StrEnum):
     Infer = 'infer'
 
     @property
-    def sa_sql_type(self) -> SA_SQLTypeClass | None:
-        return mitm_sql_type_map.get(self)
+    def sa_sql_type(self) -> SA_SQLType | None:
+        if pair := mitm_sql_type_map.get(self):
+            return pair[1]()
+
+    @property
+    def sa_sql_type_cls(self) -> SA_SQLTypeClass | None:
+        if pair := mitm_sql_type_map.get(self):
+            return pair[0]
 
     @property
     def pandas_cast(self) -> PandasCast | None:
@@ -35,7 +43,7 @@ class MITMDataType(enum.StrEnum):
 
     @property
     def sql_type_str(self) -> str:
-        return self.sa_sql_type.__name__
+        return self.sa_sql_type_cls.__name__
 
     def wrap(self) -> 'WrappedMITMDataType':
         return WrappedMITMDataType(mitm=self)
@@ -57,7 +65,7 @@ def sa_sql_to_mitm_type(sa_type: SA_SQLTypeClass) -> MITMDataType:
 
 
 def mitm_to_sql_type(mitm_type: MITMDataType) -> SA_SQLTypeClass | None:
-    return mitm_type.sa_sql_type
+    return mitm_type.sa_sql_type_cls
 
 
 def mitm_to_pandas(mitm_type: MITMDataType) -> PandasCast | None:
@@ -66,9 +74,9 @@ def mitm_to_pandas(mitm_type: MITMDataType) -> PandasCast | None:
 
 def get_sa_sql_type(type_name: EitherDataType | WrappedMITMDataType) -> SA_SQLTypeClass | None:
     if isinstance(type_name, MITMDataType):
-        return type_name.sa_sql_type
+        return type_name.sa_sql_type_cls
     elif isinstance(type_name, WrappedMITMDataType):
-        return type_name.mitm.sa_sql_type
+        return type_name.mitm.sa_sql_type_cls
     else:
         if type_name and (t := getattr(sqltypes, type_name, None)):
             if isinstance(t, type):
@@ -108,13 +116,13 @@ sql_mitm_type_map: dict[SA_SQLTypeClass, MITMDataType] = {
     # sqltypes.BINARY: MITMDataType.Binary,
 }
 
-mitm_sql_type_map: dict[MITMDataType, SA_SQLTypeClass] = {
-    MITMDataType.Text: sqltypes.String,
-    MITMDataType.Datetime: sqltypes.DATETIME_TIMEZONE,
-    MITMDataType.Json: sqltypes.JSON,
-    MITMDataType.Boolean: sqltypes.Boolean,
-    MITMDataType.Integer: sqltypes.Integer,
-    MITMDataType.Numeric: sqltypes.Float,
+mitm_sql_type_map: dict[MITMDataType, None | tuple[SA_SQLTypeClass, SA_SQLTypeInstanceBuilder]] = {
+    MITMDataType.Text: (sqltypes.String, sqltypes.String),
+    MITMDataType.Datetime: (sqltypes.DATETIME, lambda: sqltypes.DATETIME_TIMEZONE),
+    MITMDataType.Json: (sqltypes.JSON, sqltypes.JSON),
+    MITMDataType.Boolean: (sqltypes.Boolean, sqltypes.Boolean),
+    MITMDataType.Integer: (sqltypes.Integer, sqltypes.Integer),
+    MITMDataType.Numeric: (sqltypes.Float, sqltypes.Float),
     MITMDataType.Unknown: None,
     MITMDataType.Infer: None,
     # MITMDataType.Binary: sqltypes.LargeBinary,
diff --git a/mitm_tooling/extraction/__init__.py b/mitm_tooling/extraction/__init__.py
index 10e03c2..932d0a6 100644
--- a/mitm_tooling/extraction/__init__.py
+++ b/mitm_tooling/extraction/__init__.py
@@ -1,2 +1,4 @@
-from . import sql, df
+from . import sql
+from transformation import df
+
 __all__ = ['sql', 'df']
\ No newline at end of file
diff --git a/mitm_tooling/representation/intermediate_representation.py b/mitm_tooling/representation/intermediate_representation.py
index b39f046..54c58d4 100644
--- a/mitm_tooling/representation/intermediate_representation.py
+++ b/mitm_tooling/representation/intermediate_representation.py
@@ -4,7 +4,7 @@ import itertools
 import logging
 from collections import defaultdict
 from collections.abc import Iterator, Iterable, Sequence, Mapping
-from typing import TYPE_CHECKING, Self, Any
+from typing import TYPE_CHECKING, Self, Any, Annotated
 
 import pandas as pd
 import pydantic
@@ -63,9 +63,10 @@ class HeaderEntry(pydantic.BaseModel):
             itertools.chain(*zip(self.attributes, map(str, self.attribute_dtypes))))
 
 
+
 class Header(pydantic.BaseModel):
     mitm: MITM
-    header_entries: list[HeaderEntry] = pydantic.Field(default_factory=list)
+    header_entries: Annotated[list[HeaderEntry], pydantic.Field(default_factory=list)]
 
     @classmethod
     def from_df(cls, df: pd.DataFrame, mitm: MITM) -> Self:
@@ -95,7 +96,7 @@ class MITMData(Iterable[tuple[ConceptName, pd.DataFrame]], pydantic.BaseModel):
     model_config = ConfigDict(arbitrary_types_allowed=True)
 
     header: Header
-    concept_dfs: dict[ConceptName, pd.DataFrame] = pydantic.Field(default_factory=dict)
+    concept_dfs: Annotated[dict[ConceptName, pd.DataFrame], pydantic.Field(default_factory=dict)]
 
     def __iter__(self):
         return iter(self.concept_dfs.items())
@@ -103,22 +104,25 @@ class MITMData(Iterable[tuple[ConceptName, pd.DataFrame]], pydantic.BaseModel):
     def as_generalized(self) -> Self:
         mitm_def = get_mitm_def(self.header.mitm)
         dfs = defaultdict(list)
-        for c, df in self:
+        for c, df in self.concept_dfs.items():
             c = mitm_def.get_parent(c)
             dfs[c].append(df)
-        return MITMData(header=self.header, dfs=dict(dfs))
+        dfs = {c : pd.concat(dfs_, axis='rows', ignore_index=True) for c, dfs_ in dfs.items()}
+        return MITMData(header=self.header, concept_dfs=dfs)
 
     def as_specialized(self) -> Self:
         mitm_def = get_mitm_def(self.header.mitm)
-        dfs = defaultdict(list)
+        dfs = {}
         for c, df in self:
             if mitm_def.get_properties(c).is_abstract:
                 leaf_concepts = mitm_def.get_leafs(c)
-                for sub_c, idx in df.groupby('kind').groups.items():
-                    dfs[sub_c].append(df.loc[idx])
+
+                for sub_c_key, idx in df.groupby('kind').groups.items():
+                    sub_c = mitm_def.inverse_concept_key_map[str(sub_c_key)]
+                    dfs[sub_c] = df.loc[idx]
             else:
-                dfs[c].append(df)
-        return MITMData(header=self.header, dfs=dict(dfs))
+                dfs[c] = df
+        return MITMData(header=self.header, concept_dfs=dfs)
 
 
 class StreamingConceptData(pydantic.BaseModel):
diff --git a/mitm_tooling/representation/sql_representation.py b/mitm_tooling/representation/sql_representation.py
index 120998e..3a111d4 100644
--- a/mitm_tooling/representation/sql_representation.py
+++ b/mitm_tooling/representation/sql_representation.py
@@ -1,19 +1,18 @@
-from collections import defaultdict
-from collections.abc import Callable, Iterator, Generator, Mapping
+from collections.abc import Callable, Generator, Mapping
 
 import pydantic
 import sqlalchemy as sa
 import sqlalchemy.sql.schema
-from pydantic import AnyUrl
+from pydantic import AnyUrl, ConfigDict
 from mitm_tooling.data_types import MITMDataType
 from mitm_tooling.definition import MITMDefinition, ConceptProperties, OwnedRelations, ConceptName, MITM, get_mitm_def, \
-    ConceptKind, ConceptLevel, RelationName
+    RelationName
 from mitm_tooling.definition.definition_tools import map_col_groups, ColGroupMaps
-from mitm_tooling.extraction.sql.data_models import Queryable, TableName, ColumnName
-from .df_representation import MITMDataset
+from mitm_tooling.extraction.sql.data_models import Queryable, TableName
 from .intermediate_representation import Header, MITMData
 from mitm_tooling.utilities.sql_utils import create_sa_engine, qualify
 from mitm_tooling.utilities import python_utils
+from mitm_tooling.utilities.io_utils import FilePath
 
 from sqlalchemy_utils.view import create_view
 
@@ -74,7 +73,6 @@ def mk_table(meta: sa.MetaData, mitm: MITM, concept: ConceptName, table_name: Ta
                                                    created_columns,
                                                    ref_columns)
         constraints.extend(schema_items)
-        print(constraints)
 
     return sa.Table(table_name, meta, schema=SQL_REPRESENTATION_DEFAULT_SCHEMA, *columns,
                     *constraints), created_columns, ref_columns
@@ -104,6 +102,8 @@ ConceptTypeTablesDict = dict[ConceptName, dict[TableName, sa.Table]]
 
 
 class SQLRepresentationSchema(pydantic.BaseModel):
+    model_config = ConfigDict(arbitrary_types_allowed=True)
+
     meta: sa.MetaData
     concept_tables: ConceptTablesDict
     type_tables: ConceptTypeTablesDict
@@ -141,6 +141,7 @@ def mk_db_schema(header: Header, gen_views: Callable[
                                 mitm_def.resolve_foreign_types(concept).items() for name, dt in
                                 resolved_fk.items()]
         })
+        concept_tables[concept] = t
 
     for he in header.header_entries:
         he_concept = he.concept
@@ -185,35 +186,40 @@ def mk_db_schema(header: Header, gen_views: Callable[
     return SQLRepresentationSchema(meta=meta, concept_tables=concept_tables, type_tables=type_tables, views=views)
 
 
-def insert_db_instances(engine: sa.Engine, meta: sa.MetaData, mitm_data: MITMData):
+def insert_db_instances(engine: sa.Engine, sql_rep_schema: SQLRepresentationSchema, mitm_data: MITMData):
+    from mitm_tooling.transformation.df import pack_mitm_dataset, unpack_mitm_data
+    h = mitm_data.header
+    mitm = mitm_data.header.mitm
+    mitm_def = get_mitm_def(mitm)
+    mitm_dataset = unpack_mitm_data(mitm_data)
     with engine.connect() as conn:
-        h = mitm_data.header
-        mitm = mitm_data.header.mitm
-
-        for concept, df in mitm_data.as_specialized():
-            concept_table = mk_concept_table_name(mitm, concept)
-            t_concept = meta.tables[concept_table]
-            ref_cols = pick_table_pk(mitm, concept, t_concept.columns)
-            conn.execute(t_concept.insert(), df[[c.name for c in t_concept.columns]].to_dict('records'))
-
-            if has_type_tables(mitm, concept):
-                concept_properties, concept_relations = get_mitm_def(mitm).get(concept)
-                for typ, idx in df.groupby(concept_properties.typing_concept).groups.items():
-                    type_df = df.loc[idx]
-                    t_type = meta.tables[mk_type_table_name(mitm, concept, str(typ))]
-                    conn.execute(t_type.insert(), type_df[[c.name for c in t_type.columns]].to_dict('records'))
+        for concept, typed_dfs in mitm_dataset:
+            concept_properties, concept_relations = mitm_def.get(concept)
+            for type_name, type_df in typed_dfs.items():
+
+                t_concept = sql_rep_schema.concept_tables[mitm_def.get_parent(concept)]
+                ref_cols = pick_table_pk(mitm, concept, t_concept.columns)
+                conn.execute(t_concept.insert(), type_df[[c.name for c in t_concept.columns]].to_dict('records'))
+
+                if has_type_tables(mitm, concept):
+                    #for typ, idx in df.groupby(concept_properties.typing_concept).groups.items():
+                    #    type_df = df.loc[idx]
+                    t_type = sql_rep_schema.type_tables[concept][type_name]
+                    to_dict = type_df[[c.name for c in t_type.columns]].to_dict('records')
+                    conn.execute(t_type.insert(), to_dict)
+
         conn.commit()
 
 
 def insert_mitm_data(engine: sa.Engine, mitm_data: MITMData) -> SQLRepresentationSchema:
     sql_rep_schema = mk_db_schema(mitm_data.header)
     sql_rep_schema.meta.create_all(engine)
-    insert_db_instances(engine, sql_rep_schema.meta, mitm_data)
+    insert_db_instances(engine, sql_rep_schema, mitm_data)
     return sql_rep_schema
 
 
-def mk_sqlite(mitm_data: MITMData, file_path: str | None = ':memory:') -> tuple[sa.Engine, SQLRepresentationSchema]:
-    engine = create_sa_engine(AnyUrl(f'sqlite:///{file_path}'))
+def mk_sqlite(mitm_data: MITMData, file_path: FilePath | None = ':memory:') -> tuple[sa.Engine, SQLRepresentationSchema]:
+    engine = create_sa_engine(AnyUrl(f'sqlite:///{str(file_path)}'))
     sql_rep_schema = insert_mitm_data(engine, mitm_data)
     # print([f'{t.name}: {t.columns} {t.constraints}' for ts in sql_rep_schema.type_tables.values() for t in ts.values()])
     return engine, sql_rep_schema
diff --git a/mitm_tooling/transformation/__init__.py b/mitm_tooling/transformation/__init__.py
index 81683a1..0bf5d16 100644
--- a/mitm_tooling/transformation/__init__.py
+++ b/mitm_tooling/transformation/__init__.py
@@ -1,2 +1,2 @@
-from . import superset
-__all__ = ['superset']
\ No newline at end of file
+from . import df, superset
+__all__ = ['df','superset']
\ No newline at end of file
diff --git a/mitm_tooling/extraction/df/__init__.py b/mitm_tooling/transformation/df/__init__.py
similarity index 100%
rename from mitm_tooling/extraction/df/__init__.py
rename to mitm_tooling/transformation/df/__init__.py
diff --git a/mitm_tooling/extraction/df/intermediate_transformation.py b/mitm_tooling/transformation/df/intermediate_transformation.py
similarity index 96%
rename from mitm_tooling/extraction/df/intermediate_transformation.py
rename to mitm_tooling/transformation/df/intermediate_transformation.py
index 8d67f04..7562a4b 100644
--- a/mitm_tooling/extraction/df/intermediate_transformation.py
+++ b/mitm_tooling/transformation/df/intermediate_transformation.py
@@ -1,6 +1,6 @@
 import itertools
 from collections import defaultdict
-from collections.abc import Sequence
+from collections.abc import Sequence, Iterable
 
 import pandas as pd
 
@@ -12,9 +12,7 @@ from mitm_tooling.representation import mk_concept_file_header
 from mitm_tooling.representation.common import guess_k_of_header_df, mk_header_file_columns
 
 
-def pack_typed_dfs_as_concept_table(mitm: MITM, concept: ConceptName, dfs: Sequence[pd.DataFrame]) -> pd.DataFrame:
-    assert len(dfs) > 0
-
+def pack_typed_dfs_as_concept_table(mitm: MITM, concept: ConceptName, dfs: Iterable[pd.DataFrame]) -> pd.DataFrame:
     normalized_dfs = []
     for df in dfs:
         base_cols, col_dts = mk_concept_file_header(mitm, concept, 0)
@@ -27,6 +25,7 @@ def pack_typed_dfs_as_concept_table(mitm: MITM, concept: ConceptName, dfs: Seque
         df.columns = squashed_form_cols
         normalized_dfs.append((df, k))
 
+    assert len(normalized_dfs) > 0
     max_k = max(normalized_dfs, key=lambda x: x[1])[1]
 
     squashed_form_cols = mk_concept_file_header(mitm, concept, max_k)[0]
diff --git a/mitm_tooling/transformation/superset/dataset_definition.py b/mitm_tooling/transformation/superset/dataset_definition.py
index ce83a6f..92190d4 100644
--- a/mitm_tooling/transformation/superset/dataset_definition.py
+++ b/mitm_tooling/transformation/superset/dataset_definition.py
@@ -71,7 +71,7 @@ class SupersetColumnDef(pydantic.BaseModel):
     expression: str | None = None
     description: str | None = None
     python_date_format: str = None
-    extra: dict[str, Any] = pydantic.Field(default_factory=dict)
+    extra: Annotated[dict[str, Any], pydantic.Field(default_factory=dict)]
 
 
 class SupersetTableDef(SupersetDefFile):
diff --git a/mitm_tooling/transformation/superset/superset_representation.py b/mitm_tooling/transformation/superset/superset_representation.py
index 4806cde..7fe29aa 100644
--- a/mitm_tooling/transformation/superset/superset_representation.py
+++ b/mitm_tooling/transformation/superset/superset_representation.py
@@ -70,13 +70,15 @@ def infer_superset_dataset_def(sqlite_file_path: FilePath) -> SupersetDef:
             cols = []
             for c in table.columns:
                 dt = table.column_properties[c].mitm_data_type
+
                 cols.append(
                     SupersetColumnDef(column_name=c,
                                       is_dttm=dt is MITMDataType.Datetime,
                                       groupby=dt not in {MITMDataType.Json,
                                                          MITMDataType.Numeric,
                                                          MITMDataType.Datetime},
-                                      type=str(dt.sa_sql_type) # .as_generic()) #.dialect_impl(sa.Dialect.get_dialect_cls(sa.URL.create(drivername='sqlite', database=':memory:'))()
+                                      type=(dt.sa_sql_type or MITMDataType.Text.sa_sql_type).compile(
+                                          dialect=engine.dialect)
                                       ))
             datasets.append(
                 SupersetTableDef(table_name=table_name, schema_name=schema_name, uuid=uuid.uuid4(), columns=cols))
diff --git a/test/something.py b/test/something.py
index 5b6b60d..069b720 100644
--- a/test/something.py
+++ b/test/something.py
@@ -1,9 +1,6 @@
 import os
 import unittest
 
-from Tools.scripts.generate_opcode_h import header
-
-from representation.sql_representation import mk_sqlite
 
 
 class MyTestCase(unittest.TestCase):
@@ -30,7 +27,7 @@ class MyTestCase(unittest.TestCase):
         print()
 
     def test_writing_sqlite(self):
-        from mitm_tooling.representation import Header, HeaderEntry, mk_db_schema, MITMData
+        from mitm_tooling.representation import Header, HeaderEntry, mk_db_schema, MITMData, mk_sqlite
         from mitm_tooling.definition import MITM
         from mitm_tooling.data_types import MITMDataType
         h = Header(mitm=MITM.MAED, header_entries=[
@@ -44,6 +41,7 @@ class MyTestCase(unittest.TestCase):
         mk_sqlite(MITMData(header=h), file_path='gendb.sqlite')
 
     def test_with_synthetic(self):
+        from mitm_tooling.representation import mk_sqlite
         from mitm_tooling.io import importing
         from mitm_tooling.definition import MITM
         syn = importing.read_zip('synthetic.maed', MITM.MAED)
diff --git a/test/test_to_df.py b/test/test_to_df.py
index 7e39ba4..9e1b9ce 100644
--- a/test/test_to_df.py
+++ b/test/test_to_df.py
@@ -1,14 +1,14 @@
-import pandas as pd
+import unittest
 
-from mitm_tooling.extraction.df import unpack_mitm_data
+from transformation.df import unpack_mitm_data
 
+class MyTestCase(unittest.TestCase):
+    def test_to_df(self):
+        from mitm_tooling.io import importing
+        from mitm_tooling.definition import MITM
+        syn = importing.read_zip('synthetic.maed', MITM.MAED)
+        mitm_dataset = unpack_mitm_data(syn)
 
-def test_to_df():
-    from mitm_tooling.io import importing
-    from mitm_tooling.definition import MITM
-    syn = importing.read_zip('synthetic.maed', MITM.MAED)
-    mitm_dataset = unpack_mitm_data(syn)
-
-    for c, typed_dfs in mitm_dataset:
-        for type_name, df in typed_dfs.items():
-            print(df.head())
+        for c, typed_dfs in mitm_dataset:
+            for type_name, df in typed_dfs.items():
+                print(df.head())
-- 
GitLab