Skip to content
Snippets Groups Projects
Commit 870ffe5b authored by Emmanuelle Saillard's avatar Emmanuelle Saillard
Browse files

add support for topology

parent e6838f6b
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,9 @@ class CorrectParameterFactory:
def get(self, param: str) -> str:
if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf", "origin_addr"]:
return self.buf_var_name
if param in ["COUNT", "count", "sendcount", "sendcounts", "recvcount", "recvcounts", "origin_count", "target_count", "result_count"]:
if param in ["COUNT", "count", "sendcount", "recvcount", "origin_count", "target_count", "result_count"]:
return str(self.buf_size)
if param in ["sendcounts", "recvcounts"]:
return str(self.buf_size)
if param in ["DATATYPE", "datatype", "sendtype", "recvtype", "origin_datatype", "target_datatype",
"result_datatype"]:
......@@ -99,6 +101,18 @@ class CorrectParameterFactory:
return "resultbuf"
if param in ["compare_addr"]:
return "comparebuf"
if param in ["comm_cart"]:
return "&mpi_comm_0"
if param in ["comm_old"]:
return "MPI_COMM_WORLD"
if param in ["ndims"]:
return "2"
if param in ["dims"]:
return "dims"
if param in ["periods"]:
return "periods"
if param in ["reorder"]:
return "0"
print("Not Implemented: " + param)
assert False, "Param not known"
......
......@@ -183,30 +183,31 @@ def get_collective_template(collective_func):
TemplateManager Initialized with a default template
The function is contained in a block named MPICALL
"""
need_buf_funcs = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan", "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce", "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter", "mpi_reduce_scatter_block"]
need_recv_and_send_buf_funcs = []
tm = TemplateManager()
cf = CorrectParameterFactory()
# spilt send and recv buf
#if collective_func in need_buf_funcs:
cmpicf = CorrectMPICallFactory()
call_creator_function = getattr(cmpicf, collective_func)
c = call_creator_function()
if c.has_arg("buffer") or c.has_arg("sendbuf"):
alloc = cf.get_buffer_alloc()
alloc.set_identifier("ALLOC")
alloc.set_name("buf")
tm.register_instruction(alloc)
cmpicf = CorrectMPICallFactory()
call_creator_function = getattr(cmpicf, collective_func)
c = call_creator_function()
if c.has_arg("comm_cart"):
tm.add_stack_variable("MPI_Comm")
# TODO: create proper instructions
tm.register_instruction(Instruction("int periods[2]={1,1};"), identifier="ALLOC")
tm.register_instruction(Instruction("int dims[2]={0,0};"), identifier="ALLOC") # use MPI_Dims_create(nprocs,2,dims);
# add request for nonblocking collectives
if collective_func.startswith("mpi_i"):
tm.add_stack_variable("MPI_Request")
# Set parameters for some collectives: sendcount, recvcounts
#if collective_func in ["mpi_alltoallv"]:
# TODO
coll = CorrectMPICallFactory.get(collective_func)
coll.set_identifier("MPICALL")
......@@ -216,7 +217,7 @@ def get_collective_template(collective_func):
if collective_func.startswith("mpi_i"):
tm.register_instruction(CorrectMPICallFactory.mpi_wait(), rank_to_execute='all', identifier="WAIT")
#if collective_func in need_buf_funcs:
if c.has_arg("buffer") or c.has_arg("sendbuf"):
tm.register_instruction(cf.get_buffer_free(), identifier="FREE")
return tm
......@@ -228,14 +229,11 @@ def get_two_collective_template(collective_func1, collective_func2):
TemplateManager Initialized with a default template
The function is contained in a block named MPICALL
"""
need_buf_funcs = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan", "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce", "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter", "mpi_reduce_scatter_block"]
need_recv_and_send_buf_funcs = []
tm = TemplateManager()
cf = CorrectParameterFactory()
# spilt send and recv buf
#if collective_func in need_buf_funcs:
# todo: spilt send and recv buf
alloc = cf.get_buffer_alloc()
alloc.set_identifier("ALLOC")
alloc.set_name("buf")
......@@ -249,9 +247,6 @@ def get_two_collective_template(collective_func1, collective_func2):
if collective_func1.startswith("mpi_i") or collective_func2.startswith("mpi_i"):
tm.add_stack_variable("MPI_Request")
# Set parameters for some collectives: sendcount, recvcounts
#if collective_func in ["mpi_alltoallv"]:
# TODO
coll1 = CorrectMPICallFactory.get(collective_func1)
coll1.set_identifier("MPICALL")
......
......@@ -11,6 +11,7 @@ class CorrectColl(ErrorGenerator):
functions_to_use = ["mpi_allgather","mpi_allreduce","mpi_alltoall","mpi_barrier","mpi_bcast", "mpi_reduce", "mpi_scatter","mpi_exscan","mpi_gather", "mpi_reduce_scatter_block", "mpi_scan", "mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan" ]
functions_not_supported_yet = ["mpi_gatherv", "mpi_scatterv", "mpi_igatherv", "mpi_iscatterv"]
topology_functions = ["mpi_cart_create"]
def __init__(self):
pass
......@@ -29,3 +30,12 @@ class CorrectColl(ErrorGenerator):
if not generate_full_set:
return
for func_to_use in self.topology_functions:
tm = get_collective_template(func_to_use)
tm.set_description("Correct-"+func_to_use, "Correct code")
yield tm
if not generate_full_set:
return
......@@ -8,7 +8,7 @@ from scripts.Infrastructure.Template import TemplateManager
from scripts.Infrastructure.TemplateFactory import get_collective_template
class InvalidRankErrorColl(ErrorGenerator):
nbfunc_to_use = ["mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan" ]
nbfunc_to_use = ["mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan" ]
def __init__(self):
pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment