Skip to content
Snippets Groups Projects
Commit 82e0f8be authored by Jammer, Tim's avatar Jammer, Tim
Browse files

Merge branch 'devel-TJ' into 'main'

more work on infrastructure II

See merge request !3
parents c01e9805 bf4a0fe6
Branches
No related tags found
1 merge request!3more work on infrastructure II
......@@ -49,7 +49,7 @@ class AllocCall:
.replace("@{NAME}@", self._name)
.replace("@{TYPE}@", self._type)
.replace("@{FUNCTION}@", func)
.replace("@{NUM}@", self._num_elements)
.replace("@{NUM}@", str(self._num_elements))
.replace("@{SEP}", delim))
def set_num_elements(self, num_elements):
......
#! /usr/bin/python3
from collections import OrderedDict
from scripts.Infrastructure.MPICall import MPI_Call
from scripts.Infrastructure.MPICallFactory import MPICallFactory
from scripts.Infrastructure.Template import InstructionBlock
from scripts.Infrastructure.AllocCall import AllocCall
class MPI_Call_Factory:
def mpi_send(self, *args):
return MPI_Call("MPI_Send",
OrderedDict([("BUFFER", args[0]), ("COUNT", args[1]), ("DATATYPE", args[2]), ("SRC", args[3]),
("TAG", args[4]), ("COMM", args[5])]),
"1.0")
def mpi_recv(self, *args):
return MPI_Call("MPI_Recv",
OrderedDict([("BUFFER", args[0]), ("COUNT", args[1]), ("DATATYPE", args[2]), ("SRC", args[3]),
("TAG", args[4]), ("COMM", args[5]), ("STATUS", args[6])]),
"1.0")
class Correct_Parameter:
class CorrectParameterFactory:
# default params
buf_size = 10
dtype = ['int', 'MPI_INT']
tag = 0
buf_var_name = "buf"
def __init__(self):
pass
def get_buffer_alloc(self):
b = InstructionBlock()
b.register_operation(("int* buf = (int*) malloc(%d* sizeof(%s));" % (self.buf_size, self.dtype[0])), kind='all')
b = InstructionBlock("alloc")
b.register_operation(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False), kind='all')
return b
def get(self, param, func=None):
if param == "BUFFER":
return "buf"
if param == "COUNT":
if param == "BUFFER" or param == "buf" or param == "buffer" or param == "sendbuf" or param == "recvbuf":
return self.buf_var_name
if param == "COUNT" or param == "count":
return str(self.buf_size)
if param == "DATATYPE":
if param == "DATATYPE" or param == "datatype":
return self.dtype[1]
if param == "SRC":
if param == "DEST" or param == "dest":
return "0"
if param == "SRC" or param == "source":
return "1"
if param == "RANK" or param == "root":
return "0"
if param == "TAG":
if param == "TAG" or param == "tag":
return str(self.tag)
if param == "COMM":
if param == "COMM" or param == "comm":
return "MPI_COMM_WORLD"
if param == "STATUS":
if param == "STATUS" or param == "status":
return "MPI_STATUS_IGNORE"
if param == "OPERATION" or param == "op":
return "MPI_SUM"
print("Not Implemented: " + param)
assert False, "Param not known"
# todo also for send and non default args
def get_matching_recv(call):
correct_params = Correct_Parameter()
recv = MPI_Call_Factory().mpi_recv(
correct_params = CorrectParameterFactory()
recv = MPICallFactory().mpi_recv(
correct_params.get("BUFFER"),
correct_params.get("COUNT"),
correct_params.get("DATATYPE"),
......
......@@ -16,10 +16,22 @@ from collections import OrderedDict
from scripts.Infrastructure.MPICall import MPI_Call
class MPI_Call_Factory:
class MPICallFactory:
"""
correct_call_factory_header="""
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory
class CorrectMPICallFactory:
"""
template_correct = """
def @{FUNC_KEY}@(self):
correct_params = CorrectParameterFactory()
return MPICallFactory().@{FUNC_KEY}@(@{PARAMS}@)
"""
def main():
# read in the "official" standards json to get all mpi functions and there params
......@@ -31,19 +43,25 @@ def main():
class_str = file_header
correct_class_str = correct_call_factory_header
version_dict = get_mpi_version_dict()
for key, api_spec in api_specs.items():
spec = api_specs[key]
name = spec['name']
dict_str = "["
correct_param_str = ""
i = 0
for param in spec['parameters']:
if 'c_parameter' not in param['suppress']:
dict_str = dict_str + "(\"" + param['name'] + "\", args[" + str(i) + "]),"
correct_param_str = correct_param_str + "correct_params.get(\""+param['name']+"\"),"
i = i + 1
pass
dict_str = dict_str + "]"
correct_param_str=correct_param_str[:-1]# remove last ,
ver = "4.0"
# everyting not in dict is 4.0
......@@ -57,17 +75,13 @@ def main():
.replace("@{VERSION}@", ver))
class_str = class_str+ function_def_str
with open(output_file,"w") as outfile:
outfile.write(class_str)
# def mpi_send(self, *args):
# return MPI_Call("MPI_Send",
# OrderedDict([("BUFFER", args[0]), ("COUNT", args[1]), ("DATATYPE", args[2]), ("SRC", args[3]),
# ("TAG", args[4]), ("COMM", args[5])]),
# "1.0")
correct_function_def_str =(template_correct
.replace("@{FUNC_KEY}@", key)
.replace("@{PARAMS}@", correct_param_str))
correct_class_str=correct_class_str+ correct_function_def_str
with open(output_file,"w") as outfile:
outfile.write(class_str+correct_class_str)
if __name__ == "__main__":
main()
This diff is collapsed.
This diff is collapsed.
......@@ -175,7 +175,7 @@ class TemplateManager:
to_return = [b for b in self._blocks if b.name == block_name]
if len(to_return) == 0:
raise IndexError("Block Not Found")
if len(to_return) > 0:
if len(to_return) > 1:
raise IndexError("Multiple Blocks Found")
return to_return[0]
......
#! /usr/bin/python3
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory
from scripts.Infrastructure.InstructionBlock import InstructionBlock
from scripts.Infrastructure.MPICall import MPI_Call
from scripts.Infrastructure.MPICallFactory import CorrectMPICallFactory
from scripts.Infrastructure.Template import TemplateManager
def get_default_template(mpi_func):
"""
Contructs a default template for the given mpi function
Returns:
TemplateManager Initialized with a default template
The function is contained in a block named MPICALL with seperate calls for rank 1 and 2)
"""
pass
def get_send_recv_template(send_func, recv_func):
"""
Contructs a default template for the given mpi send recv function pair
Returns:
TemplateManager Initialized with a default template
The function is contained in a block named MPICALL with seperate calls for rank 1 and 2)
"""
assert send_func == "mpi_send" or send_func == "mpi_ssend"
assert recv_func == "mpi_recv"
tm = TemplateManager()
cf = CorrectParameterFactory()
tm.register_instruction_block(cf.get_buffer_alloc())
cmpicf = CorrectMPICallFactory()
send_func_creator_function = getattr(cmpicf, send_func)
s = send_func_creator_function()
recv_func_creator_function = getattr(cmpicf, recv_func)
r = recv_func_creator_function()
b = InstructionBlock("MPICALL")
b.register_operation(s, 1)
b.register_operation(r, 0)
tm.register_instruction_block(b)
return tm
def get_collective_template(collective_func,seperate=True):
"""
Contructs a default template for the given mpi collecive
Returns:
TemplateManager Initialized with a default template
The function is contained in a block named MPICALL
with seperate calls for rank 1 and 2 if seperate ==True
"""
tm = TemplateManager()
cf = CorrectParameterFactory()
tm.register_instruction_block(cf.get_buffer_alloc())
cmpicf = CorrectMPICallFactory()
call_creator_function = getattr(cmpicf, collective_func)
c = call_creator_function()
b = InstructionBlock("MPICALL")
if seperate:
b.register_operation(c, 1)
b.register_operation(c, 0)
else:
b.register_operation(c,'all')
tm.register_instruction_block(b)
return tm
......@@ -2,13 +2,14 @@
from scripts.Infrastructure.ErrorGenerator import ErrorGenerator
from scripts.Infrastructure.InstructionBlock import InstructionBlock
from scripts.Infrastructure.MPICallFactory import MPI_Call_Factory
from scripts.Infrastructure.CorrectParameter import Correct_Parameter,get_matching_recv
from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv
from scripts.Infrastructure.Template import TemplateManager
from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_collective_template
class Invalid_negative_rank_error(ErrorGenerator):
invalid_ranks = ["-1", "size", "NULL", "MPI_PROC_NULL"]
class InvalidRankErrorP2P(ErrorGenerator):
invalid_ranks = ["-1", "size", "MPI_PROC_NULL"]
def __init__(self):
pass
......@@ -24,29 +25,43 @@ class Invalid_negative_rank_error(ErrorGenerator):
return ["P2P"]
def generate(self, i):
tm = TemplateManager()
correct_params = Correct_Parameter()
tm.set_description("InvalidParam-Rank-MPI_Send", "Invalid Rank: %s" % self.invalid_ranks[i])
# include the buffer allocation in the template (all ranks execute it)
tm.register_instruction_block(correct_params.get_buffer_alloc())
send = MPI_Call_Factory().mpi_send(
correct_params.get("BUFFER"),
correct_params.get("COUNT"),
correct_params.get("DATATYPE"),
self.invalid_ranks[i], # invalid rank
correct_params.get("TAG"),
correct_params.get("COMM"),
)
send.set_has_error()
b = InstructionBlock()
# only rank 0 execute the send
b.register_operation(send, 0)
# only rank 1 execute the recv
b.register_operation(get_matching_recv(send), 1)
tm.register_instruction_block(b)
rank_to_use = self.invalid_ranks[i]
tm = get_send_recv_template("mpi_send", "mpi_recv")
tm.set_description("InvalidParam-Rank-MPI_Send", "Invalid Rank: %s" % rank_to_use)
tm.get_block("MPICALL").get_operation(kind=0, index=0).set_arg("source", rank_to_use)
tm.get_block("MPICALL").get_operation(kind=0, index=0).set_has_error()
return tm
class InvalidRankErrorColl(ErrorGenerator):
invalid_ranks = ["-1", "size", "MPI_PROC_NULL"]
functions_to_use = ["mpi_reduce", "mpi_bcast"]
def __init__(self):
pass
def get_num_errors(self):
return len(self.invalid_ranks) * len(self.functions_to_use)
# the number of errors to produce in the extended mode (all possible combinations)
def get_num_errors_extended(self):
return len(self.invalid_ranks) * len(self.functions_to_use)
def get_feature(self):
return ["COLL"]
def generate(self, i):
rank_to_use = self.invalid_ranks[i // len(self.functions_to_use)]
func_to_use = self.functions_to_use[i % len(self.functions_to_use)]
tm = get_collective_template(func_to_use, seperate=False)
arg_to_replace = "root"
tm.set_description("InvalidParam-Rank-"+func_to_use, "Invalid Rank: %s" % rank_to_use)
tm.get_block("MPICALL").get_operation(kind='all', index=0).set_arg(arg_to_replace, rank_to_use)
tm.get_block("MPICALL").get_operation(kind='all', index=0).set_has_error()
return tm
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment