Skip to content
Snippets Groups Projects
Commit 14f28f6b authored by Jammer, Tim's avatar Jammer, Tim
Browse files

added intercomm_merge case

parent 221bf89f
No related branches found
No related tags found
1 merge request!6More Work on infrastructure IV
...@@ -220,21 +220,49 @@ def get_intercomm(comm_create_func, name): ...@@ -220,21 +220,49 @@ def get_intercomm(comm_create_func, name):
""" """
:param comm_create_func: the function used to create the new communicator :param comm_create_func: the function used to create the new communicator
:param name: name of the communicator variable :param name: name of the communicator variable
:return: instruction block with name "comm_create" that will initialize the communicator with the given initialization function, :return Tuple InstructionBlock, InstructionBlock: instruction block with name "comm_create" that will initialize the communicator with the given initialization function,
does include the allocation of a stack variable with the procided name for the communicator also includes a "intercomm_base_comm, that is not freed does include the allocation of a stack variable with the provided name for the communicator
may also contain other stack variables as needed
and the block containing all the necessary frees
""" """
assert comm_create_func in ["mpi_intercomm_create"] assert comm_create_func in ["mpi_intercomm_create", "mpi_intercom_create_from_group","mpi_intercomm_merge"]
assert name != "intercomm_base_comm" assert name != "intercomm_base_comm"
if comm_create_func == "mpi_intercomm_create":
b = InstructionBlock("comm_create")
b.register_operation("MPI_Comm intercomm_base_comm;")
b.register_operation(MPICallFactory().mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&intercomm_base_comm"))
b.register_operation("MPI_Comm " + name + ";")
b.register_operation(
MPICallFactory().mpi_intercomm_create("intercomm_base_comm", "0", "MPI_COMM_WORLD", "!(rank %2)", CorrectParameterFactory().get("tag"),"&" + name))
b.register_operation(MPICallFactory().mpi_comm_free("&intercomm_base_comm"))
return b
if comm_create_func == "mpi_intercom_create_from_group":
assert False and "NOT IMPLEMENTED"
b = InstructionBlock("comm_create") b = InstructionBlock("comm_create")
b_free = InstructionBlock("comm_free")
b.register_operation("MPI_Comm intercomm_base_comm;") b.register_operation("MPI_Comm intercomm_base_comm;")
b.register_operation(MPICallFactory().mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&intercomm_base_comm")) b.register_operation(MPICallFactory().mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&intercomm_base_comm"))
b.register_operation("MPI_Group intercomm_base_comm_group;")
b.register_operation(
MPICallFactory().mpi_comm_group("intercomm_base_comm", "&intercomm_base_comm_group"))
b.register_operation("MPI_Comm " + name + ";") b.register_operation("MPI_Comm " + name + ";")
b.register_operation( b.register_operation(
MPICallFactory().mpi_intercomm_create("intercomm_base_comm", "0", "MPI_COMM_WORLD", "1", CorrectParameterFactory().get("tag"),"&" + name), kind=0) MPICallFactory().mpi_intercomm_create("mpi_intercom_create_from_group", "0", "MPI_COMM_WORLD", "1", CorrectParameterFactory().get("tag"),"&" + name),)
return b
if comm_create_func== "mpi_intercomm_merge":
b = InstructionBlock("comm_create")
b.register_operation("MPI_Comm intercomm_base_comm;")
b.register_operation("MPI_Comm to_merge_intercomm_comm;")
b.register_operation( b.register_operation(
MPICallFactory().mpi_intercomm_create("intercomm_base_comm", "0", "MPI_COMM_WORLD", "0", CorrectParameterFactory().get("tag"),"&" + name), MPICallFactory().mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&intercomm_base_comm"))
kind='not0') b.register_operation("MPI_Comm " + name + ";")
#b.register_operation(MPICallFactory().mpi_comm_free("&intercomm_base_comm")) b.register_operation(
MPICallFactory().mpi_intercomm_create("intercomm_base_comm", "0", "MPI_COMM_WORLD", "!(rank %2)",
CorrectParameterFactory().get("tag"), "&to_merge_intercomm_comm"))
b.register_operation(MPICallFactory().mpi_intercomm_merge("to_merge_intercomm_comm","rank %2","&"+name))
b.register_operation(MPICallFactory().mpi_comm_free("&to_merge_intercomm_comm"))
b.register_operation(MPICallFactory().mpi_comm_free("&intercomm_base_comm"))
return b return b
return None
...@@ -18,10 +18,9 @@ class InvalidCommErrorP2P(ErrorGenerator): ...@@ -18,10 +18,9 @@ class InvalidCommErrorP2P(ErrorGenerator):
missmatching_comms = ["MPI_COMM_SELF", "mpi_comm_dup", "mpi_comm_dup_with_info", "mpi_comm_idup", missmatching_comms = ["MPI_COMM_SELF", "mpi_comm_dup", "mpi_comm_dup_with_info", "mpi_comm_idup",
"mpi_comm_idup_with_info", "mpi_comm_create","mpi_comm_create_group","mpi_comm_split","mpi_comm_split_type","mpi_comm_create_from_group" "mpi_comm_idup_with_info", "mpi_comm_create","mpi_comm_create_group","mpi_comm_split","mpi_comm_split_type","mpi_comm_create_from_group"
] ]
intercomms = ["mpi_intercomm_create"] intercomms = ["mpi_intercomm_create","mpi_intercomm_merge"]
# TODO test with: # TODO test with:
# 'MPI_Intercomm_create',
# 'MPI_Intercomm_create_from_groups', # 'MPI_Intercomm_create_from_groups',
# 'MPI_Intercomm_merge' # 'MPI_Intercomm_merge'
...@@ -103,7 +102,8 @@ class InvalidCommErrorP2P(ErrorGenerator): ...@@ -103,7 +102,8 @@ class InvalidCommErrorP2P(ErrorGenerator):
tm.get_block("MPICALL").get_operation(kind=0, index=0).set_has_error() tm.get_block("MPICALL").get_operation(kind=0, index=0).set_has_error()
# an intercomm has only one rank (the otehr group) # an intercomm has only one rank (the otehr group)
if comm_to_use in self.intercomms: if comm_to_use in self.intercomms and not comm_to_use=="mpi_intercomm_merge":
#intercomm merge results in same comm again
if tm.get_block("MPICALL").get_operation(kind=0, index=0).has_arg("source"): if tm.get_block("MPICALL").get_operation(kind=0, index=0).has_arg("source"):
tm.get_block("MPICALL").get_operation(kind=0, index=0).set_arg("source","0") tm.get_block("MPICALL").get_operation(kind=0, index=0).set_arg("source","0")
if tm.get_block("MPICALL").get_operation(kind=1, index=0).has_arg("source"): if tm.get_block("MPICALL").get_operation(kind=1, index=0).has_arg("source"):
...@@ -112,8 +112,6 @@ class InvalidCommErrorP2P(ErrorGenerator): ...@@ -112,8 +112,6 @@ class InvalidCommErrorP2P(ErrorGenerator):
if comm_to_use in self.missmatching_comms + self.intercomms and comm_to_use != "MPI_COMM_SELF": if comm_to_use in self.missmatching_comms + self.intercomms and comm_to_use != "MPI_COMM_SELF":
b = InstructionBlock("comm_free") b = InstructionBlock("comm_free")
b.register_operation(MPICallFactory().mpi_comm_free("&" + comm_to_use)) b.register_operation(MPICallFactory().mpi_comm_free("&" + comm_to_use))
if comm_to_use in self.intercomms:
b.register_operation(MPICallFactory().mpi_comm_free("&intercomm_base_comm"))
tm.register_instruction_block(b) tm.register_instruction_block(b)
return tm return tm
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment