Skip to content
Snippets Groups Projects
Commit 293fccff authored by Jammer, Tim's avatar Jammer, Tim
Browse files

Refactoring: moved the code generating a communicator

parent 51e6dd02
No related branches found
No related tags found
1 merge request!6More Work on infrastructure IV
...@@ -187,3 +187,32 @@ def get_collective_template(collective_func, seperate=True): ...@@ -187,3 +187,32 @@ def get_collective_template(collective_func, seperate=True):
tm.register_instruction_block(free_block) tm.register_instruction_block(free_block)
return tm return tm
def get_communicator(comm_create_func, name):
"""
:param comm_create_func: teh function used to create the new communicator
:param name: name of the communicator variable
:return: instruction block with name "comm_create" that will initialize the communicator with the given initialization function, does include the allocation of a stack variable with the procided name for the communicator
"""
b = InstructionBlock("comm_create")
b.register_operation("MPI_Comm " + name + ";")
if comm_create_func.startswith("mpi_comm_i"):
b.register_operation("MPI_Request comm_create_req;")
if comm_create_func in ["mpi_comm_create", "mpi_comm_create_group"]:
b.register_operation("MPI_Group group;")
b.register_operation(CorrectMPICallFactory().mpi_comm_group())
cmpicf = CorrectMPICallFactory()
call_creator_function = getattr(cmpicf, comm_create_func)
call = call_creator_function()
call.set_arg("newcomm", "&" + name)
if comm_create_func.startswith("mpi_comm_i"):
call.set_arg("request", "&comm_create_req")
if comm_create_func in ["mpi_comm_create", "mpi_comm_create_group"]:
call.set_arg("group", "group") # not &group
b.register_operation(call)
if comm_create_func.startswith("mpi_comm_i"):
b.register_operation(MPICallFactory().mpi_wait("&comm_create_req", "MPI_STATUS_IGNORE"))
if comm_create_func in ["mpi_comm_create", "mpi_comm_create_group"]:
b.register_operation(cmpicf.mpi_group_free())
return b
...@@ -5,7 +5,7 @@ from scripts.Infrastructure.InstructionBlock import InstructionBlock ...@@ -5,7 +5,7 @@ from scripts.Infrastructure.InstructionBlock import InstructionBlock
from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv
from scripts.Infrastructure.Template import TemplateManager from scripts.Infrastructure.Template import TemplateManager
from scripts.Infrastructure.TemplateFactory import get_send_recv_template from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_communicator
from itertools import chain from itertools import chain
...@@ -70,27 +70,7 @@ class InvalidCommErrorP2P(ErrorGenerator): ...@@ -70,27 +70,7 @@ class InvalidCommErrorP2P(ErrorGenerator):
tm = get_send_recv_template(send_func, recv_func) tm = get_send_recv_template(send_func, recv_func)
if comm_to_use in self.missmatching_comms and comm_to_use != "MPI_COMM_SELF": if comm_to_use in self.missmatching_comms and comm_to_use != "MPI_COMM_SELF":
b = InstructionBlock("comm_create") b = get_communicator(comm_to_use,comm_to_use)
b.register_operation("MPI_Comm " + comm_to_use + ";")
if comm_to_use.startswith("mpi_comm_i"):
b.register_operation("MPI_Request comm_create_req;")
if comm_to_use in ["mpi_comm_create","mpi_comm_create_group"]:
b.register_operation("MPI_Group group;")
b.register_operation(CorrectMPICallFactory().mpi_comm_group())
cmpicf = CorrectMPICallFactory()
cmpicf = CorrectMPICallFactory()
call_creator_function = getattr(cmpicf, comm_to_use)
call = call_creator_function()
call.set_arg("newcomm", "&" + comm_to_use)
if comm_to_use.startswith("mpi_comm_i"):
call.set_arg("request", "&comm_create_req")
if comm_to_use in ["mpi_comm_create","mpi_comm_create_group"]:
call.set_arg("group", "group") # not &group
b.register_operation(call)
if comm_to_use.startswith("mpi_comm_i"):
b.register_operation(MPICallFactory().mpi_wait("&comm_create_req", "MPI_STATUS_IGNORE"))
if comm_to_use in ["mpi_comm_create","mpi_comm_create_group"]:
b.register_operation(cmpicf.mpi_group_free())
tm.insert_block(b, block_name="alloc") tm.insert_block(b, block_name="alloc")
error_string = "ParamMatching" error_string = "ParamMatching"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment