Skip to content
Snippets Groups Projects
Select Git revision
  • 8775aca047e94e64d744c87dad0bffc3a90bfe6d
  • main default protected
  • fortran
  • parcoach
  • fix-rma-lockunlock
  • paper_repro
  • usertypes
  • must-toolcoverage
  • toolcoverage
  • tools
  • must-json
  • merged
  • tools-parallel
  • coll
  • rma
  • dtypes
  • p2p
  • infrastructure-patch-3
  • infrastructure-patch2
  • devel-TJ
  • infrasructure-patch-1
21 results

InvalidComm.py

Blame
  • InvalidComm.py 6.88 KiB
    #! /usr/bin/python3
    from scripts.Infrastructure.ErrorGenerator import ErrorGenerator
    from scripts.Infrastructure.InstructionBlock import InstructionBlock
    from scripts.Infrastructure.MPICall import MPICall
    from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory
    from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv
    from scripts.Infrastructure.Template import TemplateManager
    from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_communicator, get_intercomm
    
    from itertools import chain
    
    sendrecv_funcs = ["mpi_sendrecv", "mpi_sendrecv_replace"]
    
    
    class InvalidCommErrorP2P(ErrorGenerator):
        invalid_comm = ["MPI_COMM_NULL", "NULL"]
    
        missmatching_comms = ["MPI_COMM_SELF", "mpi_comm_dup", "mpi_comm_dup_with_info", "mpi_comm_idup",
                              "mpi_comm_idup_with_info", "mpi_comm_create", "mpi_comm_create_group", "mpi_comm_split",
                              "mpi_comm_split_type", "mpi_comm_create_from_group"
                              ]
        intercomms = ["mpi_intercomm_create", "mpi_intercomm_merge", "mpi_intercomm_create_from_groups"]
    
        # as extended testcases
    
        comms_to_check = invalid_comm + missmatching_comms + intercomms
    
        mprobe_funcs = ["mpi_mprobe", "mpi_improbe"]
        probe_funcs = ["mpi_probe", "mpi_iprobe"]
    
        functions_to_check = ["mpi_send",
                              "mpi_recv", "mpi_irecv",
                              "mpi_isend", "mpi_ssend", "mpi_issend", "mpi_rsend", "mpi_irsend", "mpi_bsend", "mpi_ibsend",
                              "mpi_send_init", "mpi_ssend_init", "mpi_bsend_init", "mpi_rsend_init", "mpi_psend_init",
                              "mpi_precv_init", "mpi_recv_init"
                              ] + sendrecv_funcs + mprobe_funcs + probe_funcs
    
        recv_funcs = ["mpi_recv", "mpi_irecv", "mpi_recv_init",
                      "mpi_precv_init"] + sendrecv_funcs + mprobe_funcs + probe_funcs
    
        def __init__(self):
            pass
    
        def get_feature(self):
            return ["P2P"]
    
        def generate(self, generate_full_set):
            # TODO one may want to refactor it for better readability
            for send_func in self.functions_to_check:
                send_func_to_use = send_func
                for comm_to_use in self.comms_to_check:
                    check_receive = False
                    recv_func = "mpi_irecv"
                    if send_func in self.recv_funcs:
                        check_receive = True
                        recv_func = send_func
                        send_func_to_use = "mpi_send"
                        if recv_func in sendrecv_funcs:
                            send_func_to_use = recv_func
    
                    recv_func_to_use = recv_func
                    if recv_func in self.mprobe_funcs:
                        recv_func_to_use = [recv_func, "mpi_mrecv"]
                    if recv_func in self.probe_funcs:
                        recv_func_to_use = "mpi_recv"
    
                    tm = get_send_recv_template(send_func_to_use, recv_func_to_use)
    
                    if comm_to_use in self.missmatching_comms and comm_to_use != "MPI_COMM_SELF":
                        b = get_communicator(comm_to_use, comm_to_use)
                        tm.insert_block(b, after_block_name="alloc")
                    if comm_to_use in self.intercomms:
                        b = get_intercomm(comm_to_use, comm_to_use)
                        tm.insert_block(b, after_block_name="alloc")
    
                    error_string = "ParamMatching"
                    if comm_to_use in self.invalid_comm:
                        error_string = "InvalidParam"
    
                    if check_receive:
                        if comm_to_use in self.missmatching_comms + self.intercomms and recv_func == "mpi_irecv":
                            # combination repeated
                            continue
                        tm.set_description(error_string + "-Comm-" + recv_func, error_string + ": %s" % comm_to_use)
                    else:
                        tm.set_description(error_string + "-Comm-" + send_func, error_string + ": %s" % comm_to_use)
    
                    # add an additional probe call
                    if send_func == "mpi_probe":
                        tm.get_block("MPICALL").insert_instruction(CorrectMPICallFactory.get(send_func), kind=0,
                                                                   before_index=0)
                    if send_func == "mpi_iprobe":
                        tm.get_block("MPICALL").insert_instruction("int flag=0;", kind=0,
                                                                   before_index=0)
                        tm.get_block("MPICALL").insert_instruction("while (!flag){", kind=0,
                                                                   before_index=1)
                        tm.get_block("MPICALL").insert_instruction(CorrectMPICallFactory.get(send_func), kind=0,
                                                                   before_index=2)
                        tm.get_block("MPICALL").insert_instruction("}", kind=0,
                                                                   before_index=3)  # end while
    
                    kind = 1
                    if check_receive:
                        kind = 0
                    idx = 0
                    if recv_func == "mpi_improbe":
                        idx = 1
                    if recv_func == "mpi_iprobe":
                        idx = 2
    
                    tm.get_block("MPICALL").get_instruction(kind=kind, index=idx).set_arg("comm", comm_to_use)
                    tm.get_block("MPICALL").get_instruction(kind=kind, index=idx).set_has_error()
                    if comm_to_use in self.missmatching_comms + self.intercomms:
                        # missmatch is between both
                        tm.get_block("MPICALL").get_instruction(kind=((kind + 1) % 2), index=0).set_has_error()
    
                    # an intercomm has only one rank (the other group)
                    # so all rank values must be set to 0
                    if comm_to_use in self.intercomms and not comm_to_use == "mpi_intercomm_merge":
                        # intercomm merge results in an "equivalent" comm again
                        for inst in tm.get_block("MPICALL").get_instruction(kind=0, index='all'):
                            if isinstance(inst, MPICall):
                                if inst.has_arg("source"):
                                    inst.set_arg("source", "0")
                        for inst in tm.get_block("MPICALL").get_instruction(kind=1, index='all'):
                            if isinstance(inst, MPICall):
                                if inst.has_arg("source"):
                                    inst.set_arg("source", "0")
    
                    if comm_to_use in self.missmatching_comms + self.intercomms and comm_to_use != "MPI_COMM_SELF":
                        b = InstructionBlock("comm_free")
                        b.register_instruction(MPICallFactory().mpi_comm_free("&" + comm_to_use))
                        tm.register_instruction_block(b)
    
                    yield tm
                # end for comm to check
                if not generate_full_set:
                    return
            # end for send_func in funcs_to_check