diff --git a/docs/usage.md b/docs/usage.md index 9d82f5deb32402991569ed9400ef068dfbbaa820..54689f70e9878a8354c28d68911b58cb23b0ad98 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -30,7 +30,16 @@ For the output several options are possible. ### Debugging By using the `--debug` parameter the output of detailed information of the rule evaluation is enabled. Specifics of the input data preprocessing and for each rule evaluation step the output is printed. +Additional debugging arguments can be found in the `--help` output of the `prule` tools. +### Static type analysis +You can do a static type analysis with e.g. `mypy`. +Important: install and use the `mypy` package from the same python environment as you use for `prule`. +Necessary packages: `mypy`, `types-jsonschema`. + +``` +mypy prule +``` ## Job processing daemon diff --git a/prule/__init__.py b/prule/__init__.py index 23b55156828eec3ca0de0753b92831485ad7c563..af345adbc3cfe5c188ce73269a474e3212a1b821 100644 --- a/prule/__init__.py +++ b/prule/__init__.py @@ -1,3 +1,4 @@ +import typing import sys import os.path import re @@ -35,7 +36,6 @@ job_names = ["job","hwthreads"] base_path = os.path.dirname(__file__) unit_path = os.path.join(base_path, "units.txt") -unit_registry = None try: unit_registry = pint.UnitRegistry() unit_registry.load_definitions(unit_path) @@ -47,7 +47,7 @@ except Exception as e: # Create core mask string from arrays and topology # Ff entry in mask array is >0 then the core is masked. -def core_mask_str(mask, cores_per_socket, sockets_per_node, memorydomains_per_node): +def core_mask_str(mask, cores_per_socket: int, sockets_per_node: int, memorydomains_per_node: int) -> str: mstr = "" cores_per_memorydomain = (cores_per_socket * sockets_per_node / memorydomains_per_node) socket_cores = cores_per_socket @@ -68,7 +68,7 @@ def core_mask_str(mask, cores_per_socket, sockets_per_node, memorydomains_per_no # Generates the set of defined local names and used global names -def rule_used_names(rule): +def rule_used_names(rule: dict) -> typing.Tuple[list, dict]: local_names = {} global_names = {} # iterate over all terms @@ -97,10 +97,10 @@ def rule_used_names(rule): else: attr = None local_names[term_var] = True - return (local_names.keys(), global_names) + return (list(local_names.keys()), global_names) # Consistency check for parameter, rule and cluster specification -def configuration_check(parameters, rules, clusters): +def configuration_check(parameters: dict, rules: list, clusters: list) -> None: # check cluster topology specs log.print_color(log.color.magenta, log.debug, "-"*25,"Load cluster specifications:","-"*25) @@ -112,7 +112,7 @@ def configuration_check(parameters, rules, clusters): # namespace conflicts names = {} # dictionary of all names # fill in all builtin names - for name in dir(builtin): + for name in dir(prule.builtin): names[name] = True #print(name) #TODO: fix import of names from builtin, use list of exported names or import all attributes? @@ -156,7 +156,7 @@ def configuration_check(parameters, rules, clusters): #raise Exception("In rule {} the term with index {} uses the unknown literal {}.".format(rule["name"], global_names[g], g)) log.print_color(log.color.yellow, log.debug, "In rule \"{}\" the term with index {} uses the unknown literal \"{}\".".format(rule["name"], global_names[g], g)) -job_meta_fields = [ +job_meta_fields: list = [ ] # Dummy class to be able to define attributes on objects. @@ -169,17 +169,21 @@ class JobMetadata: # The node id is identified by index of the hostname in the resource list. # For the resource hierarchy, a metric for specific scope level should make the id in the levels above identifiable. # E.g. if one knows the core id, then the memoryDomain and socket ids can be identified. -def get_scopeids(clusters, job_meta, hostname, thread=None, core=None, memoryDomain=None, socket=None, node=None, accelerator=None): - cluster = None +def get_scopeids(clusters: list, job_meta: dict, hostname: str, thread: typing.Optional[int] =None, core: typing.Optional[int] =None, memoryDomain: typing.Optional[int] =None, socket: typing.Optional[int] =None, node: typing.Optional[int] =None, accelerator: typing.Optional[int] =None) -> dict: + cluster: typing.Optional[dict] = None for c in clusters: if c["name"] == job_meta["cluster"]: cluster = c break - subCluster = None + if cluster == None: + raise Exception("Cluster {} not found in cluster input.".format(job_meta["cluster"])) + subCluster: typing.Optional[dict] = None for s in cluster["subClusters"]: if s["name"] == job_meta["subCluster"]: subCluster = s break + if subCluster == None: + raise Exception("Subcluster {} not found in cluster {} input.".format(job_meta["subCluster"], job_meta["cluster"])) topology = subCluster["topology"] scopeIds = {} if thread != None: @@ -233,7 +237,7 @@ def get_scopeids(clusters, job_meta, hostname, thread=None, core=None, memoryDom # example: gpu:a100:3(IDX:0-1,3),fpga:0 slurm_reg_res = re.compile("^([^:]+):((([^,:]*):)?([0-9]+))?(\\(IDX:([^)]+)\\))?,?") -def slurm_parse_resources(s): +def slurm_parse_resources(s: str) -> dict: r = s res = {} while r != "": @@ -270,7 +274,7 @@ def slurm_parse_resources(s): #print(slurm_parse_resources("gpu:a100:1(IDX:1)")) #print(slurm_parse_resources("")) -def slurm_seq_cpuids(s): +def slurm_seq_cpuids(s: str) -> typing.List[int]: cids = [] for term in s.split(","): mix = term.find("-") @@ -284,7 +288,7 @@ def slurm_seq_cpuids(s): cids += list( range( int(term[:mix]), int(term[mix+1:])+1 ) ) return cids -def slurm_seq_nodelist(pre, start, end): +def slurm_seq_nodelist(pre: str, start: str, end: str) -> typing.List[str]: #print(start,end) res = [] slen = len(start) @@ -295,7 +299,7 @@ def slurm_seq_nodelist(pre, start, end): return res -def slurm_expand_nodelist_num(pre, numlist): +def slurm_expand_nodelist_num(pre: str, numlist: str) -> typing.List[str]: #print(numlist) res = [] nlist = numlist @@ -316,7 +320,7 @@ def slurm_expand_nodelist_num(pre, numlist): reg_numlist=re.compile("^([0-9]+)(-[0-9]+)?,?(.*)?$") reg_nodelist=re.compile("^([^ [,]+)(\\[([^]]+)])?,?(.*)$") -def slurm_expand_nodelist(nodes): +def slurm_expand_nodelist(nodes: str) -> list: #print(nodes) res = [] nlist = nodes @@ -324,24 +328,24 @@ def slurm_expand_nodelist(nodes): re_nodelist = reg_nodelist.match(nlist) #print(nlist, re_nodelist) if re_nodelist != None: - re_nodelist = re_nodelist.groups() - pre = re_nodelist[0] - numlist = re_nodelist[2] + re_groups = re_nodelist.groups() + pre = re_groups[0] + numlist = re_groups[2] if numlist == None: res.append(pre) break else: res += slurm_expand_nodelist_num(pre, numlist) - #print(nlist, re_nodelist) - if len(re_nodelist) > 2: - nlist = re_nodelist[3] + #print(nlist, re_groups) + if len(re_groups) > 2: + nlist = re_groups[3] else: break #print(res) return res -def parse_slurminfo(info): - slurm = {} +def parse_slurminfo(info: str) -> dict: + slurm: dict = {} for l in info.split("\n"): if l == "": break # end of slurm info @@ -389,8 +393,8 @@ def parse_slurminfo(info): slurm["JobName"] = l[l.find("=", l.find(" "))+1: ].strip() return slurm -def parse_slurm_size(size): - num = 0 +def parse_slurm_size(size: str) -> float: + num = 0.0 if size.endswith("K"): num = float(size[:-1]) * 1024 elif size.endswith("M"): @@ -409,12 +413,12 @@ def parse_slurm_size(size): # - cluster specification # - job meta data # - job measurement data -def rule_prepare_input(parameters, rules, clusters, job_meta, job_data): +def rule_prepare_input(parameters: dict, rules: list, clusters: list, job_meta: dict, job_data: dict) -> dict: globals = {} #globals["mean"] = builtin.mean # add definitions from prule builtins - for k in builtin.public: - globals[k] = getattr(builtin, k) + for k in prule.builtin.public: + globals[k] = getattr(prule.builtin, k) for key, value in parameters.items(): if key == "job_requirements": continue @@ -427,7 +431,7 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data): #for jn in job_names: # globals[jn] = True - def quantity_create(value, unit=None): + def quantity_create(value, unit=None) -> pint.Quantity: if unit == None: if type(value) == str: return unit_registry.parse_expression(value) @@ -443,7 +447,7 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data): globals["quantity"] = quantity_create # prepare job metadata - job = JobMetadata() + job: typing.Any = JobMetadata() # copy all attributes from json to job object for attr in job_meta: if type(job_meta[attr]) not in [dict,list]: @@ -488,12 +492,12 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data): pass #alloc_arr[0,ix] = unit_registry.Quantity(s, unit_registry.B) alloc_arr[0,ix] = s - alloc_arr = unit_registry.Quantity(alloc_arr, unit_registry.B) + alloc_arr = unit_registry.Quantity(alloc_arr, unit_registry.B) # type: ignore setattr(job, "allocated_memory", alloc_arr) # metadata conversion - setattr(job, "walltime", unit_registry.Quantity(job.walltime,unit_registry.s)) - setattr(job, "duration", unit_registry.Quantity(job.duration,unit_registry.s)) + setattr(job, "walltime", unit_registry.Quantity(job.walltime,unit_registry.s)) # type: ignore + setattr(job, "duration", unit_registry.Quantity(job.duration,unit_registry.s)) # type: ignore # SMT enabled? cluster = None @@ -528,7 +532,7 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data): resources = [] if missing_resources == True: log.print_color(log.color.yellow, log.warn, "Warning: resources, such as used hwthreads and accelerators, not entirely specified in job metadata. Reconstructing from job measurements.") - host_resources = {} + host_resources: dict = {} for metric in job_data: metric_data = job_data[metric] for scope in metric_data: @@ -617,7 +621,7 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data): # add units #job.duration = unit_registry.Quantity(job.duration, unit_registry.s) # add scope specific thread numbers - numthreads = JobMetadata() + numthreads: typing.Any = JobMetadata() # threads per node @@ -711,9 +715,9 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data): elif scope == "node" and "hostname" in data["series"][0]: series_sorted = [] for h in job_meta["resources"]: - for s in data["series"]: - if s["hostname"] == h["hostname"]: - series_sorted.append(s) + for series in data["series"]: + if series["hostname"] == h["hostname"]: + series_sorted.append(series) break data["series"] = series_sorted else: @@ -801,7 +805,7 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data): else: timedata = data["timeseries"] timeseries = np.array(timedata) - timeseries = unit_registry.Quantity(timeseries, unit_registry.parse_expression("ns")) + timeseries = unit_registry.Quantity(timeseries, unit_registry.parse_expression("ns")) # type: ignore # check sample count min_samples = sys.maxsize max_samples = 0 @@ -834,7 +838,7 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data): b = i-1 if (v == None or np.isnan(v)) and a != None and i == len(dataseries[ix])-1: # NaN sequence at end of series b = i - if b != None: # found NaN sequence + if a != None and b != None: # found NaN sequence #print("fix sequence ", a, b, " of", metadataseries[ix]["metric"]) # debugging none_count += b-a if b != a else 1 if a == 0: # sequence at start, set to 0.0 @@ -866,8 +870,8 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data): # set type to float if data_array.dtype == object: data_array = data_array.astype(float) - data_array = unit_registry.Quantity(data_array, unit) - data = builtin.Data(data_array, timeseries, metadataseries) + data_array = unit_registry.Quantity(data_array, unit) # type: ignore + data = prule.builtin.Data(data_array, timeseries, metadataseries) #TODO: sort columns by scope id, e.g. scope=="core", 1. column time, rest: core id @@ -892,13 +896,15 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data): return globals -def debug_rule_store_terms(rule_name, term_index, var_name, var_value, debug_settings): +def debug_rule_store_terms(rule_name, term_index, var_name, var_value, debug_settings) -> None: if debug_settings == None or "debug_log_terms_dir" not in debug_settings: return outfile = "{}_{}_{}.csv".format(rule_name, term_index, var_name) debug_write_file(outfile, var_value) -def debug_write_file(outfile, value): +def debug_write_file(outfile, value) -> None: + if prule.debug.debug_settings == None: + return outpath = os.path.join(prule.debug.debug_settings["debug_log_terms_dir"], outfile) with open (outpath, "w", newline='') as csvfile: cwriter = csv.writer(csvfile, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) @@ -908,7 +914,7 @@ def debug_write_file(outfile, value): else: cwriter.writerow([value]) -def debug_prompt(rule, rinput, job_meta, locals, local_names, globals, term_index, term, skip_prompt): +def debug_prompt(rule, rinput, job_meta, locals, local_names, globals, term_index, term, skip_prompt) -> bool: if prule.debug.debug_settings == None: return True if skip_prompt == True: @@ -963,7 +969,7 @@ def debug_prompt(rule, rinput, job_meta, locals, local_names, globals, term_inde return_value = eval(expr, locals, globals) print(type(return_value)) print(return_value) - if isinstance(return_value, builtin.Data): + if isinstance(return_value, prule.builtin.Data): print(return_value.array) debug_write_file(fname, return_value) except Exception as e: @@ -979,7 +985,7 @@ def debug_prompt(rule, rinput, job_meta, locals, local_names, globals, term_inde return_value = eval(ex_line, locals, globals) print(type(return_value)) print(return_value) - if isinstance(return_value, builtin.Data): + if isinstance(return_value, prule.builtin.Data): print(return_value.array) if var_name != None: locals[var_name] = return_value @@ -993,7 +999,7 @@ def debug_prompt(rule, rinput, job_meta, locals, local_names, globals, term_inde # Evaluate rule. # Iterate over terms and evaluate python expression. # Track local and global variables. -def rule_evaluate(rule, rinput, job_meta, check_requirements): +def rule_evaluate(rule: dict, rinput: dict, job_meta: dict, check_requirements: bool) -> dict: log.print_color(log.color.magenta+log.color.bold, log.debug, "#"*25,"Evaluate rule:",rule["name"],"#"*25) output = {} output["name"] = rule["name"] @@ -1026,7 +1032,7 @@ def rule_evaluate(rule, rinput, job_meta, check_requirements): else: required_metrics_min_samples = min(job_meta["metric_min_sample_count"][m], required_metrics_min_samples) # prepare locals - locals_template = {} + locals_template: dict = {} locals_template["required_metrics_min_samples"] = required_metrics_min_samples # requirements @@ -1049,7 +1055,7 @@ def rule_evaluate(rule, rinput, job_meta, check_requirements): # evaluate terms locals = copy.deepcopy(locals_template) globals = rinput - local_names = {} + local_names: dict = {} skip_prompt = False for tix,term in enumerate(rule["terms"]): skip_prompt = debug_prompt(rule, rinput, job_meta, locals, local_names, globals, tix, term, skip_prompt) @@ -1064,7 +1070,7 @@ def rule_evaluate(rule, rinput, job_meta, check_requirements): return_value = eval(term_str, locals, globals) log.print_color(log.color.blue, log.debug, "Result for ",term_var,": ",type(return_value)) log.print_color(log.color.blue, log.debug, str(return_value)) - if isinstance(return_value, builtin.Data): + if isinstance(return_value, prule.builtin.Data): log.print_color(log.color.blue, log.debug, return_value.array) debug_rule_store_terms(rule["tag"], tix, term_var, return_value, prule.debug.debug_settings) except Exception as e: @@ -1093,7 +1099,7 @@ def rule_evaluate(rule, rinput, job_meta, check_requirements): log.print(log.debug, log.color.blue + "output: " + log.color.reset, rule["output"], type(locals[rule["output"]])) log.print(log.debug, locals[rule["output"]]) output_value = None - if isinstance(locals[rule["output"]], builtin.Data): + if isinstance(locals[rule["output"]], prule.builtin.Data): log.print(log.debug, locals[rule["output"]].array) output_value = locals[rule["output"]] try: @@ -1149,7 +1155,7 @@ def rule_evaluate(rule, rinput, job_meta, check_requirements): if rule["output_scalar"] in locals: log.print(log.debug, log.color.blue + "output_scalar: " + log.color.reset, rule["output_scalar"], type(locals[rule["output_scalar"]])) log.print(log.debug, locals[rule["output_scalar"]]) - if isinstance(locals[rule["output_scalar"]], builtin.Data): + if isinstance(locals[rule["output_scalar"]], prule.builtin.Data): log.print(log.debug, locals[rule["output_scalar"]].array) output["scalar"] = float(locals[rule["output_scalar"]].array) else: diff --git a/prule/__main__.py b/prule/__main__.py index 96145e949306b41c033d4bde939a701429dae49b..692471cbda4099caf077ddb5740154342b8f2f27 100644 --- a/prule/__main__.py +++ b/prule/__main__.py @@ -1,5 +1,6 @@ #!/bin/env python -u +import typing import os import os.path import sys @@ -22,10 +23,11 @@ import prule.db class Job: - def __init__(self, meta_path, data_path): + def __init__(self, meta_path: str, data_path: str): self.meta_path = meta_path self.data_path = data_path - def from_jobdir(jobdir_path): + @staticmethod + def from_jobdir(jobdir_path: str) -> 'Job': j_meta_path = os.path.join(jobdir_path, "meta.json") j_data_path = os.path.join(jobdir_path, "data.json") # test for compressed json, else swap back to normal one @@ -37,21 +39,21 @@ class Job: return j class IDict(dict): - def __init__(self, c): + def __init__(self, c: dict): self.update(c) - def __setitem__(self, k,v): + def __setitem__(self, k,v) -> typing.Never: raise Exception("Immutable dictionary") - def __delitem__(self, k): + def __delitem__(self, k) -> typing.Never: raise Exception("Immutable dictionary") -def make_immutable(d): +def make_immutable(d: dict) -> IDict: for k,v in d.items(): if type(v) == dict: d[v] = make_immutable(v) id = IDict(d) return id -def stdin_readjob(): +def stdin_readjob() -> typing.Union[None, Job, str]: try: line = sys.stdin.readline() except: @@ -166,8 +168,7 @@ if __name__ == "__main__": out_group.add_argument('--db-path', nargs=1, help='Path to sqlite output file.') - args = parser.parse_args() - args = vars(args) + args = vars(parser.parse_args()) if "args" in args: print(args) @@ -382,25 +383,30 @@ if __name__ == "__main__": rule_error = False output_all_jobs = [] while len(job_queue)>0 or "job_stdin" in args: - job = None + job: typing.Optional[Job] = None if "job_stdin" in args: - job = stdin_readjob() - if job == "error": + job_read = stdin_readjob() + if job_read == None: + break + elif type(job_read) == str: # "error" job_error = True print("{}") # in case of malformed input continue - if job == None: - break + elif type(job_read) == Job: + job = job_read else: job = job_queue.pop() + if job == None: + break + process_time_start = datetime.datetime.now().timestamp() error = False job_meta = None job_data = None job_id = None - job_output = {"jobId":None, "tags":[], "rules":[], "error":False, "errors":[], "rules_failed":[], "rules_evaluated":[], "rules_not_evaluated":[], "metadata":{}} + job_output: dict = {"jobId":None, "tags":[], "rules":[], "error":False, "errors":[], "rules_failed":[], "rules_evaluated":[], "rules_not_evaluated":[], "metadata":{}} # read meta file if os.path.exists(job.meta_path) == False: @@ -487,7 +493,7 @@ if __name__ == "__main__": matched_tags = [] rule_output = [] - if error == False: + if error == False and rinput != None and job_meta != None: for rule in rules_list: try: out = rule_evaluate(rule, rinput, job_meta, not args["ignore_rule_requirements"]) diff --git a/prule/builtin/__init__.py b/prule/builtin/__init__.py index e95e9850b0b6222d5f6ebc3198e9211e027c1fc8..635407471e48a1906418daf22c28e74d2a910bee 100644 --- a/prule/builtin/__init__.py +++ b/prule/builtin/__init__.py @@ -1,6 +1,7 @@ # builtin functions # applicable to matrices, scalars, boolean values +import typing import builtins import numpy as np import statistics @@ -10,11 +11,7 @@ import pint import prule import types -public = [ -#"mean", -#"max" -] - +public: list = [] # Constants for scopes, necessary? #SCOPE_THREAD = "thread" @@ -28,24 +25,29 @@ public = [ _str_scope = ["all", "thread", "core", "numa", "socket", "node", "accelerator", "job"] _int_scope = {"thread":1, "core":2, "numa":3, "socket":4, "node":5, "accelerator":6, "job":7} -def _lt_scope(x, y): - x = _int_scope[x] - y = _int_scope[y] - return x < y - -def _gt_scope(x, y): - x = _int_scope[x] - y = _int_scope[y] - return x > y +def _lt_scope(x: typing.Optional[str], y: typing.Optional[str]) -> bool: + if x == None or y == None: + raise KeyError(None) + x_ix = _int_scope[x] + y_ix = _int_scope[y] + return x_ix < y_ix -def _eq_scope(x, y): - x = _int_scope[x] - y = _int_scope[y] - return x == y +def _gt_scope(x: typing.Optional[str], y: typing.Optional[str]) -> bool: + if x == None or y == None: + raise KeyError(None) + x_ix = _int_scope[x] + y_ix = _int_scope[y] + return x_ix > y_ix +def _eq_scope(x: typing.Optional[str], y: typing.Optional[str]) -> bool: + if x == None or y == None: + raise KeyError(None) + x_ix = _int_scope[x] + y_ix = _int_scope[y] + return x_ix == y_ix class Data: - def __init__(self, array, time, metadata): + def __init__(self, array, time, metadata: list): self.array = array self.time = time self.metadata = metadata @@ -57,8 +59,9 @@ class Data: # Prevent accidental overwriting of object members def setattr(self, a, b): raise Exception("Data class objects are not mutable") - self.__setattr__ = types.MethodType(setattr, self) - def __getitem__(self, indices): + #self.__setattr__ = types.MethodType(setattr, self) + builtins.setattr(self, "__setattr__", types.MethodType(setattr, self)) + def __getitem__(self, indices) -> 'Data': # if `indices` is a tuple, then the tuple elements index the two respective dimensions (time, columns) # else, `indices` indexes the first dimension (time) if type(indices) != tuple: @@ -74,30 +77,28 @@ class Data: else: raise Exception("Only 1 or 2-dimensional indices are supported") return Data(array_new, time_new, metadata_new) - def __setitem__(self, indices): + def __setitem__(self, indices) -> typing.Never: raise Exception("Data class objects are not mutable") - def nbytes(self): + def nbytes(self) -> int: return self.array.nbytes - def set_writeable(self, v): + def set_writeable(self, v: bool) -> None: self.array.flags.writeable = v - def __str__(self): + def __str__(self) -> str: unit = self.array.units if isinstance(self.array, pint.UnitRegistry.Quantity) else "" return "<{} Class:{} Unit:{} DType:{} Shape:{} {}>".format(type(self), type(self.array), unit, self.array.dtype, self.array.shape, str(self.array)) - def __int__(self): + def __int__(self) -> int: return int(self.array) - def __long__(self): - return long(self.array) - def __float__(self): + def __float__(self) -> float: return float(self.array) - def __complex__(self): + def __complex__(self) -> complex: return complex(self.array) - def __oct__(self): + def __oct__(self) -> str: return oct(self.array) - def __hex__(self): + def __hex__(self) -> str: return hex(self.array) - def __bool__(self): + def __bool__(self) -> bool: return bool(self.array) - def scope(self): + def scope(self) -> typing.Optional[str]: scope = None column_num = 1 if self.array.ndim == 1 else self.array.shape[1] for cix in range(0, column_num): @@ -106,7 +107,7 @@ class Data: if scope == None or _lt_scope(s, scope): scope = s return scope - def _scope_columns(self, scope): + def _scope_columns(self, scope: str) -> typing.Tuple[list, list]: pscope = _str_scope[_int_scope[scope] -1] unique=[] cols={} @@ -135,35 +136,35 @@ class Data: return (col_list, metadata_new) #TODO: Add deepcopy() calls to Data() arguments - def __add__(self, other): + def __add__(self, other) -> 'Data': other = other.array if isinstance(other, Data) else other array_new = self.array.__add__(other) return Data(array_new, self.time, self.metadata) - def __sub__(self, other): + def __sub__(self, other) -> 'Data': other = other.array if isinstance(other, Data) else other array_new = self.array.__sub__(other) return Data(array_new, self.time, self.metadata) - def __mul__(self, other): + def __mul__(self, other) -> 'Data': other = other.array if isinstance(other, Data) else other array_new = self.array.__mul__(other) return Data(array_new, self.time, self.metadata) - def __lt__(self, other): + def __lt__(self, other) -> 'Data': other = other.array if isinstance(other, Data) else other array_new = self.array.__lt__(other) return Data(array_new, self.time, self.metadata) - def __gt__(self, other): + def __gt__(self, other) -> 'Data': other = other.array if isinstance(other, Data) else other array_new = self.array.__gt__(other) return Data(array_new, self.time, self.metadata) - def __rsub__(self, other): + def __rsub__(self, other) -> 'Data': other = other.array if isinstance(other, Data) else other array_new = self.array.__rsub__(other) return Data(array_new, self.time, self.metadata) - def __truediv__(self, other): + def __truediv__(self, other) -> 'Data': other = other.array if isinstance(other, Data) else other array_new = self.array.__truediv__(other) return Data(array_new, self.time, self.metadata) - def mean(self, scope='time'): + def mean(self, scope='time') -> 'Data': if scope == 'all': new_array = self.array.mean(keepdims=True) new_time = self.time.mean(axis=0, keepdims=True) @@ -204,7 +205,7 @@ class Data: # Integrated with trapz: 3.5 # Timediff: 6 # Mean: 3.5 / 6 = 0.583 - def mean_int(self, scope='time'): + def mean_int(self, scope='time') -> 'Data': if scope == 'all': new_data = self.mean_int(scope='time') new_array = new_data.array.mean(keepdims=True) @@ -222,7 +223,7 @@ class Data: # mean of the timestamps new_time = self.time.mean(axis=0, keepdims=True) # integrate using the timestamps - new_array = np.trapz(self.array, x=self.time, axis=0) + new_array = np.trapz(self.array, x=self.time, axis=0) # type: ignore # compute time difference time_min = self.time.min(axis=0, keepdims=True) time_max = self.time.max(axis=0, keepdims=True) @@ -233,7 +234,7 @@ class Data: return Data(new_array, new_time, copy.deepcopy(self.metadata)) else: return self.mean(scope=scope) - def sum(self, scope='time'): + def sum(self, scope='time') -> 'Data': if scope == 'all': new_array = self.array.sum(keepdims=True) new_time = self.time.sum(axis=0, keepdims=True) @@ -257,11 +258,11 @@ class Data: new_col.append(col_sum) new_array = np.column_stack(new_col) return Data(new_array, copy.deepcopy(self.time), copy.deepcopy(self.metadata)) - def any(self, scope='time'): + def any(self, scope='time') -> 'Data': if scope == 'all': new_array = self.array.any(keepdims=True) new_time = np.expand_dims(np.array([]), axis=1) - new_metadata = [] + new_metadata: list = [] return Data(new_array, new_time, new_metadata) elif scope == 'time': new_time = self.time.any(axis=0, keepdims=True) @@ -278,11 +279,11 @@ class Data: new_col.append(col_any) new_array = np.column_stack(new_col) return Data(new_array, copy.deepcopy(self.time), copy.deepcopy(self.metadata)) - def min(self, scope='time'): + def min(self, scope='time') -> 'Data': if scope == 'all': new_array = self.array.min(keepdims=True) new_time = np.expand_dims(np.array([]), axis=1) - new_metadata = [] + new_metadata: list = [] return Data(new_array, new_time, new_metadata) elif scope == 'time': new_time = self.time.min(axis=0, keepdims=True) @@ -299,11 +300,11 @@ class Data: new_col.append(col_any) new_array = np.column_stack(new_col) return Data(new_array, copy.deepcopy(self.time), copy.deepcopy(self.metadata)) - def max(self, scope='time'): + def max(self, scope='time') -> 'Data': if scope == 'all': new_array = self.array.max(keepdims=True) new_time = np.expand_dims(np.array([]), axis=1) - new_metadata = [] + new_metadata: list = [] return Data(new_array, new_time, new_metadata) elif scope == 'time': new_time = self.time.max(axis=0, keepdims=True) @@ -320,11 +321,11 @@ class Data: new_col.append(col_any) new_array = np.column_stack(new_col) return Data(new_array, copy.deepcopy(self.time), copy.deepcopy(self.metadata)) - def std(self, scope='time'): + def std(self, scope='time') -> 'Data': if scope == 'all': new_array = self.array.std(keepdims=True) new_time = np.expand_dims(np.array([]), axis=1) - new_metadata = [] + new_metadata: list = [] return Data(new_array, new_time, new_metadata) elif scope == 'time': new_time = self.time.std(axis=0, keepdims=True) @@ -341,7 +342,7 @@ class Data: new_col.append(col_any) new_array = np.column_stack(new_col) return Data(new_array, copy.deepcopy(self.time), copy.deepcopy(self.metadata)) - def slope(self): + def slope(self) -> pint.Quantity: # use `magnitude` to strip the Pint unit, because `fit` will strip the unit anyway and # produce unnecessary warnings s = [] @@ -362,7 +363,7 @@ class Data: pf_conv = pf.convert() pfs.append(pf_conv) return pfs - def to_array(self): + def to_array(self) -> np.ndarray: time = self.time if type(time) == prule.unit_registry.Quantity: time = time.magnitude @@ -372,7 +373,7 @@ class Data: if len(time) == 0: return array return np.concatenate((time, array), axis=1) - def to(self, *args, **kwargs): + def to(self, *args, **kwargs) -> 'Data': array_new = copy.deepcopy(self.array).to(*args, **kwargs) md_new = copy.deepcopy(self.metadata) time_new = copy.deepcopy(self.time) diff --git a/prule/daemon/__main__.py b/prule/daemon/__main__.py index a8008fcdd9710fd20394ff0a1e9a13d988d6ba57..821f22ca4daf0c7316ab784c35ac596f9de57b25 100644 --- a/prule/daemon/__main__.py +++ b/prule/daemon/__main__.py @@ -72,6 +72,7 @@ Example state file: """ +import typing import os.path import sys import argparse @@ -100,10 +101,35 @@ import prule.debug config_keys = ["CC_URL", "CC_TOKEN", "CC_CHECK_INTERVAL", "STATE_PATH", "PRULE_PARAMETERS_FILE_PATH", "PRULE_CLUSTERS_FILE_PATH", "PRULE_RULES_FILE_PATH", "JOBARCHIVE_PATH", "OUTPUT_PATH", "API_METADATA", "API_TAG", "API_JOBARCHIVE", "CACHE_DB", "DB_PATH", "STORE_OUTPUT"] config_types = [str, str, int, str, str, str, str, str, str, bool, bool, bool, bool, str, bool, int] + +""" +Simply loads and holds a json. +""" +class Config: + def __init__(self, main_tid: int, path: str): + self.path :str = path + self.config :dict = {} + self.main_tid :int = main_tid + self.shutdown :bool = False + def load(self) -> None: + data = None + with open(self.path, "r") as f: + data = json.load(f) + for i,c in enumerate(config_keys): + if c not in data: + raise Exception("Key {} not found in configuration file loaded from {}.".format(c, self.path)) + if type(data[c]) != config_types[i]: + raise Exception("Key {} in configuration file has wrong type {}. It should be of type {}.".format(c, type(data[c]), config_types[i])) + config.config = data + def signal_shutdown(self) -> None: + if self.shutdown == False: + self.shutdown = True + signal.pthread_kill(config.main_tid, signal.SIGTERM) # shutdown + """ Create the message that is inserted into the ClusterCockpit UI. """ -def prepare_patho_message(config, result_json): +def prepare_patho_message(config: Config, result_json: dict) -> typing.Optional[str]: if len(result_json["tags"]) == 0: return None message = "<ul>" @@ -123,55 +149,31 @@ def prepare_patho_message(config, result_json): message += config.config["METADATA_MESSAGE"] return message - -""" -Simply loads and holds a json. -""" -class Config: - def __init__(self, main_tid, path): - self.path = path - self.config = {} - self.main_tid = main_tid - self.shutdown = False - def load(self): - data = None - with open(self.path, "r") as f: - data = json.load(f) - for i,c in enumerate(config_keys): - if c not in data: - raise Exception("Key {} not found in configuration file loaded from {}.".format(c, self.path)) - if type(data[c]) != config_types[i]: - raise Exception("Key {} in configuration file has wrong type {}. It should be of type {}.".format(c, type(data[c]), config_types[i])) - config.config = data - def signal_shutdown(self): - if self.shutdown == False: - self.shutdown = True - signal.pthread_kill(config.main_tid, signal.SIGTERM) # shutdown - class JobQueueItem: - def __init__(self, ccjobid, metadata=None, first_check=None): - self.ccjobid = ccjobid - self.metadata = metadata - self.first_check = int(time.time()) if first_check == None else first_check - self.backoff = 0 - def toDict(self): + def __init__(self, ccjobid: int, metadata: typing.Optional[dict] =None, first_check: typing.Optional[int] =None): + self.ccjobid :int = ccjobid + self.metadata :typing.Optional[dict] = metadata + self.first_check :int = int(time.time()) if first_check == None else first_check + self.backoff :int = 0 + def toDict(self) -> dict: return {"ccjobid":self.ccjobid, "metadata":self.metadata, "first_check":self.first_check, "backoff":self.backoff} - def fromDict(d): + @staticmethod + def fromDict(d: dict) -> typing.Optional['JobQueueItem']: if type(d) != dict: return None item = JobQueueItem(0) - item.ccjobid = d["ccjobid"] if "ccjobid" in d else None + if "ccjobid" not in d or type(d["ccjobid"] != int): + return None + item.ccjobid = d["ccjobid"] item.metadata = d["metadata"] if "metadata" in d else None item.first_check = d["first_check"] if "first_check" in d else 0 item.backoff = d["backoff"] if "backoff" in d else 0 - if item.ccjobid == None: - return None return item - def __eq__(self, other): + def __eq__(self, other) -> bool: if type(other) != JobQueueItem: return False return self.ccjobid == other.ccjobid - def __str__(self): + def __str__(self) -> str: return "<JobQueueItem ccjobid={} metadata={} first_check={} backoff={}>".format(self.ccjobid, self.metadata != None, self.first_check, self.backoff) """ @@ -191,7 +193,7 @@ class JobQueue: self.queue_active = [] self.smallest_starttime = 0 self.last_check = 0 - def loadFromJson(self, path): + def loadFromJson(self, path: str) -> bool: data = None try: with open(path, "r") as f: @@ -210,7 +212,7 @@ class JobQueue: self.queue.append(i) print("Loaded state: ", data) return True - def saveToJson(self, path): + def saveToJson(self, path: str) -> bool: data = {} data["smallest_starttime"] = self.smallest_starttime data["last_check"] = self.last_check @@ -230,13 +232,13 @@ class JobQueue: print(data) return False return True - def stop(self): + def stop(self) -> None: with self.condition: self.stopQueue = True self.condition.notify_all() - def stopped(self): + def stopped(self) -> bool: return self.stopQueue == True or (len(self.queue) == 0 and self.stop_once_empty == True) - def add(self, v, stop_once_empty=False): + def add(self, v, stop_once_empty: bool =False) -> None: if v == None and stop_once_empty == False: # prevent adding None value return with self.condition: @@ -245,7 +247,7 @@ class JobQueue: if stop_once_empty == True: self.stop_once_empty = True self.condition.notify() - def add_all(self, l, stop_once_empty=False): + def add_all(self, l: typing.Optional[typing.Sequence], stop_once_empty: bool =False) -> None: if (l == None or len(l) == 0) and stop_once_empty == False: return with self.condition: @@ -256,7 +258,7 @@ class JobQueue: if stop_once_empty == True: self.stop_once_empty == True self.condition.notify() - def clear(self): + def clear(self) -> typing.Sequence: with self.condition: values = self.queue self.queue = [] @@ -268,13 +270,13 @@ class JobQueue: if self.stopQueue == True or (len(self.queue) == 0 and self.stop_once_empty == True): return None return self.queue[0] - def empty(self): + def empty(self) -> bool: with self.condition: return len(self.queue) == 0 - def size(self): + def size(self) -> int: with self.condition: return len(self.queue) - def wait_add(self, timeout=None): + def wait_add(self, timeout: typing.Optional[float] =None) -> typing.Optional[bool]: # return True if condition is notified # return False if timeout expires # return None if queue is stopped @@ -284,11 +286,11 @@ class JobQueue: if self.stopQueue == True: return None return self.condition.wait(timeout) - def safe_fun(self, fun): + def safe_fun(self, fun: typing.Callable): with self.condition: queue_copy = copy.copy(self.queue) return fun(queue_copy) - def get_min(self, min_select): + def get_min(self, min_select: typing.Callable): # select the item with the smallest value or next item with value <= 0 # does not block if queue is empty ! # Moves item automatically into queue_active @@ -325,7 +327,7 @@ class JobQueue: self.queue = self.queue[1:] self.queue_active.append(value) return value - def deactivate_item(self, item): + def deactivate_item(self, item) -> typing.Optional[bool]: with self.condition: if self.stopQueue == True: return None @@ -342,16 +344,16 @@ class JobQueue: Checks for newly finished jobs and puts them into the queue. """ class CCCheckThread(threading.Thread): - def __init__(self, config, queue, check_once=False): + def __init__(self, config: Config, queue: JobQueue, check_once: bool =False): self.config = config self.queue = queue threading.Thread.__init__(self) self.stopThread = False self.stopCondition = threading.Condition() - self.requestFuture = None - self.executor = None + self.requestFuture :typing.Optional[concurrent.futures.Future] = None + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) self.check_once = check_once - def request_jobs(self, starttime, state, page, items_per_page): + def request_jobs(self, starttime: typing.Optional[int], state: typing.Optional[str], page: typing.Optional[int], items_per_page: int) -> typing.Optional[dict]: tempfile = None jobs = None try: @@ -373,7 +375,7 @@ class CCCheckThread(threading.Thread): #print("CCheckThread: Request {}".format(url)) req = urllib.request.Request(url, headers=headers, method="GET") - def execRequest(req): + def execRequest(req: urllib.request.Request) -> typing.Optional[dict]: try: with urllib.request.urlopen(req, timeout=10) as response: if response.status == 200: @@ -407,15 +409,15 @@ class CCCheckThread(threading.Thread): #await self.requestTask try: jobs = self.requestFuture.result() - except CancelledError as e: + except concurrent.futures.CancelledError as e: jobs = None with self.stopCondition: self.requestTask = None self.requestFuture = None return jobs - def get_jobs(self, starttime, state): - jobs = [] + def get_jobs(self, starttime: int, state: typing.Optional[str]) -> typing.Optional[list]: + jobs: typing.Optional[list] = [] items_per_page = 10000 ret = 0 page = 0 @@ -437,7 +439,7 @@ class CCCheckThread(threading.Thread): stopTime = time.time() print("Request {} jobs using {} requests in {:.3} seconds".format(len(jobs) if jobs != None else None, page, stopTime-startTime)) return jobs - def check(self): + def check(self) -> None: # current check smallest_starttime = sys.maxsize new_running_ids_set = set() @@ -495,8 +497,7 @@ class CCCheckThread(threading.Thread): if smallest_starttime < sys.maxsize: queue.smallest_starttime = smallest_starttime - def run_main(self): - self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + def run_main(self) -> None: while self.stopThread == False: # check CC for finished jobs and add them to queue @@ -518,7 +519,7 @@ class CCCheckThread(threading.Thread): # thread done print("CCCheckThread Done") self.executor.shutdown(wait=False) - def run(self): + def run(self) -> None: try: self.run_main() except Exception as e: @@ -526,7 +527,7 @@ class CCCheckThread(threading.Thread): print("CCCheckThread",e) if self.config.shutdown == False: self.config.signal_shutdown() - def stop(self): + def stop(self) -> None: with self.stopCondition: print("Stop CCCheckThread") self.stopThread = True @@ -539,17 +540,17 @@ class CCCheckThread(threading.Thread): Gets a job from the queue and starts the prule program. """ class PruleThread(threading.Thread): - def __init__(self, config, stop_on_empty, queue): + def __init__(self, config: Config, stop_on_empty: bool, queue: JobQueue): self.config = config self.queue = queue threading.Thread.__init__(self) self.stopThread = False self.stopCondition = threading.Condition() self.stop_on_empty = stop_on_empty - self.currentProcess = None + self.currentProcess: typing.Optional[subprocess.Popen] = None self.processTerminated = False - self.db_con = None - def request_job_meta(self, id): + self.db_con: typing.Optional[prule.db.ResultsDB] = None + def request_job_meta(self, id: int) -> typing.Tuple[typing.Union[None, str, dict], int]: url = config.config["CC_URL"]+"/api/jobs/{}".format(id) headers = {} headers["Access-Control-Request-Headers"] = "x-auth-token" @@ -584,7 +585,7 @@ class PruleThread(threading.Thread): print("request_job_meta",e) return (None, 999) return (None, 990) - def request_tag_job(self, id, tags): + def request_tag_job(self, id: int, tags: dict) -> typing.Tuple[bool, int]: #[{"type":"foo","name":"bar"},{"type":"asdf","name":"fdsa"}] url = config.config["CC_URL"]+"/api/jobs/tag_job/{}".format(id) headers = {} @@ -602,7 +603,7 @@ class PruleThread(threading.Thread): msg = "" try: msg = e.fp.read().decode('utf-8', 'ignore') - except Exception as e: + except: pass print("Error {} for URL {} Reason {} Msg {}".format(e.code, e.url, e.reason, msg)) if e.code == 401: @@ -617,7 +618,7 @@ class PruleThread(threading.Thread): print(e) return (False, 999) return (False, 990) - def request_jobarchive(self, id): + def request_jobarchive(self, id:int ) -> typing.Tuple[typing.Union[bool, tempfile.TemporaryDirectory, str], int]: url = config.config["CC_URL"]+"/api/jobs/{}?all-metrics=true".format(id) headers = {} headers["Access-Control-Request-Headers"] = "x-auth-token" @@ -653,7 +654,7 @@ class PruleThread(threading.Thread): print("Cleaning up {} failed".format(tdir.name)) return (False, 999) return (False, 990) - def request_metadata_upload(self, id, metadata): + def request_metadata_upload(self, id: int, metadata: dict) -> typing.Tuple[typing.Union[bool, dict], int]: url = config.config["CC_URL"]+"/api/jobs/edit_meta/{}".format(id) headers = {} headers["Access-Control-Request-Headers"] = "x-auth-token" @@ -678,7 +679,7 @@ class PruleThread(threading.Thread): print(e) return (False, 999) return (False, 990) - def prule_start(self): + def prule_start(self) -> None: params = ["python3","-u","-m","prule"] params += ["--parameters-file", config.config["PRULE_PARAMETERS_FILE_PATH"]] params += ["--clusters-file", config.config["PRULE_CLUSTERS_FILE_PATH"]] @@ -694,18 +695,21 @@ class PruleThread(threading.Thread): if self.currentProcess == None: self.currentProcess = subprocess.Popen(params, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr, preexec_fn = preexec) print("Prule process {} started".format(self.currentProcess.pid)) - def prule_job(self, job): + def prule_job(self, job: dict) -> typing.Tuple[typing.Union[None, dict, bool], float, int]: if self.currentProcess == None: return (None, 0.0, 10) process_time_start = datetime.datetime.now().timestamp() tries = 0 + result = None while tries <2: try: data = json.dumps(job) + "\n" - self.currentProcess.stdin.write(data.encode("utf-8")) - self.currentProcess.stdin.flush() - line = self.currentProcess.stdout.readline() - result = json.loads(line) + if self.currentProcess.stdin != None: + self.currentProcess.stdin.write(data.encode("utf-8")) + self.currentProcess.stdin.flush() + if self.currentProcess.stdout != None: + line = self.currentProcess.stdout.readline() + result = json.loads(line) break except Exception as e: traceback.print_exc() @@ -723,17 +727,18 @@ class PruleThread(threading.Thread): if type(result) != dict or len(result) == 0: return (False, process_time, 40) return (result, process_time, 0) - def prule_restart(self): + def prule_restart(self) -> None: self.prule_stop() self.prule_start() - def prule_stop(self): + def prule_stop(self) -> typing.Optional[int]: proc = None with self.stopCondition: if self.currentProcess == None: return None proc = self.currentProcess try: - self.currentProcess.stdin.close() + if self.currentProcess.stdin != None: + self.currentProcess.stdin.close() except: pass try: @@ -747,7 +752,7 @@ class PruleThread(threading.Thread): return returncode # job: "id" - CC database id, not "jobId", which is the SLURM job id - def processJob(self, job): + def processJob(self, job: JobQueueItem) -> typing.Tuple[str, int]: # track process error process_result = "success" @@ -772,7 +777,11 @@ class PruleThread(threading.Thread): if config.config["CACHE_DB"] == True: try: with prule.debug.Timing("prulethread.db_get_result", "PRINT_TIMING" in config.config): - old_result = self.db_con.db_get_result(job.ccjobid) + if self.db_con != None: + old_result = self.db_con.db_get_result(job.ccjobid) + else: + error_code = 101000000 + return ("failure-shutdown", error_code) except: error_code = 101000000 return ("failure-shutdown", error_code) @@ -790,18 +799,23 @@ class PruleThread(threading.Thread): error_code += 102000000 if job_res == None: return ("failure-shutdown", error_code) - if job_res == "job-failure": - return ("failure-drop", error_code) - if job_res == "wait": - return ("failure-wait", error_code) - job_meta = job_res['Meta'] + elif type(job_res) == str: + if job_res == "job-failure": + return ("failure-drop", error_code) + if job_res == "wait": + return ("failure-wait", error_code) + elif type(job_res) == dict: + if type(job_res['Meta']) == dict: + job_meta = job_res['Meta'] + if job_meta == None: + return ("failure-drop", 102000000) job_cluster = job_meta["cluster"] job_slurmid = job_meta["jobId"] job_startTime = str(job_meta["startTime"]) if type(job_meta["startTime"]) == int else str(int(datetime.datetime.fromisoformat(job_meta['startTime']).timestamp())) # prepare job path for filesystem access or download jobarchive from API - job_path = None + job_path = "" job_tempdir = None if config.config["API_JOBARCHIVE"] == False: @@ -810,14 +824,16 @@ class PruleThread(threading.Thread): else: # Load job from jobarchive api and write it to tempdir with prule.debug.Timing("prulethread.request_jobarchive", "PRINT_TIMING" in config.config): - job_tempdir, error_code = self.request_jobarchive(job.ccjobid) + job_tempdir_res, error_code = self.request_jobarchive(job.ccjobid) if error_code > 0: error_code += 103000000 - if job_tempdir == False: + if type(job_tempdir) == bool: return ("failure-shutdown", error_code) - if job_tempdir == "wait": + elif type(job_tempdir) == str: # "wait" return ("failure-wait", error_code) - job_path = job_tempdir.name + elif type(job_tempdir_res) == tempfile.TemporaryDirectory: + job_path = job_tempdir_res.name + job_tempdir = job_tempdir_res print("Job path:",job_path) @@ -847,6 +863,8 @@ class PruleThread(threading.Thread): return ("failure-shutdown", error_code) if result_json == False: return ("failure-drop", error_code) + if type(result_json) != dict: + return ("failure-drop", 105000000) print("Process: job {} jobId {} time {:.6f}".format(job.ccjobid, job_slurmid, process_time)) if self.processTerminated == True: @@ -904,7 +922,7 @@ class PruleThread(threading.Thread): res, error_code = self.request_metadata_upload(job.ccjobid, {"key":"issues","value":patho_message}) if error_code > 0: error_code += 107000000 - if res == False: + if type(res) == bool: print("Job {} process failed to write metadata using API_METADATA".format(job.ccjobid)) process_result = "failure-shutdown" else: @@ -955,13 +973,16 @@ class PruleThread(threading.Thread): job_tempdir.cleanup() except: print("Cleaning up {} failed".format(job_tempdir.name)) - if job_res == None: - return ("failure-shutdown", error_code) + if job_res == None: + return ("failure-shutdown", error_code) + elif type(job_res) == str: if job_res == "job-failure": return ("failure-drop", error_code) if job_res == "wait": return ("failure-wait", error_code) - job_meta = job_res["Meta"] + elif type(job_res) == dict: + if type(job_res['Meta']) == dict: + job_meta = job_res["Meta"] # overwrite metadata in job from prule results if "metadata" in result_json and job_meta != None: @@ -972,7 +993,10 @@ class PruleThread(threading.Thread): try: evaluated = "error" in result_json and result_json["error"] == False with prule.debug.Timing("prulethread.db_insert_result", "PRINT_TIMING" in config.config): - self.db_con.db_insert_result(job.ccjobid, result_json, job_meta, process_time, evaluated) + if self.db_con != None: + self.db_con.db_insert_result(job.ccjobid, result_json, job_meta, process_time, evaluated) + else: + return ("failure-shutdown", 11000000) except Exception as e: traceback.print_exc() print(e) @@ -987,7 +1011,7 @@ class PruleThread(threading.Thread): except: print("Cleaning up {} failed".format(job_tempdir.name)) return (process_result, error_code) - def run_main(self): + def run_main(self) -> None: if self.config.config["CACHE_DB"] == True: self.db_con = prule.db.ResultsDB(self.config.config["DB_PATH"]) @@ -1073,7 +1097,10 @@ class PruleThread(threading.Thread): if config.config["CACHE_DB"] == True: try: with prule.debug.Timing("prulethread.db_insert_failure", "PRINT_TIMING" in config.config): - self.db_con.db_insert_failure(job.ccjobid) + if self.db_con != None: + self.db_con.db_insert_failure(job.ccjobid) + else: + raise Exception("Failed to open sqlite database") except Exception as e: traceback.print_exc() print(e) @@ -1103,10 +1130,10 @@ class PruleThread(threading.Thread): self.prule_stop() - if self.config.config["CACHE_DB"] == True: + if self.config.config["CACHE_DB"] == True and self.db_con != None: self.db_con.close() self.db_con = None - def run(self): + def run(self) -> None: try: self.run_main() except Exception as e: @@ -1114,7 +1141,7 @@ class PruleThread(threading.Thread): print("PruleThread:",e) if self.config.shutdown == False: self.config.signal_shutdown() - def stop(self): + def stop(self) -> None: with self.stopCondition: if self.currentProcess != None: self.currentProcess.terminate() @@ -1151,8 +1178,7 @@ if __name__ == "__main__": parser.add_argument('--no-tmpdir-clean', action='store_true', help='Keep temporary directories') parser.add_argument('--print-timing', action='store_true', help='Print debug timings') - args = parser.parse_args() - args = vars(args) + args = vars(parser.parse_args()) if "args" in args: print(args) diff --git a/prule/debug/__init__.py b/prule/debug/__init__.py index 4e70b6fa84e6de109d09d3d33d77a8ea8cbae225..96104c14b042cdd9ad07975b33d96ea0047ed1fe 100644 --- a/prule/debug/__init__.py +++ b/prule/debug/__init__.py @@ -1,7 +1,8 @@ +import typing import sys import datetime -debug_settings = None +debug_settings: typing.Optional[dict] = None # debug_log_terms_dir: Path to directory for storing term values during evaluation class Timing: