From cb161b24dba64a26b31c97c3bb2fa7103a152a0b Mon Sep 17 00:00:00 2001
From: Tim Jammer <tim.jammer@tu-darmstadt.de>
Date: Mon, 29 Apr 2024 14:28:00 +0200
Subject: [PATCH] fix usage of multiple type arg functions

---
 scripts/errors/coll/ParamMatchingType.py | 64 ++++++++++++++++++------
 1 file changed, 49 insertions(+), 15 deletions(-)

diff --git a/scripts/errors/coll/ParamMatchingType.py b/scripts/errors/coll/ParamMatchingType.py
index 8a6feabab..9601ea314 100644
--- a/scripts/errors/coll/ParamMatchingType.py
+++ b/scripts/errors/coll/ParamMatchingType.py
@@ -24,9 +24,13 @@ def get_local_missmatch(type_1, type_2, func_to_use):
     type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 1, 1)
 
     for call in tm.get_instruction("MPICALL", return_list=True):
-        call.set_rank_executing(0)
-        call.set_arg("datatype", type_var_1)
+
         call.set_has_error()
+        if call.has_arg("recvtype"):
+            call.set_arg("recvtype", type_var_1)
+            call.set_arg("sendtype", type_var_1)
+        else:
+            call.set_arg("datatype", type_var_1)
         if call.has_arg("recvbuf"):
             call.set_arg("recvbuf", buf_name_2)
             call.set_arg("sendbuf", buf_name_2)
@@ -50,27 +54,43 @@ def get_global_missmatch(type_1, type_2, count_1, count_2, func_to_use, comm):
 
     for call in tm.get_instruction("MPICALL", return_list=True):
         call.set_rank_executing(0)
-        call.set_arg("datatype", type_var_1)
-        call.set_arg("count", count_1)
-        call.set_arg("comm", comm_var_name)
         call.set_has_error()
+        call.set_arg("comm", comm_var_name)
         if call.has_arg("recvbuf"):
             call.set_arg("recvbuf", buf_name_1)
             call.set_arg("sendbuf", buf_name_1)
         else:
             call.set_arg("buffer", buf_name_1)
+        if call.has_arg("recvtype"):
+            call.set_arg("recvtype", type_var_1)
+            call.set_arg("sendtype", type_var_1)
+        else:
+            call.set_arg("datatype", type_var_1)
+        if call.has_arg("recvcount"):
+            call.set_arg("recvcount", count_1)
+            call.set_arg("sendcount", count_1)
+        else:
+            call.set_arg("count", count_1)
 
         c = CorrectMPICallFactory.get(func_to_use)
         c.set_rank_executing('not0')
-        c.set_arg("datatype", type_var_2)
-        c.set_arg("count", count_2)
+        call.set_has_error()
         c.set_arg("comm", comm_var_name)
-        c.set_has_error()
         if c.has_arg("recvbuf"):
             c.set_arg("recvbuf", buf_name_2)
             c.set_arg("sendbuf", buf_name_2)
         else:
             c.set_arg("buffer", buf_name_2)
+        if c.has_arg("recvtype"):
+            c.set_arg("recvtype", type_var_2)
+            c.set_arg("sendtype", type_var_2)
+        else:
+            c.set_arg("datatype", type_var_2)
+        if c.has_arg("recvcount"):
+            c.set_arg("recvcount", count_2)
+            c.set_arg("sendcount", count_2)
+        else:
+            c.set_arg("count", count_2)
 
         tm.insert_instruction(c, after_instruction=call)
 
@@ -91,27 +111,41 @@ def get_correct_case(type_1, count_1, func_to_use, comm):
 
     for call in tm.get_instruction("MPICALL", return_list=True):
         call.set_rank_executing(0)
-        call.set_arg("datatype", type_var_1)
-        call.set_arg("count", count_1)
         call.set_arg("comm", comm_var_name)
-
         if call.has_arg("recvbuf"):
             call.set_arg("recvbuf", buf_name_1)
             call.set_arg("sendbuf", buf_name_1)
         else:
             call.set_arg("buffer", buf_name_1)
+        if call.has_arg("recvtype"):
+            call.set_arg("recvtype", type_var_1)
+            call.set_arg("sendtype", type_var_1)
+        else:
+            call.set_arg("datatype", type_var_1)
+        if call.has_arg("recvcount"):
+            call.set_arg("recvcount", count_1)
+            call.set_arg("sendcount", count_1)
+        else:
+            call.set_arg("count", count_1)
 
         c = CorrectMPICallFactory.get(func_to_use)
         c.set_rank_executing('not0')
-        c.set_arg("datatype", type_var_2)
-        c.set_arg("count", count_1)
         c.set_arg("comm", comm_var_name)
-        c.set_has_error()
         if c.has_arg("recvbuf"):
             c.set_arg("recvbuf", buf_name_2)
             c.set_arg("sendbuf", buf_name_2)
         else:
             c.set_arg("buffer", buf_name_2)
+        if c.has_arg("recvtype"):
+            c.set_arg("recvtype", type_var_2)
+            c.set_arg("sendtype", type_var_2)
+        else:
+            c.set_arg("datatype", type_var_2)
+        if c.has_arg("recvcount"):
+            c.set_arg("recvcount", count_1)
+            c.set_arg("sendcount", count_1)
+        else:
+            c.set_arg("count", count_1)
 
         tm.insert_instruction(c, after_instruction=call)
 
@@ -187,4 +221,4 @@ class InvalidComErrorColl(ErrorGenerator):
 
             if type_1 not in correct_types_checked:
                 correct_types_checked.add(type_1)
-                yield get_correct_case(type_1, 1,func_to_use, comm)
+                yield get_correct_case(type_1, 1, func_to_use, comm)
-- 
GitLab