From 8c67cc974dd8db4efc4a916e571740b50622d339 Mon Sep 17 00:00:00 2001 From: Tim Jammer <tim.jammer@tu-darmstadt.de> Date: Fri, 26 Apr 2024 21:06:45 +0200 Subject: [PATCH] fix usage of different communicators --- scripts/errors/coll/ParamMatchingType.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/scripts/errors/coll/ParamMatchingType.py b/scripts/errors/coll/ParamMatchingType.py index fca1b3ba8..4ddc1ed81 100644 --- a/scripts/errors/coll/ParamMatchingType.py +++ b/scripts/errors/coll/ParamMatchingType.py @@ -5,7 +5,7 @@ from scripts.Infrastructure.Variables import * from scripts.Infrastructure.ErrorGenerator import ErrorGenerator from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory from scripts.Infrastructure.TemplateFactory import get_collective_template, predefined_types, user_defined_types, \ - predefined_mpi_dtype_consants, get_type_buffers, get_bytes_size_for_type + predefined_mpi_dtype_consants, get_type_buffers, get_bytes_size_for_type, get_communicator, get_intercomm class InvalidComErrorColl(ErrorGenerator): @@ -94,6 +94,12 @@ class InvalidComErrorColl(ErrorGenerator): continue tm = get_collective_template(func_to_use) + comm_var_name = "MPI_COMM_WORLD" + if comm in self.comm_creators: + comm_var_name = get_communicator(comm, tm) + + if comm in self.intercomms: + comm_var_name = get_intercomm(comm, tm) type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 1, 1) tm.set_description("ParamMatching-Type-" + func_to_use, @@ -103,6 +109,7 @@ class InvalidComErrorColl(ErrorGenerator): call.set_rank_executing(0) call.set_arg("datatype", type_var_1) call.set_arg("count", 1) + call.set_arg("comm", comm_var_name) call.set_has_error() if call.has_arg("recvbuf"): call.set_arg("recvbuf", buf_name_1) @@ -114,6 +121,7 @@ class InvalidComErrorColl(ErrorGenerator): c.set_rank_executing('not0') c.set_arg("datatype", type_var_2) c.set_arg("count", 1) + c.set_arg("comm",comm_var_name) c.set_has_error() if c.has_arg("recvbuf"): c.set_arg("recvbuf", buf_name_2) @@ -133,6 +141,12 @@ class InvalidComErrorColl(ErrorGenerator): type_2), get_bytes_size_for_type( type_1)) + comm_var_name = "MPI_COMM_WORLD" + if comm in self.comm_creators: + comm_var_name = get_communicator(comm, tm) + + if comm in self.intercomms: + comm_var_name = get_intercomm(comm, tm) tm.set_description("ParamMatching-Type-" + func_to_use, "Wrong datatype matching: %s vs %s" % (type_1, type_2)) @@ -141,6 +155,7 @@ class InvalidComErrorColl(ErrorGenerator): call.set_rank_executing(0) call.set_arg("datatype", type_var_1) call.set_arg("count", get_bytes_size_for_type(type_2)) + call.set_arg("comm", comm_var_name) call.set_has_error() if call.has_arg("recvbuf"): call.set_arg("recvbuf", buf_name_1) @@ -151,7 +166,8 @@ class InvalidComErrorColl(ErrorGenerator): c = CorrectMPICallFactory.get(func_to_use) c.set_rank_executing('not0') c.set_arg("datatype", type_var_2) - call.set_arg("count", get_bytes_size_for_type(type_1)) + c.set_arg("count", get_bytes_size_for_type(type_1)) + c.set_arg("comm", comm_var_name) c.set_has_error() if c.has_arg("recvbuf"): c.set_arg("recvbuf", buf_name_2) -- GitLab