From cb161b24dba64a26b31c97c3bb2fa7103a152a0b Mon Sep 17 00:00:00 2001 From: Tim Jammer <tim.jammer@tu-darmstadt.de> Date: Mon, 29 Apr 2024 14:28:00 +0200 Subject: [PATCH] fix usage of multiple type arg functions --- scripts/errors/coll/ParamMatchingType.py | 64 ++++++++++++++++++------ 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/scripts/errors/coll/ParamMatchingType.py b/scripts/errors/coll/ParamMatchingType.py index 8a6feabab..9601ea314 100644 --- a/scripts/errors/coll/ParamMatchingType.py +++ b/scripts/errors/coll/ParamMatchingType.py @@ -24,9 +24,13 @@ def get_local_missmatch(type_1, type_2, func_to_use): type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 1, 1) for call in tm.get_instruction("MPICALL", return_list=True): - call.set_rank_executing(0) - call.set_arg("datatype", type_var_1) + call.set_has_error() + if call.has_arg("recvtype"): + call.set_arg("recvtype", type_var_1) + call.set_arg("sendtype", type_var_1) + else: + call.set_arg("datatype", type_var_1) if call.has_arg("recvbuf"): call.set_arg("recvbuf", buf_name_2) call.set_arg("sendbuf", buf_name_2) @@ -50,27 +54,43 @@ def get_global_missmatch(type_1, type_2, count_1, count_2, func_to_use, comm): for call in tm.get_instruction("MPICALL", return_list=True): call.set_rank_executing(0) - call.set_arg("datatype", type_var_1) - call.set_arg("count", count_1) - call.set_arg("comm", comm_var_name) call.set_has_error() + call.set_arg("comm", comm_var_name) if call.has_arg("recvbuf"): call.set_arg("recvbuf", buf_name_1) call.set_arg("sendbuf", buf_name_1) else: call.set_arg("buffer", buf_name_1) + if call.has_arg("recvtype"): + call.set_arg("recvtype", type_var_1) + call.set_arg("sendtype", type_var_1) + else: + call.set_arg("datatype", type_var_1) + if call.has_arg("recvcount"): + call.set_arg("recvcount", count_1) + call.set_arg("sendcount", count_1) + else: + call.set_arg("count", count_1) c = CorrectMPICallFactory.get(func_to_use) c.set_rank_executing('not0') - c.set_arg("datatype", type_var_2) - c.set_arg("count", count_2) + call.set_has_error() c.set_arg("comm", comm_var_name) - c.set_has_error() if c.has_arg("recvbuf"): c.set_arg("recvbuf", buf_name_2) c.set_arg("sendbuf", buf_name_2) else: c.set_arg("buffer", buf_name_2) + if c.has_arg("recvtype"): + c.set_arg("recvtype", type_var_2) + c.set_arg("sendtype", type_var_2) + else: + c.set_arg("datatype", type_var_2) + if c.has_arg("recvcount"): + c.set_arg("recvcount", count_2) + c.set_arg("sendcount", count_2) + else: + c.set_arg("count", count_2) tm.insert_instruction(c, after_instruction=call) @@ -91,27 +111,41 @@ def get_correct_case(type_1, count_1, func_to_use, comm): for call in tm.get_instruction("MPICALL", return_list=True): call.set_rank_executing(0) - call.set_arg("datatype", type_var_1) - call.set_arg("count", count_1) call.set_arg("comm", comm_var_name) - if call.has_arg("recvbuf"): call.set_arg("recvbuf", buf_name_1) call.set_arg("sendbuf", buf_name_1) else: call.set_arg("buffer", buf_name_1) + if call.has_arg("recvtype"): + call.set_arg("recvtype", type_var_1) + call.set_arg("sendtype", type_var_1) + else: + call.set_arg("datatype", type_var_1) + if call.has_arg("recvcount"): + call.set_arg("recvcount", count_1) + call.set_arg("sendcount", count_1) + else: + call.set_arg("count", count_1) c = CorrectMPICallFactory.get(func_to_use) c.set_rank_executing('not0') - c.set_arg("datatype", type_var_2) - c.set_arg("count", count_1) c.set_arg("comm", comm_var_name) - c.set_has_error() if c.has_arg("recvbuf"): c.set_arg("recvbuf", buf_name_2) c.set_arg("sendbuf", buf_name_2) else: c.set_arg("buffer", buf_name_2) + if c.has_arg("recvtype"): + c.set_arg("recvtype", type_var_2) + c.set_arg("sendtype", type_var_2) + else: + c.set_arg("datatype", type_var_2) + if c.has_arg("recvcount"): + c.set_arg("recvcount", count_1) + c.set_arg("sendcount", count_1) + else: + c.set_arg("count", count_1) tm.insert_instruction(c, after_instruction=call) @@ -187,4 +221,4 @@ class InvalidComErrorColl(ErrorGenerator): if type_1 not in correct_types_checked: correct_types_checked.add(type_1) - yield get_correct_case(type_1, 1,func_to_use, comm) + yield get_correct_case(type_1, 1, func_to_use, comm) -- GitLab