Skip to content
Snippets Groups Projects
Commit 983c7710 authored by Jammer, Tim's avatar Jammer, Tim
Browse files

added testcase for intercomm

parent 1c00014a
Branches
No related tags found
1 merge request!6More Work on infrastructure IV
...@@ -214,3 +214,27 @@ def get_communicator(comm_create_func, name): ...@@ -214,3 +214,27 @@ def get_communicator(comm_create_func, name):
if comm_create_func in ["mpi_comm_create", "mpi_comm_create_group"]: if comm_create_func in ["mpi_comm_create", "mpi_comm_create_group"]:
b.register_operation(cmpicf.mpi_group_free()) b.register_operation(cmpicf.mpi_group_free())
return b return b
def get_intercomm(comm_create_func, name):
"""
:param comm_create_func: the function used to create the new communicator
:param name: name of the communicator variable
:return: instruction block with name "comm_create" that will initialize the communicator with the given initialization function,
does include the allocation of a stack variable with the procided name for the communicator also includes a "intercomm_base_comm, that is not freed
"""
assert comm_create_func in ["mpi_intercomm_create"]
assert name != "intercomm_base_comm"
b = InstructionBlock("comm_create")
b.register_operation("MPI_Comm intercomm_base_comm;")
b.register_operation(MPICallFactory().mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&intercomm_base_comm"))
b.register_operation("MPI_Comm " + name + ";")
b.register_operation(
MPICallFactory().mpi_intercomm_create("intercomm_base_comm", "0", "MPI_COMM_WORLD", "1", CorrectParameterFactory().get("tag"),"&" + name), kind=0)
b.register_operation(
MPICallFactory().mpi_intercomm_create("intercomm_base_comm", "0", "MPI_COMM_WORLD", "0", CorrectParameterFactory().get("tag"),"&" + name),
kind='not0')
#b.register_operation(MPICallFactory().mpi_comm_free("&intercomm_base_comm"))
return b
...@@ -5,7 +5,7 @@ from scripts.Infrastructure.InstructionBlock import InstructionBlock ...@@ -5,7 +5,7 @@ from scripts.Infrastructure.InstructionBlock import InstructionBlock
from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv
from scripts.Infrastructure.Template import TemplateManager from scripts.Infrastructure.Template import TemplateManager
from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_communicator from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_communicator, get_intercomm
from itertools import chain from itertools import chain
...@@ -18,6 +18,7 @@ class InvalidCommErrorP2P(ErrorGenerator): ...@@ -18,6 +18,7 @@ class InvalidCommErrorP2P(ErrorGenerator):
missmatching_comms = ["MPI_COMM_SELF", "mpi_comm_dup", "mpi_comm_dup_with_info", "mpi_comm_idup", missmatching_comms = ["MPI_COMM_SELF", "mpi_comm_dup", "mpi_comm_dup_with_info", "mpi_comm_idup",
"mpi_comm_idup_with_info", "mpi_comm_create","mpi_comm_create_group","mpi_comm_split","mpi_comm_split_type","mpi_comm_create_from_group" "mpi_comm_idup_with_info", "mpi_comm_create","mpi_comm_create_group","mpi_comm_split","mpi_comm_split_type","mpi_comm_create_from_group"
] ]
intercomms = ["mpi_intercomm_create"]
# TODO test with: # TODO test with:
# 'MPI_Intercomm_create', # 'MPI_Intercomm_create',
...@@ -26,7 +27,7 @@ class InvalidCommErrorP2P(ErrorGenerator): ...@@ -26,7 +27,7 @@ class InvalidCommErrorP2P(ErrorGenerator):
# as extended testcases # as extended testcases
comms_to_check = invalid_comm + missmatching_comms comms_to_check = invalid_comm + missmatching_comms +intercomms
functions_to_check = ["mpi_send", functions_to_check = ["mpi_send",
"mpi_recv", "mpi_irecv", "mpi_recv", "mpi_irecv",
...@@ -72,6 +73,9 @@ class InvalidCommErrorP2P(ErrorGenerator): ...@@ -72,6 +73,9 @@ class InvalidCommErrorP2P(ErrorGenerator):
if comm_to_use in self.missmatching_comms and comm_to_use != "MPI_COMM_SELF": if comm_to_use in self.missmatching_comms and comm_to_use != "MPI_COMM_SELF":
b = get_communicator(comm_to_use,comm_to_use) b = get_communicator(comm_to_use,comm_to_use)
tm.insert_block(b, block_name="alloc") tm.insert_block(b, block_name="alloc")
if comm_to_use in self.intercomms:
b = get_intercomm(comm_to_use,comm_to_use)
tm.insert_block(b, block_name="alloc")
error_string = "ParamMatching" error_string = "ParamMatching"
if comm_to_use in self.invalid_comm: if comm_to_use in self.invalid_comm:
...@@ -98,9 +102,18 @@ class InvalidCommErrorP2P(ErrorGenerator): ...@@ -98,9 +102,18 @@ class InvalidCommErrorP2P(ErrorGenerator):
# missmatch is between both # missmatch is between both
tm.get_block("MPICALL").get_operation(kind=0, index=0).set_has_error() tm.get_block("MPICALL").get_operation(kind=0, index=0).set_has_error()
if comm_to_use in self.missmatching_comms and comm_to_use != "MPI_COMM_SELF": # an intercomm has only one rank (the otehr group)
if comm_to_use in self.intercomms:
if tm.get_block("MPICALL").get_operation(kind=0, index=0).has_arg("source"):
tm.get_block("MPICALL").get_operation(kind=0, index=0).set_arg("source","0")
if tm.get_block("MPICALL").get_operation(kind=1, index=0).has_arg("source"):
tm.get_block("MPICALL").get_operation(kind=1, index=0).set_arg("source", "0")
if comm_to_use in self.missmatching_comms + self.intercomms and comm_to_use != "MPI_COMM_SELF":
b = InstructionBlock("comm_free") b = InstructionBlock("comm_free")
b.register_operation(MPICallFactory().mpi_comm_free("&" + comm_to_use)) b.register_operation(MPICallFactory().mpi_comm_free("&" + comm_to_use))
if comm_to_use in self.intercomms:
b.register_operation(MPICallFactory().mpi_comm_free("&intercomm_base_comm"))
tm.register_instruction_block(b) tm.register_instruction_block(b)
return tm return tm
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment