diff --git a/scripts/errors/coll/ParamMatchingType.py b/scripts/errors/coll/ParamMatchingType.py index 8a6feabab3962227ead97412f988478f7575d34a..9601ea31497e1d7528a622fde5fa0d60e3755857 100644 --- a/scripts/errors/coll/ParamMatchingType.py +++ b/scripts/errors/coll/ParamMatchingType.py @@ -24,9 +24,13 @@ def get_local_missmatch(type_1, type_2, func_to_use): type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 1, 1) for call in tm.get_instruction("MPICALL", return_list=True): - call.set_rank_executing(0) - call.set_arg("datatype", type_var_1) + call.set_has_error() + if call.has_arg("recvtype"): + call.set_arg("recvtype", type_var_1) + call.set_arg("sendtype", type_var_1) + else: + call.set_arg("datatype", type_var_1) if call.has_arg("recvbuf"): call.set_arg("recvbuf", buf_name_2) call.set_arg("sendbuf", buf_name_2) @@ -50,27 +54,43 @@ def get_global_missmatch(type_1, type_2, count_1, count_2, func_to_use, comm): for call in tm.get_instruction("MPICALL", return_list=True): call.set_rank_executing(0) - call.set_arg("datatype", type_var_1) - call.set_arg("count", count_1) - call.set_arg("comm", comm_var_name) call.set_has_error() + call.set_arg("comm", comm_var_name) if call.has_arg("recvbuf"): call.set_arg("recvbuf", buf_name_1) call.set_arg("sendbuf", buf_name_1) else: call.set_arg("buffer", buf_name_1) + if call.has_arg("recvtype"): + call.set_arg("recvtype", type_var_1) + call.set_arg("sendtype", type_var_1) + else: + call.set_arg("datatype", type_var_1) + if call.has_arg("recvcount"): + call.set_arg("recvcount", count_1) + call.set_arg("sendcount", count_1) + else: + call.set_arg("count", count_1) c = CorrectMPICallFactory.get(func_to_use) c.set_rank_executing('not0') - c.set_arg("datatype", type_var_2) - c.set_arg("count", count_2) + call.set_has_error() c.set_arg("comm", comm_var_name) - c.set_has_error() if c.has_arg("recvbuf"): c.set_arg("recvbuf", buf_name_2) c.set_arg("sendbuf", buf_name_2) else: c.set_arg("buffer", buf_name_2) + if c.has_arg("recvtype"): + c.set_arg("recvtype", type_var_2) + c.set_arg("sendtype", type_var_2) + else: + c.set_arg("datatype", type_var_2) + if c.has_arg("recvcount"): + c.set_arg("recvcount", count_2) + c.set_arg("sendcount", count_2) + else: + c.set_arg("count", count_2) tm.insert_instruction(c, after_instruction=call) @@ -91,27 +111,41 @@ def get_correct_case(type_1, count_1, func_to_use, comm): for call in tm.get_instruction("MPICALL", return_list=True): call.set_rank_executing(0) - call.set_arg("datatype", type_var_1) - call.set_arg("count", count_1) call.set_arg("comm", comm_var_name) - if call.has_arg("recvbuf"): call.set_arg("recvbuf", buf_name_1) call.set_arg("sendbuf", buf_name_1) else: call.set_arg("buffer", buf_name_1) + if call.has_arg("recvtype"): + call.set_arg("recvtype", type_var_1) + call.set_arg("sendtype", type_var_1) + else: + call.set_arg("datatype", type_var_1) + if call.has_arg("recvcount"): + call.set_arg("recvcount", count_1) + call.set_arg("sendcount", count_1) + else: + call.set_arg("count", count_1) c = CorrectMPICallFactory.get(func_to_use) c.set_rank_executing('not0') - c.set_arg("datatype", type_var_2) - c.set_arg("count", count_1) c.set_arg("comm", comm_var_name) - c.set_has_error() if c.has_arg("recvbuf"): c.set_arg("recvbuf", buf_name_2) c.set_arg("sendbuf", buf_name_2) else: c.set_arg("buffer", buf_name_2) + if c.has_arg("recvtype"): + c.set_arg("recvtype", type_var_2) + c.set_arg("sendtype", type_var_2) + else: + c.set_arg("datatype", type_var_2) + if c.has_arg("recvcount"): + c.set_arg("recvcount", count_1) + c.set_arg("sendcount", count_1) + else: + c.set_arg("count", count_1) tm.insert_instruction(c, after_instruction=call) @@ -187,4 +221,4 @@ class InvalidComErrorColl(ErrorGenerator): if type_1 not in correct_types_checked: correct_types_checked.add(type_1) - yield get_correct_case(type_1, 1,func_to_use, comm) + yield get_correct_case(type_1, 1, func_to_use, comm)