diff --git a/scripts/errors/pt2pt/InvalidComm.py b/scripts/errors/pt2pt/InvalidComm.py index 1272833638438b0d955f8223238560c5a995f749..95e114f77e956588c185fcdf053e06b4a96ee3f9 100644 --- a/scripts/errors/pt2pt/InvalidComm.py +++ b/scripts/errors/pt2pt/InvalidComm.py @@ -59,7 +59,7 @@ class InvalidCommErrorP2P(ErrorGenerator): recv_func = "mpi_irecv" # not implemented continue - #TODO add probe call + # TODO add probe call for comm_to_use in self.comms_to_check: if comm_to_use in self.missmatching_comms + self.intercomms and recv_func == "mpi_irecv" and generate_level < FULL_TEST_LEVEL: @@ -74,14 +74,16 @@ class InvalidCommErrorP2P(ErrorGenerator): if comm_to_use in self.missmatching_comms and comm_to_use != "MPI_COMM_SELF": comm_var_name = get_communicator(comm_to_use, tm) - # use precprcessor to make this two identifieres same - # alternatively, chang the arg in the MPI call to teh result variable name - tm.insert_instruction(Instruction("#define " + comm_to_use + " " + comm_var_name), - before_instruction=0) + # change the arg in the MPI call to the result variable name + for call in tm.get_instruction(identifier="MPICALL", return_list=True): + if call.get_arg("comm") == comm_to_use: + call.set_arg("comm", comm_var_name) + if comm_to_use in self.intercomms: comm_var_name = get_intercomm(comm_to_use, tm) - tm.insert_instruction(Instruction("#define " + comm_to_use + " " + comm_var_name), - before_instruction=0) + for call in tm.get_instruction(identifier="MPICALL", return_list=True): + if call.get_arg("comm") == comm_to_use: + call.set_arg("comm", comm_var_name) # if intercomm: set rank to 0 instead of 1 as ther is only one rank in intercomm if comm_to_use in self.intercomms and not comm_to_use == "mpi_intercomm_merge":