Skip to content
Snippets Groups Projects
Commit 8c67cc97 authored by Jammer, Tim's avatar Jammer, Tim
Browse files

fix usage of different communicators

parent e76730f3
No related branches found
No related tags found
No related merge requests found
...@@ -5,7 +5,7 @@ from scripts.Infrastructure.Variables import * ...@@ -5,7 +5,7 @@ from scripts.Infrastructure.Variables import *
from scripts.Infrastructure.ErrorGenerator import ErrorGenerator from scripts.Infrastructure.ErrorGenerator import ErrorGenerator
from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory
from scripts.Infrastructure.TemplateFactory import get_collective_template, predefined_types, user_defined_types, \ from scripts.Infrastructure.TemplateFactory import get_collective_template, predefined_types, user_defined_types, \
predefined_mpi_dtype_consants, get_type_buffers, get_bytes_size_for_type predefined_mpi_dtype_consants, get_type_buffers, get_bytes_size_for_type, get_communicator, get_intercomm
class InvalidComErrorColl(ErrorGenerator): class InvalidComErrorColl(ErrorGenerator):
...@@ -94,6 +94,12 @@ class InvalidComErrorColl(ErrorGenerator): ...@@ -94,6 +94,12 @@ class InvalidComErrorColl(ErrorGenerator):
continue continue
tm = get_collective_template(func_to_use) tm = get_collective_template(func_to_use)
comm_var_name = "MPI_COMM_WORLD"
if comm in self.comm_creators:
comm_var_name = get_communicator(comm, tm)
if comm in self.intercomms:
comm_var_name = get_intercomm(comm, tm)
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 1, 1) type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 1, 1)
tm.set_description("ParamMatching-Type-" + func_to_use, tm.set_description("ParamMatching-Type-" + func_to_use,
...@@ -103,6 +109,7 @@ class InvalidComErrorColl(ErrorGenerator): ...@@ -103,6 +109,7 @@ class InvalidComErrorColl(ErrorGenerator):
call.set_rank_executing(0) call.set_rank_executing(0)
call.set_arg("datatype", type_var_1) call.set_arg("datatype", type_var_1)
call.set_arg("count", 1) call.set_arg("count", 1)
call.set_arg("comm", comm_var_name)
call.set_has_error() call.set_has_error()
if call.has_arg("recvbuf"): if call.has_arg("recvbuf"):
call.set_arg("recvbuf", buf_name_1) call.set_arg("recvbuf", buf_name_1)
...@@ -114,6 +121,7 @@ class InvalidComErrorColl(ErrorGenerator): ...@@ -114,6 +121,7 @@ class InvalidComErrorColl(ErrorGenerator):
c.set_rank_executing('not0') c.set_rank_executing('not0')
c.set_arg("datatype", type_var_2) c.set_arg("datatype", type_var_2)
c.set_arg("count", 1) c.set_arg("count", 1)
c.set_arg("comm",comm_var_name)
c.set_has_error() c.set_has_error()
if c.has_arg("recvbuf"): if c.has_arg("recvbuf"):
c.set_arg("recvbuf", buf_name_2) c.set_arg("recvbuf", buf_name_2)
...@@ -133,6 +141,12 @@ class InvalidComErrorColl(ErrorGenerator): ...@@ -133,6 +141,12 @@ class InvalidComErrorColl(ErrorGenerator):
type_2), type_2),
get_bytes_size_for_type( get_bytes_size_for_type(
type_1)) type_1))
comm_var_name = "MPI_COMM_WORLD"
if comm in self.comm_creators:
comm_var_name = get_communicator(comm, tm)
if comm in self.intercomms:
comm_var_name = get_intercomm(comm, tm)
tm.set_description("ParamMatching-Type-" + func_to_use, tm.set_description("ParamMatching-Type-" + func_to_use,
"Wrong datatype matching: %s vs %s" % (type_1, type_2)) "Wrong datatype matching: %s vs %s" % (type_1, type_2))
...@@ -141,6 +155,7 @@ class InvalidComErrorColl(ErrorGenerator): ...@@ -141,6 +155,7 @@ class InvalidComErrorColl(ErrorGenerator):
call.set_rank_executing(0) call.set_rank_executing(0)
call.set_arg("datatype", type_var_1) call.set_arg("datatype", type_var_1)
call.set_arg("count", get_bytes_size_for_type(type_2)) call.set_arg("count", get_bytes_size_for_type(type_2))
call.set_arg("comm", comm_var_name)
call.set_has_error() call.set_has_error()
if call.has_arg("recvbuf"): if call.has_arg("recvbuf"):
call.set_arg("recvbuf", buf_name_1) call.set_arg("recvbuf", buf_name_1)
...@@ -151,7 +166,8 @@ class InvalidComErrorColl(ErrorGenerator): ...@@ -151,7 +166,8 @@ class InvalidComErrorColl(ErrorGenerator):
c = CorrectMPICallFactory.get(func_to_use) c = CorrectMPICallFactory.get(func_to_use)
c.set_rank_executing('not0') c.set_rank_executing('not0')
c.set_arg("datatype", type_var_2) c.set_arg("datatype", type_var_2)
call.set_arg("count", get_bytes_size_for_type(type_1)) c.set_arg("count", get_bytes_size_for_type(type_1))
c.set_arg("comm", comm_var_name)
c.set_has_error() c.set_has_error()
if c.has_arg("recvbuf"): if c.has_arg("recvbuf"):
c.set_arg("recvbuf", buf_name_2) c.set_arg("recvbuf", buf_name_2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment