diff --git a/scripts/Infrastructure/CorrectParameter.py b/scripts/Infrastructure/CorrectParameter.py index 6f0c2c2967eaa9555d09b73e6009ae8987337fb7..0a2fc99c27730054ddfb8d70c5c424b6fa7c3886 100644 --- a/scripts/Infrastructure/CorrectParameter.py +++ b/scripts/Infrastructure/CorrectParameter.py @@ -26,7 +26,9 @@ class CorrectParameterFactory: def get(self, param: str) -> str: if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf", "origin_addr"]: return self.buf_var_name - if param in ["COUNT", "count", "sendcount", "sendcounts", "recvcount", "recvcounts", "origin_count", "target_count", "result_count"]: + if param in ["COUNT", "count", "sendcount", "recvcount", "origin_count", "target_count", "result_count"]: + return str(self.buf_size) + if param in ["sendcounts", "recvcounts"]: return str(self.buf_size) if param in ["DATATYPE", "datatype", "sendtype", "recvtype", "origin_datatype", "target_datatype", "result_datatype"]: @@ -99,6 +101,18 @@ class CorrectParameterFactory: return "resultbuf" if param in ["compare_addr"]: return "comparebuf" + if param in ["comm_cart"]: + return "&mpi_comm_0" + if param in ["comm_old"]: + return "MPI_COMM_WORLD" + if param in ["ndims"]: + return "2" + if param in ["dims"]: + return "dims" + if param in ["periods"]: + return "periods" + if param in ["reorder"]: + return "0" print("Not Implemented: " + param) assert False, "Param not known" diff --git a/scripts/Infrastructure/TemplateFactory.py b/scripts/Infrastructure/TemplateFactory.py index 2ef9f31b4836b3fe156790b1babb4827aaee41f2..a90a6567040b2b3214ed94f3fa545a9ec99106e7 100644 --- a/scripts/Infrastructure/TemplateFactory.py +++ b/scripts/Infrastructure/TemplateFactory.py @@ -183,30 +183,31 @@ def get_collective_template(collective_func): TemplateManager Initialized with a default template The function is contained in a block named MPICALL """ - need_buf_funcs = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan", "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce", "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter", "mpi_reduce_scatter_block"] - need_recv_and_send_buf_funcs = [] tm = TemplateManager() cf = CorrectParameterFactory() - # spilt send and recv buf - #if collective_func in need_buf_funcs: - alloc = cf.get_buffer_alloc() - alloc.set_identifier("ALLOC") - alloc.set_name("buf") - tm.register_instruction(alloc) - cmpicf = CorrectMPICallFactory() call_creator_function = getattr(cmpicf, collective_func) c = call_creator_function() + if c.has_arg("buffer") or c.has_arg("sendbuf"): + alloc = cf.get_buffer_alloc() + alloc.set_identifier("ALLOC") + alloc.set_name("buf") + tm.register_instruction(alloc) + + if c.has_arg("comm_cart"): + tm.add_stack_variable("MPI_Comm") + # TODO: create proper instructions + tm.register_instruction(Instruction("int periods[2]={1,1};"), identifier="ALLOC") + tm.register_instruction(Instruction("int dims[2]={0,0};"), identifier="ALLOC") # use MPI_Dims_create(nprocs,2,dims); + + # add request for nonblocking collectives if collective_func.startswith("mpi_i"): tm.add_stack_variable("MPI_Request") - # Set parameters for some collectives: sendcount, recvcounts - #if collective_func in ["mpi_alltoallv"]: - # TODO coll = CorrectMPICallFactory.get(collective_func) coll.set_identifier("MPICALL") @@ -216,8 +217,8 @@ def get_collective_template(collective_func): if collective_func.startswith("mpi_i"): tm.register_instruction(CorrectMPICallFactory.mpi_wait(), rank_to_execute='all', identifier="WAIT") - #if collective_func in need_buf_funcs: - tm.register_instruction(cf.get_buffer_free(), identifier="FREE") + if c.has_arg("buffer") or c.has_arg("sendbuf"): + tm.register_instruction(cf.get_buffer_free(), identifier="FREE") return tm @@ -228,14 +229,11 @@ def get_two_collective_template(collective_func1, collective_func2): TemplateManager Initialized with a default template The function is contained in a block named MPICALL """ - need_buf_funcs = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan", "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce", "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter", "mpi_reduce_scatter_block"] - need_recv_and_send_buf_funcs = [] tm = TemplateManager() cf = CorrectParameterFactory() - # spilt send and recv buf - #if collective_func in need_buf_funcs: + # todo: spilt send and recv buf alloc = cf.get_buffer_alloc() alloc.set_identifier("ALLOC") alloc.set_name("buf") @@ -249,9 +247,6 @@ def get_two_collective_template(collective_func1, collective_func2): if collective_func1.startswith("mpi_i") or collective_func2.startswith("mpi_i"): tm.add_stack_variable("MPI_Request") - # Set parameters for some collectives: sendcount, recvcounts - #if collective_func in ["mpi_alltoallv"]: - # TODO coll1 = CorrectMPICallFactory.get(collective_func1) coll1.set_identifier("MPICALL") diff --git a/scripts/errors/coll/Correct.py b/scripts/errors/coll/Correct.py index 44a6c165c8cde1f8a21a46ba4dffd41fdfebadbd..97ed5bbc5bb393bf2a9424683dc1cdbed7f5784b 100644 --- a/scripts/errors/coll/Correct.py +++ b/scripts/errors/coll/Correct.py @@ -11,6 +11,7 @@ class CorrectColl(ErrorGenerator): functions_to_use = ["mpi_allgather","mpi_allreduce","mpi_alltoall","mpi_barrier","mpi_bcast", "mpi_reduce", "mpi_scatter","mpi_exscan","mpi_gather", "mpi_reduce_scatter_block", "mpi_scan", "mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan" ] functions_not_supported_yet = ["mpi_gatherv", "mpi_scatterv", "mpi_igatherv", "mpi_iscatterv"] + topology_functions = ["mpi_cart_create"] def __init__(self): pass @@ -29,3 +30,12 @@ class CorrectColl(ErrorGenerator): if not generate_full_set: return + + for func_to_use in self.topology_functions: + tm = get_collective_template(func_to_use) + + tm.set_description("Correct-"+func_to_use, "Correct code") + yield tm + + if not generate_full_set: + return diff --git a/scripts/errors/coll/LocalConcurrency.py b/scripts/errors/coll/LocalConcurrency.py index c3209915aab698c3f4611c287d42b414354c5161..811ce739c13b1f5463933416072d7b46af6df5b5 100644 --- a/scripts/errors/coll/LocalConcurrency.py +++ b/scripts/errors/coll/LocalConcurrency.py @@ -8,7 +8,7 @@ from scripts.Infrastructure.Template import TemplateManager from scripts.Infrastructure.TemplateFactory import get_collective_template class InvalidRankErrorColl(ErrorGenerator): - nbfunc_to_use = ["mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan" ] + nbfunc_to_use = ["mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan" ] def __init__(self): pass