Skip to content
Snippets Groups Projects
Commit b9a5ebda authored by Jammer, Tim's avatar Jammer, Tim
Browse files

fix Communicator cases

parent 6d6f9037
Branches
No related tags found
No related merge requests found
......@@ -402,6 +402,7 @@ def get_communicator(comm_create_func: str, tm: TemplateManager, before_idx: int
"mpi_comm_idup_with_info", "mpi_comm_create", "mpi_comm_create_group", "mpi_comm_split",
"mpi_comm_split_type", "mpi_comm_create_from_group"]
newcomm = tm.add_stack_variable("MPI_Comm")
instr_list=[]
if comm_create_func.startswith("mpi_comm_i"):
req_name = tm.add_stack_variable("MPI_Request")
if comm_create_func in ["mpi_comm_create", "mpi_comm_create_group"]:
......@@ -409,7 +410,8 @@ def get_communicator(comm_create_func: str, tm: TemplateManager, before_idx: int
group = CorrectMPICallFactory.mpi_comm_group()
group.set_identifier(identifier)
group.set_arg("group", "&" + group_variable)
tm.insert_instruction(group, before_instruction=before_idx)
instr_list.append(group)
call = CorrectMPICallFactory.get(comm_create_func)
call.set_arg("newcomm", "&" + newcomm)
......@@ -418,16 +420,17 @@ def get_communicator(comm_create_func: str, tm: TemplateManager, before_idx: int
if comm_create_func in ["mpi_comm_create", "mpi_comm_create_group"]:
call.set_arg("group", group_variable) # not &group
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
instr_list.append(call)
if comm_create_func.startswith("mpi_comm_i"):
wait = MPICallFactory.mpi_wait("&" + req_name, "MPI_STATUS_IGNORE")
wait.set_identifier(identifier)
tm.insert_instruction(wait, before_instruction=before_idx)
instr_list.append(wait)
if comm_create_func in ["mpi_comm_create", "mpi_comm_create_group"]:
group_free = CorrectMPICallFactory.mpi_group_free()
group_free.set_arg("group", "&" + group_variable)
group_free.set_identifier(identifier)
tm.insert_instruction(group_free, before_instruction=before_idx)
instr_list.append(group_free)
tm.insert_instruction(instr_list,before_instruction=before_idx)
return newcomm
......@@ -438,19 +441,21 @@ def get_intercomm(comm_create_func: str, tm: TemplateManager, before_idx: int =
"""
assert comm_create_func in ["mpi_intercomm_create", "mpi_intercomm_create_from_groups", "mpi_intercomm_merge"]
instr_list=[]
if comm_create_func == "mpi_intercomm_create":
base_comm = tm.add_stack_variable("MPI_Comm")
intercomm = tm.add_stack_variable("MPI_Comm")
call = MPICallFactory.mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&" + base_comm)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
instr_list.append(call)
call = MPICallFactory.mpi_intercomm_create(base_comm, "0", "MPI_COMM_WORLD", "!(rank %2)",
CorrectParameterFactory().get("tag"), "&" + intercomm)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
instr_list.append(call)
call = MPICallFactory.mpi_comm_free("&" + base_comm)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
instr_list.append(call)
tm.insert_instruction(instr_list, before_instruction=before_idx)
return intercomm
if comm_create_func == "mpi_intercomm_create_from_groups":
intercomm = tm.add_stack_variable("MPI_Comm")
......@@ -459,24 +464,24 @@ def get_intercomm(comm_create_func: str, tm: TemplateManager, before_idx: int =
odd_group = tm.add_stack_variable("MPI_Comm")
call = MPICallFactory.mpi_comm_group("MPI_COMM_WORLD", "&" + world_group)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
instr_list.append(call)
call = MPICallFactory.mpi_comm_group("intercomm_base_comm", "&intercomm_base_comm_group")
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
instr_list.append(call)
inst = Instruction("int[3] triplet;"
"triplet[0] =0;"
"triplet[1] =size;"
"triplet[2] =2;", identifier=identifier)
tm.insert_instruction(inst, before_instruction=before_idx)
instr_list.append(inst)
call = MPICallFactory.mpi_group_incl(world_group, "1", "&triplet", even_group)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
instr_list.append(call)
inst = Instruction("triplet[0] =1;", identifier=identifier)
tm.insert_instruction(inst, before_instruction=before_idx)
instr_list.append(inst)
call = MPICallFactory.mpi_group_incl(world_group, "1", "&triplet", odd_group)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
instr_list.append(call)
call = MPICallFactory.mpi_intercomm_create_from_groups("(rank % 2 ? " + even_group + ":" + odd_group + ")",
"0",
"(!(rank % 2) ? " + even_group + ":" + odd_group + ")",
......@@ -486,7 +491,8 @@ def get_intercomm(comm_create_func: str, tm: TemplateManager, before_idx: int =
CorrectParameterFactory().get("errhandler"),
"&" + intercomm)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
instr_list.append(call)
tm.insert_instruction(instr_list, before_instruction=before_idx)
return intercomm
if comm_create_func == "mpi_intercomm_merge":
......@@ -495,20 +501,21 @@ def get_intercomm(comm_create_func: str, tm: TemplateManager, before_idx: int =
result_comm = tm.add_stack_variable("MPI_Comm")
call = MPICallFactory.mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&" + intercomm_base_comm)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
instr_list.append(call)
call = MPICallFactory.mpi_intercomm_create(intercomm_base_comm, "0", "MPI_COMM_WORLD", "!(rank %2)",
CorrectParameterFactory().get("tag"), "&" + to_merge_intercomm)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
instr_list.append(call)
call = MPICallFactory.mpi_intercomm_merge(to_merge_intercomm, "rank %2", "&" + result_comm)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
instr_list.append(call)
call = MPICallFactory.mpi_comm_free("&" + to_merge_intercomm)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
instr_list.append(call)
call = MPICallFactory.mpi_comm_free("&" + intercomm_base_comm)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
instr_list.append(call)
tm.insert_instruction(instr_list, before_instruction=before_idx)
return result_comm
return None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment