Skip to content
Snippets Groups Projects
Select Git revision
  • 2cba78a344c061320070b765df75282db8ec17ca
  • master default protected
2 results

stop.py

Blame
  • GeneratorManager.py 8.08 KiB
    #! /usr/bin/python3
    from __future__ import annotations
    
    import inspect
    import os
    import importlib
    import importlib.util
    import subprocess
    import typing
    from pathlib import Path
    
    # for printing a nice progress bar
    import tqdm
    
    from scripts.Infrastructure.ErrorGenerator import ErrorGenerator
    from scripts.Infrastructure.Template import TemplateManager
    from scripts.Infrastructure.Variables import featurelist
    
    # number of digits to use numbering filenames
    digits_to_use = 3
    
    
    def import_module(root, file):
        """
        Private function: imports a python module from a file
        """
        full_fname = os.path.join(root, file)
        name = file[:-3]  # name without .py suffix
        spec = importlib.util.spec_from_file_location(name, full_fname)
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        return module
    
    
    class GeneratorManager:
        def __init__(self, path, print_discovery=True, skip_invalid=False):
            """
            Instantiates an GeneratorManager and discovers ErrorGenerator classes from Python files in the specified path.
    
            Parameters:
                - `path` (str): The path to search for Python files containing ErrorGenerator classes.
                - `print_discovery` (bool, optional): Whether to print the discovered generators. Defaults to True.
                - `skip_invalid` (bool, optional): Whether to skip invalid generators. Defaults to False.
    
            Returns:
                None
    
            Raises:
                AssertionError: If an ErrorGenerator class is found with unknown features and `skip_invalid` is False.
    
            Note:
                - Discovers the generators in Python files with the '.py' extension in the specified path and its subdirectories.
            """
            self.generators = []
            self.discover_generators(path, print_discovery, skip_invalid)
            self.case_names = {}
            # discover all Error Generators
            pass
    
        def get_filename(self, case_name, suffix=".c"):
            """
            Private Function: Helps to generate filenames for the generated testcases
            """
            num = 0
            if case_name in self.case_names:
                num = self.case_names[case_name]
            num = num + 1
            self.case_names[case_name] = num
    
            return case_name + "-" + str(num).zfill(digits_to_use) + suffix
    
        def generate(self, outpath: str | Path | os.PathLike[str], filterlist_: typing.Sequence[str] = None,
                     print_progress_bar: bool = True, overwrite: bool = True, generate_full_set: bool = False,
                     try_compile: bool = False, max_mpi_version: str = "4.0", use_clang_format: bool = True):
            """
            Generates test cases based on the specified parameters.
            Parameters:
                - `outpath` (str): The path where the generated test cases will be saved.
                - `filterlist_` (list, optional): A list of features to filter the generators. Defaults to None (no filters).
                - `print_progress_bar` (bool, optional): Whether to print a progress bar. Defaults to True.
                - `overwrite` (bool, optional): Whether to overwrite existing files. Defaults to True.
                - `generate_full_set` (bool, optional): Whether to generate the full (extended) set of errors. Defaults to False.
                - `try_compile` (bool, optional): Whether to try compiling the generated test cases. Defaults to False.
                - `use_clang_format` (bool, optional): Whether to format the generated test cases. Defaults to True.
                - `max_mpi_version` (float, optional): The maximum MPI version allowed for generated test cases. Defaults to 4.0.
    
            Returns:
                None
    
            Raises:
                AssertionError: If the environment variable 'MPICC' is not set when `try_compile` is True.
                CalledProcessError: If compilation fails during the try_compile process or clang-formart fails.
    
            Note:
                - The progress bar is printed using the tqdm module.
                - If `try_compile` is True, it uses the compiler specified via 'MPICC' environment variable to attempt compilation.
                - the Features and if a test belongs to the base of frull test set is defined by the respective generators
                - Generators can Raise CorrectTestcase, therefore the number of discovered testcases may not match the number of generated cases
            """
            filterlist = filterlist_
            if filterlist is None:
                filterlist = featurelist
    
            if try_compile:
                mpicc = os.environ.get('MPICC')
                assert mpicc and "Environment var MPICC not set"
    
            # use generator if at least one feature of the generator matches the filterlist
            generators_to_use = [g for g in self.generators if any(elem in filterlist for elem in g.get_feature())]
    
            print("Generate Testcases using %d generators" % len(generators_to_use))
    
            # prints a nice progress bar
            if print_progress_bar:
                progress_bar = tqdm.tqdm(total=len(generators_to_use))
            cases_generated = 0
    
            for generator in generators_to_use:
                # use first feature as category if generatro has multiple
                categroy_path = os.path.join(outpath, generator.get_feature()[0])
                os.makedirs(categroy_path, exist_ok=True)
    
                for result_error in generator.generate(generate_full_set):
                    assert isinstance(result_error, TemplateManager)
    
                    if not float(result_error.get_version()) > float(max_mpi_version):
                        case_name = result_error.get_short_descr()
                        fname = self.get_filename(case_name)
                        full_name = os.path.join(categroy_path, fname)
    
                        if not overwrite and os.path.isfile(full_name):
                            assert False and "File Already exists"
    
                        result_str = str(result_error)
                        if use_clang_format:
                            result_str = subprocess.check_output(["clang-format"], text=True,
                                                                 input=result_str)
    
                        with open(full_name, "w") as text_file:
                            text_file.write(result_str)
                        cases_generated += 1
    
                        if try_compile:
                            subprocess.check_call([mpicc, full_name])
                            # raises CalledProcessError if code does not compile
    
                if print_progress_bar:
                    progress_bar.update(1)
            if print_progress_bar:
                progress_bar.close()
    
            print("Finished. Generated %i cases" % cases_generated)
    
            pass
    
        def discover_generators(self, path, print_discovery=True, skip_invalid=False):
            """
            Private Function. see Documentation for __init__()
            """
            if print_discovery:
                print("Discover Generators:")
            for root, dirs, files in os.walk(path):
                for file in files:
                    if file.endswith('.py'):
                        module = import_module(root, file)
                        for name, obj in inspect.getmembers(module):
                            # if it is a class derived from ErrorGenerator (and is not the interface class itself)
                            if inspect.isclass(obj) and issubclass(obj, ErrorGenerator) and not obj is ErrorGenerator:
                                if print_discovery:
                                    print("Found Generator %s" % name)
                                # instantiate the object
                                generator = obj()
                                valid = True
                                for feature in generator.get_feature():
                                    if feature not in featurelist:
                                        if print_discovery:
                                            print("Generator has unknown feature: %s" % feature)
                                        valid = False
                                if valid:
                                    self.generators.append(generator)
                                else:
                                    if not skip_invalid:
                                        assert False and "Invalid Generator"
                                    else:
                                        pass  # just skip