Skip to content
Snippets Groups Projects
Select Git revision
  • 8a8566bac6590efa80fba980448e26492782b3c3
  • master default protected
  • develop protected
  • feature/use-nest3-dockerfile
  • configurable_address
  • 0.2.0
  • 0.1.0
7 results

data_storage.cpp

Blame
  • data_management.py 24.44 KiB
    """
    Data management for FRTRG data based on pytables
    
    Idea:
    * Data + Metadata are stored in H5F files.
    * A single file can be any logical unit. An arbitrary number of files can be created.
    * Metadata are partially stored in a database for simple access
    * a file "filename.lock" is created before writing to a file.
    
    Copyright 2022 Valentin Bruch
    License: MIT
    """
    
    import os
    import tables as tb
    import pathlib
    from time import sleep
    from datetime import datetime
    import numpy as np
    import pandas as pd
    import sqlalchemy as db
    import random
    import settings
    import warnings
    
    warnings.simplefilter("ignore", tb.NaturalNameWarning)
    
    
    _to_char = lambda x: chr(x + 49 if x < 10 else x + 55 if x < 36 else x + 61)
    random_string = lambda n: ''.join(_to_char(random.randint(0,61)) for i in range(n))
    
    
    def replace_all(string:str, replacements:dict):
        """
        Apply all replacements to string
        """
        for old, new in replacements.items():
            string = string.replace(old, new)
        return string
    
    
    class KondoExport:
        """
        Class for saving Kondo object to file.
        Example usage:
        >>> kondo = Kondo(...)
        >>> kondo.run(...)
        >>> KondoExport(kondo).save_h5("data/frtrg-01.h5")
        """
        METHOD_ENUM = tb.Enum(('unknown', 'mu', 'J', 'J-compact-1', 'J-compact-2', 'mu-reference', 'J-reference', 'mu-extrap-voltage', 'J-extrap-voltage'))
        SOLVER_METHOD_ENUM = tb.Enum(('unknown', 'RK45', 'RK23', 'DOP853', 'Radau', 'BDF', 'LSODA', 'other'))
    
        def __init__(self, kondo):
            self.kondo = kondo
    
        @property
        def hash(self):
            """
            hash based on Floquet matrices in Kondo object
            """
            try:
                return self._hash
            except AttributeError:
                self._hash = self.kondo.hash()[:40]
                return self._hash
    
        @property
        def metadata(self):
            """
            dictionary of metadata
            """
            if self.kondo.unitary_transformation:
                if self.kondo.compact == 2:
                    method = 'J-compact-2'
                elif self.kondo.compact == 1:
                    method = 'J-compact-1'
                else:
                    method = 'J'
            else:
                method = 'mu'
            solver_flags = 0
            # TODO: use solver_flags
            return dict(
                    hash = self.hash,
                    omega = self.kondo.omega,
                    energy = self.kondo.energy,
                    version_major = self.kondo.version[0],
                    version_minor = self.kondo.version[1],
                    git_commit_count = self.kondo.version[2],
                    git_commit_id = self.kondo.version[3],
                    method = method,
                    timestamp = datetime.utcnow().timestamp(),
                    solver_method = getattr(self.kondo, 'solveopts', {}).get('method', 'unknown'),
                    solver_tol_abs = getattr(self.kondo, 'solveopts', {}).get('atol', -1),
                    solver_tol_rel = getattr(self.kondo, 'solveopts', {}).get('rtol', -1),
                    d = self.kondo.d,
                    vdc = self.kondo.vdc + self.kondo.omega*self.kondo.resonant_dc_shift,
                    vac = self.kondo.vac,
                    nmax = self.kondo.nmax,
                    padding = self.kondo.padding,
                    voltage_branches = self.kondo.voltage_branches,
                    resonant_dc_shift = self.kondo.resonant_dc_shift,
                    solver_flags = solver_flags,
                    )
    
        @property
        def main_results(self):
            results = dict(dc_current=np.nan, dc_conductance=np.nan, ac_current_abs=np.nan, ac_current_phase=np.nan)
            nmax = self.kondo.nmax
            try:
                results['dc_current'] = self.kondo.gammaL[nmax, nmax].real
            except:
                pass
            try:
                results['dc_conductance'] = self.kondo.deltaGammaL[nmax, nmax].real
            except:
                pass
            if nmax == 0:
                results['ac_current_abs'] = 0
            else:
                try:
                    results['ac_current_abs'] = np.abs(self.kondo.gammaL[nmax-1, nmax])
                    results['ac_current_phase'] = np.angle(self.kondo.gammaL[nmax-1, nmax])
                except:
                    pass
            return results
    
        def data(self, include='all'):
            if include == 'all':
                save = dict(
                        gamma = self.kondo.gamma.values,
                        z = self.kondo.z.values,
                        gammaL = self.kondo.gammaL.values,
                        deltaGammaL = self.kondo.deltaGammaL.values,
                        deltaGamma = self.kondo.deltaGamma.values,
                        yL = self.kondo.yL.values,
                        g2 = self.kondo.g2.to_numpy_array(),
                        g3 = self.kondo.g3.to_numpy_array(),
                        current = self.kondo.current.to_numpy_array(),
                        )
            elif include == 'reduced':
                if self.kondo.voltage_branches:
                    vb = self.kondo.voltage_branches
                    save = dict(
                            gamma = self.kondo.gamma[vb],
                            z = self.kondo.z[vb],
                            gammaL = self.kondo.gammaL.values,
                            deltaGammaL = self.kondo.deltaGammaL.values,
                            deltaGamma = self.kondo.deltaGamma[min(vb,1)],
                            g2 = self.kondo.g2.to_numpy_array()[:,:,vb],
                            g3 = self.kondo.g3.to_numpy_array()[:,:,vb],
                            current = self.kondo.current.to_numpy_array(),
                            )
                else:
                    save = dict(
                            gamma = self.kondo.gamma.values,
                            z = self.kondo.z.values,
                            gammaL = self.kondo.gammaL.values,
                            deltaGammaL = self.kondo.deltaGammaL.values,
                            deltaGamma = self.kondo.deltaGamma.values,
                            g2 = self.kondo.g2.to_numpy_array(),
                            g3 = self.kondo.g3.to_numpy_array(),
                            current = self.kondo.current.to_numpy_array(),
                            )
            elif include == 'observables':
                if self.kondo.voltage_branches:
                    vb = self.kondo.voltage_branches
                    save = dict(
                            gamma = self.kondo.gamma[vb],
                            gammaL = self.kondo.gammaL.values,
                            deltaGammaL = self.kondo.deltaGammaL.values,
                            )
                else:
                    save = dict(
                            gamma = self.kondo.gamma.values,
                            gammaL = self.kondo.gammaL.values,
                            deltaGammaL = self.kondo.deltaGammaL.values,
                            )
            elif include == 'minimal':
                nmax = self.kondo.nmax
                if self.kondo.voltage_branches:
                    vb = self.kondo.voltage_branches
                    save = dict(
                            gamma = self.kondo.gamma[vb,:,nmax],
                            gammaL = self.kondo.gammaL[:,nmax],
                            deltaGammaL = self.kondo.deltaGammaL[:,nmax],
                            )
                else:
                    save = dict(
                            gamma = self.kondo.gamma[:,nmax],
                            gammaL = self.kondo.gammaL[:,nmax],
                            deltaGammaL = self.kondo.deltaGammaL[:,nmax],
                            )
            else:
                raise ValueError("Unknown value for include: " + include)
            return save
    
        def save_npz(self, filename, include='all'):
            np.savez(filename, **self.metadata, **self.data(include))
    
        def save_h5(self, filename, include='all', overwrite=False):
            """
            Returns absolute path to filename where data have been saved.
            If overwrite is False and a file would be overwritten, append a random
            string to the end of the filename.
            """
            os.sync()
            while os.path.exists(filename + '.lock'):
                try:
                    settings.logger.warning('File %s is locked, waiting 0.5s'%filename)
                    sleep(0.5)
                except KeyboardInterrupt:
                    answer = input('Ignore lock file? Then type "yes": ')
                    if answer.lower() == "yes":
                        break
                    answer = input('Save with filename extended by random string? (Yn): ')
                    if answer.lower()[0] != "n":
                        return self.save_h5(filename + random_string(8) + ".h5", include, overwrite)
            pathlib.Path(filename + '.lock').touch()
            try:
                file_exists = os.path.exists(filename)
                h5file = None
                while h5file is None:
                    try:
                        h5file = tb.open_file(filename, "a")
                    except tb.exceptions.HDF5ExtError:
                        settings.logger.warning('Error opening file %s, waiting 0.5s'%filename)
                        sleep(0.5)
                try:
                    if file_exists:
                        try:
                            h5file.is_visible_node('/data/' + self.hash)
                            settings.logger.warning("Hash exists in file %s!"%filename)
                            return self.save_h5(filename + random_string(8) + ".h5", include, overwrite)
                        except tb.exceptions.NoSuchNodeError:
                            pass
                        metadata_table = h5file.get_node("/metadata/mdtable")
                    else:
                        # create new file
                        metadata_parent = h5file.create_group(h5file.root, "metadata", "Metadata")
                        metadata_table = h5file.create_table(metadata_parent,
                                'mdtable',
                                dict(
                                    idnum = tb.Int32Col(),
                                    hash = tb.StringCol(40),
                                    omega = tb.Float64Col(),
                                    energy = tb.ComplexCol(16),
                                    version_major = tb.Int16Col(),
                                    version_minor = tb.Int16Col(),
                                    git_commit_count = tb.Int16Col(),
                                    git_commit_id = tb.Int32Col(),
                                    timestamp = tb.Time64Col(),
                                    method = tb.EnumCol(KondoExport.METHOD_ENUM, 'unknown', 'int8'),
                                    solver_method = tb.EnumCol(KondoExport.SOLVER_METHOD_ENUM, 'unknown', 'int8'),
                                    solver_tol_abs = tb.Float64Col(),
                                    solver_tol_rel = tb.Float64Col(),
                                    d = tb.Float64Col(),
                                    vdc = tb.Float64Col(),
                                    vac = tb.Float64Col(),
                                    nmax = tb.Int16Col(),
                                    padding = tb.Int16Col(),
                                    voltage_branches = tb.Int16Col(),
                                    resonant_dc_shift = tb.Int16Col(),
                                    solver_flags = tb.Int16Col(),
                                )
                            )
                        h5file.create_group(h5file.root, "data", "Floquet matrices")
                        h5file.flush()
    
                    # Save metadata
                    row = metadata_table.row
                    idnum = metadata_table.shape[0]
                    row['idnum'] = idnum
                    metadata = self.metadata
                    row['method'] = KondoExport.METHOD_ENUM[metadata.pop('method')]
                    row['solver_method'] = KondoExport.SOLVER_METHOD_ENUM[metadata.pop('solver_method')]
                    for key, value in metadata.items():
                        row[key] = value
                    row.append()
    
                    # save data
                    datagroup = h5file.create_group("/data/", self.hash)
                    data = self.data(include)
                    for key, value in data.items():
                        h5file.create_array(datagroup, key, value)
                    h5file.flush()
                finally:
                    h5file.close()
            finally:
                os.remove(filename + ".lock")
            return os.path.abspath(filename)
    
    
    class KondoImport:
        """
        Class for importing Kondo objects that were saved with KondoExport.
        Example usage:
        >>> kondo, = KondoImport.read_from_h5("data/frtrg-01.h5", "94f81d2b49df15912798d95cae8e108d75c637c2")
        >>> print(kondo.gammaL[kondo.nmax, kondo.nmax])
        """
        def __init__(self, metadata, datanode, h5file, owns_h5file=False):
            self.metadata = metadata
            self._datanode = datanode
            self._h5file = h5file
            self._owns_h5file = owns_h5file
    
        def __del__(self):
            if self._owns_h5file:
                settings.logger.info("closing h5file")
                self._h5file.close()
    
        @classmethod
        def read_from_h5(cls, filename, khash):
            h5file = tb.open_file(filename, "r")
            datanode = h5file.get_node('/data/' + khash)
            metadatatable = h5file.get_node('/metadata/mdtable')
            counter = 0
            for row in metadatatable.where(f"hash == '{khash}'"):
                metadata = {key:row[key] for key in metadatatable.colnames}
                item = cls(metadata, datanode, h5file)
                yield item
                counter += 1
            if counter == 1:
                item._owns_h5file = True
            else:
                settings.logger.warning("h5file will not be closed automatically")
    
        @classmethod
        def read_all_from_h5(cls, filename):
            h5file = tb.open_file(filename)
            metadatatable = h5file.get_node('/metadata/mdtable')
            counter = 0
            for row in metadatatable:
                metadata = {key:row[key] for key in metadatatable.colnames}
                datanode = h5file.get_node('/data/' + row.hash)
                item = cls(metadata, datanode, h5file)
                yield item
                counter += 1
            if counter == 1:
                item._owns_h5file = True
            else:
                settings.logger.warning("h5file will not be closed automatically")
    
        def __getitem__(self, name):
            if name in self.metadata:
                return self.metadata[name]
            if name in self._datanode:
                return self._datanode[name].read()
            raise KeyError("Unknown key: %s"%name)
    
        def __getattr__(self, name):
            if name in self.metadata:
                return self.metadata[name]
            if name in self._datanode:
                return self._datanode[name].read()
            raise AttributeError("Unknown attribute name: %s"%name)
    
    
    
    class DataManager:
        '''
        Database structure
        tables:
            datapoints (single data point)
        '''
        SOLVER_FLAGS = dict(
                contains_flow = 0x001,
                reduced = 0x002,
                deleted = 0x004,
                simplified_initial_conditions = 0x008,
                )
    
        def __init__(self):
            self.version = settings.VERSION
            self.engine = db.create_engine(settings.DB_CONNECTION_STRING, future=True, echo=False)
    
            self.metadata = db.MetaData()
            try:
                self.table = db.Table('datapoints', self.metadata, autoload=True, autoload_with=self.engine)
            except db.exc.NoSuchTableError:
                with self.engine.begin() as connection:
                    settings.logger.info('Creating database table datapoints')
                    self.table = db.Table(
                            'datapoints',
                            self.metadata,
                            db.Column('id', db.INTEGER(), primary_key=True),
                            db.Column('hash', db.CHAR(40)),
                            db.Column('version_major', db.SMALLINT()),
                            db.Column('version_minor', db.SMALLINT()),
                            db.Column('git_commit_count', db.SMALLINT()),
                            db.Column('git_commit_id', db.INTEGER()),
                            db.Column('timestamp', db.TIMESTAMP()),
                            db.Column('method', db.Enum('unknown', 'mu', 'J', 'J-compact-1', 'J-compact-2', 'mu-reference', 'J-reference')),
                            db.Column('solver_method', db.Enum('unknown', 'RK45', 'RK23', 'DOP853', 'Radau', 'BDF', 'LSODA', 'other')),
                            db.Column('solver_tol_abs', db.FLOAT()),
                            db.Column('solver_tol_rel', db.FLOAT()),
                            db.Column('omega', db.FLOAT()),
                            db.Column('d', db.FLOAT()),
                            db.Column('vdc', db.FLOAT()),
                            db.Column('vac', db.FLOAT()),
                            db.Column('energy_re', db.FLOAT()),
                            db.Column('energy_im', db.FLOAT()),
                            db.Column('dc_current', db.FLOAT()),
                            db.Column('ac_current_abs', db.FLOAT()),
                            db.Column('ac_current_phase', db.FLOAT()),
                            db.Column('dc_conductance', db.FLOAT()),
                            db.Column('nmax', db.SMALLINT()),
                            db.Column('padding', db.SMALLINT()),
                            db.Column('voltage_branches', db.SMALLINT()),
                            db.Column('resonant_dc_shift', db.SMALLINT()),
                            db.Column('solver_flags', db.SMALLINT()), # unfortunately SET is not available in SQLite
                            db.Column('dirname', db.String(256)),
                            db.Column('basename', db.String(128)),
                        )
                    self.table.create(bind=connection)
    
        def insert_from_h5file(self, filename):
            raise NotImplementedError()
            basename = os.path.basename(filename)
            dirname = os.path.dirname(filename)
            # TODO
    
        def insert_in_db(self, filename : str, kondo : KondoExport):
            '''
            Save metadata in database for data stored in filename.
            '''
            metadata = kondo.metadata
            metadata.update(kondo.main_results)
            energy = metadata.pop('energy')
            metadata.update(
                        energy_re = energy.real,
                        energy_im = energy.imag,
                        timestamp = datetime.fromtimestamp(metadata.pop("timestamp")).isoformat().replace('T', ' '),
                        dirname = os.path.dirname(filename),
                        basename = os.path.basename(filename),
                    )
            frame = pd.DataFrame(metadata, index=[0])
            frame.to_sql(
                    'datapoints',
                    self.engine,
                    if_exists='append',
                    index=False,
                    )
            try:
                del self.df_table
            except AttributeError:
                pass
    
        def import_from_db(self, db_string, replace_base_path={}):
            """
            e.g. replace_base_path = {'/path/on/cluster/to/data':'/path/to/local/data'}
            """
            raise NotImplementedError()
            # TODO: rewrite
            import_engine = db.create_engine(db_string, future=True, echo=False)
            import_metadata = db.MetaData()
            import_table = db.Table('datapoints', import_metadata, autoload=True, autoload_with=import_engine)
            with import_engine.begin() as connection:
                import_df_table = pd.read_sql_table('datapoints', connection, index_col='id')
            valid_indices = []
            for idx in import_df_table.index:
                import_df_table.dirname[idx] = replace_all(import_df_table.dirname[idx], replace_base_path)
                # TODO: rewrite this
                selection = self.df_table.basename == import_df_table.basename[idx]
                if not any(self.df_table.dirname[selection] == import_df_table.dirname[idx]):
                    valid_indices.append(idx)
            settings.logger.info('Importing %d entries'%len(valid_indices))
            import_df_table.loc[valid_indices].to_sql(
                    'datapoints',
                    self.engine,
                    if_exists='append',
                    index=False,
                    )
    
        def save_h5(self, kondo : KondoExport, filename : str = None, include='all', overwrite=False):
            '''
            Save all data in given filename and keep metadata in database.
            '''
            if filename is None:
                filename = os.path.join(settings.BASEPATH, settings.FILENAME)
            if not isinstance(kondo, KondoExport):
                kondo = KondoExport(kondo)
            filename = kondo.save_h5(filename, include, overwrite)
            self.insert_in_db(filename, kondo)
    
        def cache_df_table(self, min_version=(0,5,-1)):
            settings.logger.debug('DataManager: cache df_table', flush=True)
            with self.engine.begin() as connection:
                df_table = pd.read_sql_table('datapoints', connection, index_col='id')
            selection = (df_table.solver_flags & DataManager.SOLVER_FLAGS['deleted']) == 0
            selection &= (df_table.version_major > min_version[0]) | ( (df_table.version_major == min_version[0]) & (df_table.version_minor >= min_version[1]) )
            selection &= df_table.energy_re == 0
            selection &= df_table.energy_im == 0
            if len(min_version) > 2 and min_version[2] > 0:
                selection &= df_table.git_commit_count >= min_version[2]
            self.df_table = df_table[selection]
    
        def __getattr__(self, name):
            if name == 'df_table':
                self.cache_df_table()
                return self.df_table
    
        def load(self, db_id):
            '''
            db_id is the id in the database (starts counting from 1)
            '''
            raise NotImplementedError
            row = self.df_table.loc[db_id]
            path = os.path.join(row.dirname, row.basename)
            kondo = ...
            kondo.solveopts = dict(
                    method = row.solver_method,
                    rtol = row.solver_tol_rel,
                    atol = row.solver_tol_abs,
                )
            return kondo
    
        def list(self, min_version=(14,0,-1,-1), **parameters):
            '''
            Print and return DataFrame with selection of physical parameters.
            '''
            selection = (self.df_table.version_major > min_version[0]) | (self.df_table.version_major == min_version[0]) & (self.df_table.version_minor >= min_version[1])
            selection &= self.df_table.energy_re == 0
            selection &= self.df_table.energy_im == 0
            if len(min_version) > 2 and min_version[2] > 0:
                selection &= self.df_table.git_commit_count >= min_version[2]
            for key, value in parameters.items():
                if value is None:
                    continue
                try:
                    selection &= self.df_table[key] == value
                except KeyError:
                    settings.logger.warning("Unknown key: %s"%key)
            if selection is True:
                result = self.df_table
            else:
                result = self.df_table.loc[selection]
            return result
    
        def load_from_table(self, table, load_flow=False, load_old_files=True):
            '''
            Extend table by adding a "solver" column.
            '''
            solvers = []
            reduced_table = table
            for idx, row in table.iterrows():
                old_file = load_old_files and row.version_major == 0 and row.version_minor < 6
                loader = Solver.load_old_file if old_file else Solver.load
                try:
                    solvers.append(loader(os.path.join(row.dirname, row.basename), load_flow))
                except FileNotFoundError:
                    settings.logger.exception('Could not find file: "%s" / "%s"'%(row.dirname, row.basename))
                    reduced_table = reduced_table.drop(idx)
                except AssertionError:
                    settings.logger.exception('Error while loading file: "%s" / "%s"'%(row.dirname, row.basename))
                    reduced_table = reduced_table.drop(idx)
            return reduced_table.assign(solver = solvers)
    
        def list_kondo(self, **kwargs):
            '''
            Returns a DataFrame with an extra column "solvers" with the filters
            from kwargs applied (see documentation of DataManager.list for the
            filters).
            '''
            return self.load_from_table(self.list(**kwargs))
    
        def clean_database(self):
            '''
            Flag all database entries as 'deleted' for which no solver file can be found.
            Delete duplicated entries.
            '''
            raise NotImplementedError
            with self.engine.begin() as connection:
                full_df_table = pd.read_sql_table('datapoints', connection, index_col='id')
            remove_indices = []
            for idx, row in full_df_table.iterrows():
                path = os.path.join(row.dirname, row.basename)
                if not os.path.exists(path):
                    path += '.npz'
                if not os.path.exists(path):
                    settings.logger.warning('File does not exist:', path, idx)
                    row.solver_flags |= DataManager.SOLVER_FLAGS['deleted']
                    stmt = db.update(self.table).where(self.table.c.id == idx).values(solver_flags = row.solver_flags)
                    with self.engine.begin() as connection:
                        connection.execute(stmt)
    
    def list_data(**kwargs):
        table = DataManager().list(**kwargs)
        print(result[['method', 'vdc', 'vac', 'omega', 'nmax', 'voltage_branches', 'padding', 'dc_current', 'dc_conductance', 'ac_current_abs']])