From 8c67cc974dd8db4efc4a916e571740b50622d339 Mon Sep 17 00:00:00 2001
From: Tim Jammer <tim.jammer@tu-darmstadt.de>
Date: Fri, 26 Apr 2024 21:06:45 +0200
Subject: [PATCH] fix usage of different communicators

---
 scripts/errors/coll/ParamMatchingType.py | 20 ++++++++++++++++++--
 1 file changed, 18 insertions(+), 2 deletions(-)

diff --git a/scripts/errors/coll/ParamMatchingType.py b/scripts/errors/coll/ParamMatchingType.py
index fca1b3ba8..4ddc1ed81 100644
--- a/scripts/errors/coll/ParamMatchingType.py
+++ b/scripts/errors/coll/ParamMatchingType.py
@@ -5,7 +5,7 @@ from scripts.Infrastructure.Variables import *
 from scripts.Infrastructure.ErrorGenerator import ErrorGenerator
 from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory
 from scripts.Infrastructure.TemplateFactory import get_collective_template, predefined_types, user_defined_types, \
-    predefined_mpi_dtype_consants, get_type_buffers, get_bytes_size_for_type
+    predefined_mpi_dtype_consants, get_type_buffers, get_bytes_size_for_type, get_communicator, get_intercomm
 
 
 class InvalidComErrorColl(ErrorGenerator):
@@ -94,6 +94,12 @@ class InvalidComErrorColl(ErrorGenerator):
                                 continue
 
                         tm = get_collective_template(func_to_use)
+                        comm_var_name = "MPI_COMM_WORLD"
+                        if comm in self.comm_creators:
+                            comm_var_name = get_communicator(comm, tm)
+
+                        if comm in self.intercomms:
+                            comm_var_name = get_intercomm(comm, tm)
                         type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 1, 1)
 
                         tm.set_description("ParamMatching-Type-" + func_to_use,
@@ -103,6 +109,7 @@ class InvalidComErrorColl(ErrorGenerator):
                             call.set_rank_executing(0)
                             call.set_arg("datatype", type_var_1)
                             call.set_arg("count", 1)
+                            call.set_arg("comm", comm_var_name)
                             call.set_has_error()
                             if call.has_arg("recvbuf"):
                                 call.set_arg("recvbuf", buf_name_1)
@@ -114,6 +121,7 @@ class InvalidComErrorColl(ErrorGenerator):
                             c.set_rank_executing('not0')
                             c.set_arg("datatype", type_var_2)
                             c.set_arg("count", 1)
+                            c.set_arg("comm",comm_var_name)
                             c.set_has_error()
                             if c.has_arg("recvbuf"):
                                 c.set_arg("recvbuf", buf_name_2)
@@ -133,6 +141,12 @@ class InvalidComErrorColl(ErrorGenerator):
                                                                                               type_2),
                                                                                           get_bytes_size_for_type(
                                                                                               type_1))
+                        comm_var_name = "MPI_COMM_WORLD"
+                        if comm in self.comm_creators:
+                            comm_var_name = get_communicator(comm, tm)
+
+                        if comm in self.intercomms:
+                            comm_var_name = get_intercomm(comm, tm)
 
                         tm.set_description("ParamMatching-Type-" + func_to_use,
                                            "Wrong datatype matching: %s vs %s" % (type_1, type_2))
@@ -141,6 +155,7 @@ class InvalidComErrorColl(ErrorGenerator):
                             call.set_rank_executing(0)
                             call.set_arg("datatype", type_var_1)
                             call.set_arg("count", get_bytes_size_for_type(type_2))
+                            call.set_arg("comm", comm_var_name)
                             call.set_has_error()
                             if call.has_arg("recvbuf"):
                                 call.set_arg("recvbuf", buf_name_1)
@@ -151,7 +166,8 @@ class InvalidComErrorColl(ErrorGenerator):
                             c = CorrectMPICallFactory.get(func_to_use)
                             c.set_rank_executing('not0')
                             c.set_arg("datatype", type_var_2)
-                            call.set_arg("count", get_bytes_size_for_type(type_1))
+                            c.set_arg("count", get_bytes_size_for_type(type_1))
+                            c.set_arg("comm", comm_var_name)
                             c.set_has_error()
                             if c.has_arg("recvbuf"):
                                 c.set_arg("recvbuf", buf_name_2)
-- 
GitLab