Skip to content
Snippets Groups Projects
Commit b258aad2 authored by Alex Wiens's avatar Alex Wiens
Browse files

Prule: Add basic type annotations and slight code reworks

parent 94896931
Branches
No related tags found
No related merge requests found
......@@ -30,7 +30,16 @@ For the output several options are possible.
### Debugging
By using the `--debug` parameter the output of detailed information of the rule evaluation is enabled.
Specifics of the input data preprocessing and for each rule evaluation step the output is printed.
Additional debugging arguments can be found in the `--help` output of the `prule` tools.
### Static type analysis
You can do a static type analysis with e.g. `mypy`.
Important: install and use the `mypy` package from the same python environment as you use for `prule`.
Necessary packages: `mypy`, `types-jsonschema`.
```
mypy prule
```
## Job processing daemon
......
import typing
import sys
import os.path
import re
......@@ -35,7 +36,6 @@ job_names = ["job","hwthreads"]
base_path = os.path.dirname(__file__)
unit_path = os.path.join(base_path, "units.txt")
unit_registry = None
try:
unit_registry = pint.UnitRegistry()
unit_registry.load_definitions(unit_path)
......@@ -47,7 +47,7 @@ except Exception as e:
# Create core mask string from arrays and topology
# Ff entry in mask array is >0 then the core is masked.
def core_mask_str(mask, cores_per_socket, sockets_per_node, memorydomains_per_node):
def core_mask_str(mask, cores_per_socket: int, sockets_per_node: int, memorydomains_per_node: int) -> str:
mstr = ""
cores_per_memorydomain = (cores_per_socket * sockets_per_node / memorydomains_per_node)
socket_cores = cores_per_socket
......@@ -68,7 +68,7 @@ def core_mask_str(mask, cores_per_socket, sockets_per_node, memorydomains_per_no
# Generates the set of defined local names and used global names
def rule_used_names(rule):
def rule_used_names(rule: dict) -> typing.Tuple[list, dict]:
local_names = {}
global_names = {}
# iterate over all terms
......@@ -97,10 +97,10 @@ def rule_used_names(rule):
else:
attr = None
local_names[term_var] = True
return (local_names.keys(), global_names)
return (list(local_names.keys()), global_names)
# Consistency check for parameter, rule and cluster specification
def configuration_check(parameters, rules, clusters):
def configuration_check(parameters: dict, rules: list, clusters: list) -> None:
# check cluster topology specs
log.print_color(log.color.magenta, log.debug, "-"*25,"Load cluster specifications:","-"*25)
......@@ -112,7 +112,7 @@ def configuration_check(parameters, rules, clusters):
# namespace conflicts
names = {} # dictionary of all names
# fill in all builtin names
for name in dir(builtin):
for name in dir(prule.builtin):
names[name] = True
#print(name)
#TODO: fix import of names from builtin, use list of exported names or import all attributes?
......@@ -156,7 +156,7 @@ def configuration_check(parameters, rules, clusters):
#raise Exception("In rule {} the term with index {} uses the unknown literal {}.".format(rule["name"], global_names[g], g))
log.print_color(log.color.yellow, log.debug, "In rule \"{}\" the term with index {} uses the unknown literal \"{}\".".format(rule["name"], global_names[g], g))
job_meta_fields = [
job_meta_fields: list = [
]
# Dummy class to be able to define attributes on objects.
......@@ -169,17 +169,21 @@ class JobMetadata:
# The node id is identified by index of the hostname in the resource list.
# For the resource hierarchy, a metric for specific scope level should make the id in the levels above identifiable.
# E.g. if one knows the core id, then the memoryDomain and socket ids can be identified.
def get_scopeids(clusters, job_meta, hostname, thread=None, core=None, memoryDomain=None, socket=None, node=None, accelerator=None):
cluster = None
def get_scopeids(clusters: list, job_meta: dict, hostname: str, thread: typing.Optional[int] =None, core: typing.Optional[int] =None, memoryDomain: typing.Optional[int] =None, socket: typing.Optional[int] =None, node: typing.Optional[int] =None, accelerator: typing.Optional[int] =None) -> dict:
cluster: typing.Optional[dict] = None
for c in clusters:
if c["name"] == job_meta["cluster"]:
cluster = c
break
subCluster = None
if cluster == None:
raise Exception("Cluster {} not found in cluster input.".format(job_meta["cluster"]))
subCluster: typing.Optional[dict] = None
for s in cluster["subClusters"]:
if s["name"] == job_meta["subCluster"]:
subCluster = s
break
if subCluster == None:
raise Exception("Subcluster {} not found in cluster {} input.".format(job_meta["subCluster"], job_meta["cluster"]))
topology = subCluster["topology"]
scopeIds = {}
if thread != None:
......@@ -233,7 +237,7 @@ def get_scopeids(clusters, job_meta, hostname, thread=None, core=None, memoryDom
# example: gpu:a100:3(IDX:0-1,3),fpga:0
slurm_reg_res = re.compile("^([^:]+):((([^,:]*):)?([0-9]+))?(\\(IDX:([^)]+)\\))?,?")
def slurm_parse_resources(s):
def slurm_parse_resources(s: str) -> dict:
r = s
res = {}
while r != "":
......@@ -270,7 +274,7 @@ def slurm_parse_resources(s):
#print(slurm_parse_resources("gpu:a100:1(IDX:1)"))
#print(slurm_parse_resources(""))
def slurm_seq_cpuids(s):
def slurm_seq_cpuids(s: str) -> typing.List[int]:
cids = []
for term in s.split(","):
mix = term.find("-")
......@@ -284,7 +288,7 @@ def slurm_seq_cpuids(s):
cids += list( range( int(term[:mix]), int(term[mix+1:])+1 ) )
return cids
def slurm_seq_nodelist(pre, start, end):
def slurm_seq_nodelist(pre: str, start: str, end: str) -> typing.List[str]:
#print(start,end)
res = []
slen = len(start)
......@@ -295,7 +299,7 @@ def slurm_seq_nodelist(pre, start, end):
return res
def slurm_expand_nodelist_num(pre, numlist):
def slurm_expand_nodelist_num(pre: str, numlist: str) -> typing.List[str]:
#print(numlist)
res = []
nlist = numlist
......@@ -316,7 +320,7 @@ def slurm_expand_nodelist_num(pre, numlist):
reg_numlist=re.compile("^([0-9]+)(-[0-9]+)?,?(.*)?$")
reg_nodelist=re.compile("^([^ [,]+)(\\[([^]]+)])?,?(.*)$")
def slurm_expand_nodelist(nodes):
def slurm_expand_nodelist(nodes: str) -> list:
#print(nodes)
res = []
nlist = nodes
......@@ -324,24 +328,24 @@ def slurm_expand_nodelist(nodes):
re_nodelist = reg_nodelist.match(nlist)
#print(nlist, re_nodelist)
if re_nodelist != None:
re_nodelist = re_nodelist.groups()
pre = re_nodelist[0]
numlist = re_nodelist[2]
re_groups = re_nodelist.groups()
pre = re_groups[0]
numlist = re_groups[2]
if numlist == None:
res.append(pre)
break
else:
res += slurm_expand_nodelist_num(pre, numlist)
#print(nlist, re_nodelist)
if len(re_nodelist) > 2:
nlist = re_nodelist[3]
#print(nlist, re_groups)
if len(re_groups) > 2:
nlist = re_groups[3]
else:
break
#print(res)
return res
def parse_slurminfo(info):
slurm = {}
def parse_slurminfo(info: str) -> dict:
slurm: dict = {}
for l in info.split("\n"):
if l == "":
break # end of slurm info
......@@ -389,8 +393,8 @@ def parse_slurminfo(info):
slurm["JobName"] = l[l.find("=", l.find(" "))+1: ].strip()
return slurm
def parse_slurm_size(size):
num = 0
def parse_slurm_size(size: str) -> float:
num = 0.0
if size.endswith("K"):
num = float(size[:-1]) * 1024
elif size.endswith("M"):
......@@ -409,12 +413,12 @@ def parse_slurm_size(size):
# - cluster specification
# - job meta data
# - job measurement data
def rule_prepare_input(parameters, rules, clusters, job_meta, job_data):
def rule_prepare_input(parameters: dict, rules: list, clusters: list, job_meta: dict, job_data: dict) -> dict:
globals = {}
#globals["mean"] = builtin.mean
# add definitions from prule builtins
for k in builtin.public:
globals[k] = getattr(builtin, k)
for k in prule.builtin.public:
globals[k] = getattr(prule.builtin, k)
for key, value in parameters.items():
if key == "job_requirements":
continue
......@@ -427,7 +431,7 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data):
#for jn in job_names:
# globals[jn] = True
def quantity_create(value, unit=None):
def quantity_create(value, unit=None) -> pint.Quantity:
if unit == None:
if type(value) == str:
return unit_registry.parse_expression(value)
......@@ -443,7 +447,7 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data):
globals["quantity"] = quantity_create
# prepare job metadata
job = JobMetadata()
job: typing.Any = JobMetadata()
# copy all attributes from json to job object
for attr in job_meta:
if type(job_meta[attr]) not in [dict,list]:
......@@ -488,12 +492,12 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data):
pass
#alloc_arr[0,ix] = unit_registry.Quantity(s, unit_registry.B)
alloc_arr[0,ix] = s
alloc_arr = unit_registry.Quantity(alloc_arr, unit_registry.B)
alloc_arr = unit_registry.Quantity(alloc_arr, unit_registry.B) # type: ignore
setattr(job, "allocated_memory", alloc_arr)
# metadata conversion
setattr(job, "walltime", unit_registry.Quantity(job.walltime,unit_registry.s))
setattr(job, "duration", unit_registry.Quantity(job.duration,unit_registry.s))
setattr(job, "walltime", unit_registry.Quantity(job.walltime,unit_registry.s)) # type: ignore
setattr(job, "duration", unit_registry.Quantity(job.duration,unit_registry.s)) # type: ignore
# SMT enabled?
cluster = None
......@@ -528,7 +532,7 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data):
resources = []
if missing_resources == True:
log.print_color(log.color.yellow, log.warn, "Warning: resources, such as used hwthreads and accelerators, not entirely specified in job metadata. Reconstructing from job measurements.")
host_resources = {}
host_resources: dict = {}
for metric in job_data:
metric_data = job_data[metric]
for scope in metric_data:
......@@ -617,7 +621,7 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data):
# add units
#job.duration = unit_registry.Quantity(job.duration, unit_registry.s)
# add scope specific thread numbers
numthreads = JobMetadata()
numthreads: typing.Any = JobMetadata()
# threads per node
......@@ -711,9 +715,9 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data):
elif scope == "node" and "hostname" in data["series"][0]:
series_sorted = []
for h in job_meta["resources"]:
for s in data["series"]:
if s["hostname"] == h["hostname"]:
series_sorted.append(s)
for series in data["series"]:
if series["hostname"] == h["hostname"]:
series_sorted.append(series)
break
data["series"] = series_sorted
else:
......@@ -801,7 +805,7 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data):
else:
timedata = data["timeseries"]
timeseries = np.array(timedata)
timeseries = unit_registry.Quantity(timeseries, unit_registry.parse_expression("ns"))
timeseries = unit_registry.Quantity(timeseries, unit_registry.parse_expression("ns")) # type: ignore
# check sample count
min_samples = sys.maxsize
max_samples = 0
......@@ -834,7 +838,7 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data):
b = i-1
if (v == None or np.isnan(v)) and a != None and i == len(dataseries[ix])-1: # NaN sequence at end of series
b = i
if b != None: # found NaN sequence
if a != None and b != None: # found NaN sequence
#print("fix sequence ", a, b, " of", metadataseries[ix]["metric"]) # debugging
none_count += b-a if b != a else 1
if a == 0: # sequence at start, set to 0.0
......@@ -866,8 +870,8 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data):
# set type to float
if data_array.dtype == object:
data_array = data_array.astype(float)
data_array = unit_registry.Quantity(data_array, unit)
data = builtin.Data(data_array, timeseries, metadataseries)
data_array = unit_registry.Quantity(data_array, unit) # type: ignore
data = prule.builtin.Data(data_array, timeseries, metadataseries)
#TODO: sort columns by scope id, e.g. scope=="core", 1. column time, rest: core id
......@@ -892,13 +896,15 @@ def rule_prepare_input(parameters, rules, clusters, job_meta, job_data):
return globals
def debug_rule_store_terms(rule_name, term_index, var_name, var_value, debug_settings):
def debug_rule_store_terms(rule_name, term_index, var_name, var_value, debug_settings) -> None:
if debug_settings == None or "debug_log_terms_dir" not in debug_settings:
return
outfile = "{}_{}_{}.csv".format(rule_name, term_index, var_name)
debug_write_file(outfile, var_value)
def debug_write_file(outfile, value):
def debug_write_file(outfile, value) -> None:
if prule.debug.debug_settings == None:
return
outpath = os.path.join(prule.debug.debug_settings["debug_log_terms_dir"], outfile)
with open (outpath, "w", newline='') as csvfile:
cwriter = csv.writer(csvfile, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
......@@ -908,7 +914,7 @@ def debug_write_file(outfile, value):
else:
cwriter.writerow([value])
def debug_prompt(rule, rinput, job_meta, locals, local_names, globals, term_index, term, skip_prompt):
def debug_prompt(rule, rinput, job_meta, locals, local_names, globals, term_index, term, skip_prompt) -> bool:
if prule.debug.debug_settings == None:
return True
if skip_prompt == True:
......@@ -963,7 +969,7 @@ def debug_prompt(rule, rinput, job_meta, locals, local_names, globals, term_inde
return_value = eval(expr, locals, globals)
print(type(return_value))
print(return_value)
if isinstance(return_value, builtin.Data):
if isinstance(return_value, prule.builtin.Data):
print(return_value.array)
debug_write_file(fname, return_value)
except Exception as e:
......@@ -979,7 +985,7 @@ def debug_prompt(rule, rinput, job_meta, locals, local_names, globals, term_inde
return_value = eval(ex_line, locals, globals)
print(type(return_value))
print(return_value)
if isinstance(return_value, builtin.Data):
if isinstance(return_value, prule.builtin.Data):
print(return_value.array)
if var_name != None:
locals[var_name] = return_value
......@@ -993,7 +999,7 @@ def debug_prompt(rule, rinput, job_meta, locals, local_names, globals, term_inde
# Evaluate rule.
# Iterate over terms and evaluate python expression.
# Track local and global variables.
def rule_evaluate(rule, rinput, job_meta, check_requirements):
def rule_evaluate(rule: dict, rinput: dict, job_meta: dict, check_requirements: bool) -> dict:
log.print_color(log.color.magenta+log.color.bold, log.debug, "#"*25,"Evaluate rule:",rule["name"],"#"*25)
output = {}
output["name"] = rule["name"]
......@@ -1026,7 +1032,7 @@ def rule_evaluate(rule, rinput, job_meta, check_requirements):
else:
required_metrics_min_samples = min(job_meta["metric_min_sample_count"][m], required_metrics_min_samples)
# prepare locals
locals_template = {}
locals_template: dict = {}
locals_template["required_metrics_min_samples"] = required_metrics_min_samples
# requirements
......@@ -1049,7 +1055,7 @@ def rule_evaluate(rule, rinput, job_meta, check_requirements):
# evaluate terms
locals = copy.deepcopy(locals_template)
globals = rinput
local_names = {}
local_names: dict = {}
skip_prompt = False
for tix,term in enumerate(rule["terms"]):
skip_prompt = debug_prompt(rule, rinput, job_meta, locals, local_names, globals, tix, term, skip_prompt)
......@@ -1064,7 +1070,7 @@ def rule_evaluate(rule, rinput, job_meta, check_requirements):
return_value = eval(term_str, locals, globals)
log.print_color(log.color.blue, log.debug, "Result for ",term_var,": ",type(return_value))
log.print_color(log.color.blue, log.debug, str(return_value))
if isinstance(return_value, builtin.Data):
if isinstance(return_value, prule.builtin.Data):
log.print_color(log.color.blue, log.debug, return_value.array)
debug_rule_store_terms(rule["tag"], tix, term_var, return_value, prule.debug.debug_settings)
except Exception as e:
......@@ -1093,7 +1099,7 @@ def rule_evaluate(rule, rinput, job_meta, check_requirements):
log.print(log.debug, log.color.blue + "output: " + log.color.reset, rule["output"], type(locals[rule["output"]]))
log.print(log.debug, locals[rule["output"]])
output_value = None
if isinstance(locals[rule["output"]], builtin.Data):
if isinstance(locals[rule["output"]], prule.builtin.Data):
log.print(log.debug, locals[rule["output"]].array)
output_value = locals[rule["output"]]
try:
......@@ -1149,7 +1155,7 @@ def rule_evaluate(rule, rinput, job_meta, check_requirements):
if rule["output_scalar"] in locals:
log.print(log.debug, log.color.blue + "output_scalar: " + log.color.reset, rule["output_scalar"], type(locals[rule["output_scalar"]]))
log.print(log.debug, locals[rule["output_scalar"]])
if isinstance(locals[rule["output_scalar"]], builtin.Data):
if isinstance(locals[rule["output_scalar"]], prule.builtin.Data):
log.print(log.debug, locals[rule["output_scalar"]].array)
output["scalar"] = float(locals[rule["output_scalar"]].array)
else:
......
#!/bin/env python -u
import typing
import os
import os.path
import sys
......@@ -22,10 +23,11 @@ import prule.db
class Job:
def __init__(self, meta_path, data_path):
def __init__(self, meta_path: str, data_path: str):
self.meta_path = meta_path
self.data_path = data_path
def from_jobdir(jobdir_path):
@staticmethod
def from_jobdir(jobdir_path: str) -> 'Job':
j_meta_path = os.path.join(jobdir_path, "meta.json")
j_data_path = os.path.join(jobdir_path, "data.json")
# test for compressed json, else swap back to normal one
......@@ -37,21 +39,21 @@ class Job:
return j
class IDict(dict):
def __init__(self, c):
def __init__(self, c: dict):
self.update(c)
def __setitem__(self, k,v):
def __setitem__(self, k,v) -> typing.Never:
raise Exception("Immutable dictionary")
def __delitem__(self, k):
def __delitem__(self, k) -> typing.Never:
raise Exception("Immutable dictionary")
def make_immutable(d):
def make_immutable(d: dict) -> IDict:
for k,v in d.items():
if type(v) == dict:
d[v] = make_immutable(v)
id = IDict(d)
return id
def stdin_readjob():
def stdin_readjob() -> typing.Union[None, Job, str]:
try:
line = sys.stdin.readline()
except:
......@@ -166,8 +168,7 @@ if __name__ == "__main__":
out_group.add_argument('--db-path', nargs=1,
help='Path to sqlite output file.')
args = parser.parse_args()
args = vars(args)
args = vars(parser.parse_args())
if "args" in args:
print(args)
......@@ -382,25 +383,30 @@ if __name__ == "__main__":
rule_error = False
output_all_jobs = []
while len(job_queue)>0 or "job_stdin" in args:
job = None
job: typing.Optional[Job] = None
if "job_stdin" in args:
job = stdin_readjob()
if job == "error":
job_read = stdin_readjob()
if job_read == None:
break
elif type(job_read) == str: # "error"
job_error = True
print("{}") # in case of malformed input
continue
if job == None:
break
elif type(job_read) == Job:
job = job_read
else:
job = job_queue.pop()
if job == None:
break
process_time_start = datetime.datetime.now().timestamp()
error = False
job_meta = None
job_data = None
job_id = None
job_output = {"jobId":None, "tags":[], "rules":[], "error":False, "errors":[], "rules_failed":[], "rules_evaluated":[], "rules_not_evaluated":[], "metadata":{}}
job_output: dict = {"jobId":None, "tags":[], "rules":[], "error":False, "errors":[], "rules_failed":[], "rules_evaluated":[], "rules_not_evaluated":[], "metadata":{}}
# read meta file
if os.path.exists(job.meta_path) == False:
......@@ -487,7 +493,7 @@ if __name__ == "__main__":
matched_tags = []
rule_output = []
if error == False:
if error == False and rinput != None and job_meta != None:
for rule in rules_list:
try:
out = rule_evaluate(rule, rinput, job_meta, not args["ignore_rule_requirements"])
......
# builtin functions
# applicable to matrices, scalars, boolean values
import typing
import builtins
import numpy as np
import statistics
......@@ -10,11 +11,7 @@ import pint
import prule
import types
public = [
#"mean",
#"max"
]
public: list = []
# Constants for scopes, necessary?
#SCOPE_THREAD = "thread"
......@@ -28,24 +25,29 @@ public = [
_str_scope = ["all", "thread", "core", "numa", "socket", "node", "accelerator", "job"]
_int_scope = {"thread":1, "core":2, "numa":3, "socket":4, "node":5, "accelerator":6, "job":7}
def _lt_scope(x, y):
x = _int_scope[x]
y = _int_scope[y]
return x < y
def _gt_scope(x, y):
x = _int_scope[x]
y = _int_scope[y]
return x > y
def _lt_scope(x: typing.Optional[str], y: typing.Optional[str]) -> bool:
if x == None or y == None:
raise KeyError(None)
x_ix = _int_scope[x]
y_ix = _int_scope[y]
return x_ix < y_ix
def _eq_scope(x, y):
x = _int_scope[x]
y = _int_scope[y]
return x == y
def _gt_scope(x: typing.Optional[str], y: typing.Optional[str]) -> bool:
if x == None or y == None:
raise KeyError(None)
x_ix = _int_scope[x]
y_ix = _int_scope[y]
return x_ix > y_ix
def _eq_scope(x: typing.Optional[str], y: typing.Optional[str]) -> bool:
if x == None or y == None:
raise KeyError(None)
x_ix = _int_scope[x]
y_ix = _int_scope[y]
return x_ix == y_ix
class Data:
def __init__(self, array, time, metadata):
def __init__(self, array, time, metadata: list):
self.array = array
self.time = time
self.metadata = metadata
......@@ -57,8 +59,9 @@ class Data:
# Prevent accidental overwriting of object members
def setattr(self, a, b):
raise Exception("Data class objects are not mutable")
self.__setattr__ = types.MethodType(setattr, self)
def __getitem__(self, indices):
#self.__setattr__ = types.MethodType(setattr, self)
builtins.setattr(self, "__setattr__", types.MethodType(setattr, self))
def __getitem__(self, indices) -> 'Data':
# if `indices` is a tuple, then the tuple elements index the two respective dimensions (time, columns)
# else, `indices` indexes the first dimension (time)
if type(indices) != tuple:
......@@ -74,30 +77,28 @@ class Data:
else:
raise Exception("Only 1 or 2-dimensional indices are supported")
return Data(array_new, time_new, metadata_new)
def __setitem__(self, indices):
def __setitem__(self, indices) -> typing.Never:
raise Exception("Data class objects are not mutable")
def nbytes(self):
def nbytes(self) -> int:
return self.array.nbytes
def set_writeable(self, v):
def set_writeable(self, v: bool) -> None:
self.array.flags.writeable = v
def __str__(self):
def __str__(self) -> str:
unit = self.array.units if isinstance(self.array, pint.UnitRegistry.Quantity) else ""
return "<{} Class:{} Unit:{} DType:{} Shape:{} {}>".format(type(self), type(self.array), unit, self.array.dtype, self.array.shape, str(self.array))
def __int__(self):
def __int__(self) -> int:
return int(self.array)
def __long__(self):
return long(self.array)
def __float__(self):
def __float__(self) -> float:
return float(self.array)
def __complex__(self):
def __complex__(self) -> complex:
return complex(self.array)
def __oct__(self):
def __oct__(self) -> str:
return oct(self.array)
def __hex__(self):
def __hex__(self) -> str:
return hex(self.array)
def __bool__(self):
def __bool__(self) -> bool:
return bool(self.array)
def scope(self):
def scope(self) -> typing.Optional[str]:
scope = None
column_num = 1 if self.array.ndim == 1 else self.array.shape[1]
for cix in range(0, column_num):
......@@ -106,7 +107,7 @@ class Data:
if scope == None or _lt_scope(s, scope):
scope = s
return scope
def _scope_columns(self, scope):
def _scope_columns(self, scope: str) -> typing.Tuple[list, list]:
pscope = _str_scope[_int_scope[scope] -1]
unique=[]
cols={}
......@@ -135,35 +136,35 @@ class Data:
return (col_list, metadata_new)
#TODO: Add deepcopy() calls to Data() arguments
def __add__(self, other):
def __add__(self, other) -> 'Data':
other = other.array if isinstance(other, Data) else other
array_new = self.array.__add__(other)
return Data(array_new, self.time, self.metadata)
def __sub__(self, other):
def __sub__(self, other) -> 'Data':
other = other.array if isinstance(other, Data) else other
array_new = self.array.__sub__(other)
return Data(array_new, self.time, self.metadata)
def __mul__(self, other):
def __mul__(self, other) -> 'Data':
other = other.array if isinstance(other, Data) else other
array_new = self.array.__mul__(other)
return Data(array_new, self.time, self.metadata)
def __lt__(self, other):
def __lt__(self, other) -> 'Data':
other = other.array if isinstance(other, Data) else other
array_new = self.array.__lt__(other)
return Data(array_new, self.time, self.metadata)
def __gt__(self, other):
def __gt__(self, other) -> 'Data':
other = other.array if isinstance(other, Data) else other
array_new = self.array.__gt__(other)
return Data(array_new, self.time, self.metadata)
def __rsub__(self, other):
def __rsub__(self, other) -> 'Data':
other = other.array if isinstance(other, Data) else other
array_new = self.array.__rsub__(other)
return Data(array_new, self.time, self.metadata)
def __truediv__(self, other):
def __truediv__(self, other) -> 'Data':
other = other.array if isinstance(other, Data) else other
array_new = self.array.__truediv__(other)
return Data(array_new, self.time, self.metadata)
def mean(self, scope='time'):
def mean(self, scope='time') -> 'Data':
if scope == 'all':
new_array = self.array.mean(keepdims=True)
new_time = self.time.mean(axis=0, keepdims=True)
......@@ -204,7 +205,7 @@ class Data:
# Integrated with trapz: 3.5
# Timediff: 6
# Mean: 3.5 / 6 = 0.583
def mean_int(self, scope='time'):
def mean_int(self, scope='time') -> 'Data':
if scope == 'all':
new_data = self.mean_int(scope='time')
new_array = new_data.array.mean(keepdims=True)
......@@ -222,7 +223,7 @@ class Data:
# mean of the timestamps
new_time = self.time.mean(axis=0, keepdims=True)
# integrate using the timestamps
new_array = np.trapz(self.array, x=self.time, axis=0)
new_array = np.trapz(self.array, x=self.time, axis=0) # type: ignore
# compute time difference
time_min = self.time.min(axis=0, keepdims=True)
time_max = self.time.max(axis=0, keepdims=True)
......@@ -233,7 +234,7 @@ class Data:
return Data(new_array, new_time, copy.deepcopy(self.metadata))
else:
return self.mean(scope=scope)
def sum(self, scope='time'):
def sum(self, scope='time') -> 'Data':
if scope == 'all':
new_array = self.array.sum(keepdims=True)
new_time = self.time.sum(axis=0, keepdims=True)
......@@ -257,11 +258,11 @@ class Data:
new_col.append(col_sum)
new_array = np.column_stack(new_col)
return Data(new_array, copy.deepcopy(self.time), copy.deepcopy(self.metadata))
def any(self, scope='time'):
def any(self, scope='time') -> 'Data':
if scope == 'all':
new_array = self.array.any(keepdims=True)
new_time = np.expand_dims(np.array([]), axis=1)
new_metadata = []
new_metadata: list = []
return Data(new_array, new_time, new_metadata)
elif scope == 'time':
new_time = self.time.any(axis=0, keepdims=True)
......@@ -278,11 +279,11 @@ class Data:
new_col.append(col_any)
new_array = np.column_stack(new_col)
return Data(new_array, copy.deepcopy(self.time), copy.deepcopy(self.metadata))
def min(self, scope='time'):
def min(self, scope='time') -> 'Data':
if scope == 'all':
new_array = self.array.min(keepdims=True)
new_time = np.expand_dims(np.array([]), axis=1)
new_metadata = []
new_metadata: list = []
return Data(new_array, new_time, new_metadata)
elif scope == 'time':
new_time = self.time.min(axis=0, keepdims=True)
......@@ -299,11 +300,11 @@ class Data:
new_col.append(col_any)
new_array = np.column_stack(new_col)
return Data(new_array, copy.deepcopy(self.time), copy.deepcopy(self.metadata))
def max(self, scope='time'):
def max(self, scope='time') -> 'Data':
if scope == 'all':
new_array = self.array.max(keepdims=True)
new_time = np.expand_dims(np.array([]), axis=1)
new_metadata = []
new_metadata: list = []
return Data(new_array, new_time, new_metadata)
elif scope == 'time':
new_time = self.time.max(axis=0, keepdims=True)
......@@ -320,11 +321,11 @@ class Data:
new_col.append(col_any)
new_array = np.column_stack(new_col)
return Data(new_array, copy.deepcopy(self.time), copy.deepcopy(self.metadata))
def std(self, scope='time'):
def std(self, scope='time') -> 'Data':
if scope == 'all':
new_array = self.array.std(keepdims=True)
new_time = np.expand_dims(np.array([]), axis=1)
new_metadata = []
new_metadata: list = []
return Data(new_array, new_time, new_metadata)
elif scope == 'time':
new_time = self.time.std(axis=0, keepdims=True)
......@@ -341,7 +342,7 @@ class Data:
new_col.append(col_any)
new_array = np.column_stack(new_col)
return Data(new_array, copy.deepcopy(self.time), copy.deepcopy(self.metadata))
def slope(self):
def slope(self) -> pint.Quantity:
# use `magnitude` to strip the Pint unit, because `fit` will strip the unit anyway and
# produce unnecessary warnings
s = []
......@@ -362,7 +363,7 @@ class Data:
pf_conv = pf.convert()
pfs.append(pf_conv)
return pfs
def to_array(self):
def to_array(self) -> np.ndarray:
time = self.time
if type(time) == prule.unit_registry.Quantity:
time = time.magnitude
......@@ -372,7 +373,7 @@ class Data:
if len(time) == 0:
return array
return np.concatenate((time, array), axis=1)
def to(self, *args, **kwargs):
def to(self, *args, **kwargs) -> 'Data':
array_new = copy.deepcopy(self.array).to(*args, **kwargs)
md_new = copy.deepcopy(self.metadata)
time_new = copy.deepcopy(self.time)
......
This diff is collapsed.
import typing
import sys
import datetime
debug_settings = None
debug_settings: typing.Optional[dict] = None
# debug_log_terms_dir: Path to directory for storing term values during evaluation
class Timing:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment