Skip to content
Snippets Groups Projects

P2P

Open
Jammer, Timrequested to merge
p2p into main
1 file
+ 22
16
Compare changes
  • Side-by-side
  • Inline
@@ -112,11 +112,12 @@ def get_send_recv_template(send_func: str = "mpi_isend", recv_func: str | typing
if recv_func in probe_pairs:
if recv_func in [["mpi_improbe", "mpi_mrecv"],
["mpi_improbe", "mpi_imrecv"]]:
tm.insert_instruction(Instruction("while (!" + flag_name + "){", rank=0,identifier="PROBELOOP"), before_instruction=r)
tm.insert_instruction(Instruction("while (!" + flag_name + "){", rank=0, identifier="PROBELOOP"),
before_instruction=r)
# insertion before the improbe call
tm.register_instruction("}", rank_to_execute=0,identifier="PROBELOOP") # end while
tm.register_instruction("}", rank_to_execute=0, identifier="PROBELOOP") # end while
# the matched recv
tm.register_instruction(CorrectMPICallFactory.get(recv_func[1]), rank_to_execute=0,identifier="MATCHEDRECEIVE")
tm.register_instruction(CorrectMPICallFactory.get(recv_func[1]), rank_to_execute=0, identifier="MATCHEDRECEIVE")
if send_func in persistent_send_funcs:
tm.register_instruction(CorrectMPICallFactory.mpi_start(), rank_to_execute=1, identifier="START")
@@ -192,8 +193,11 @@ def replace_wait(wait_call, tm, wait_func_to_use):
) # end while
tm.remove_instruction(wait_call)
return
status_to_use = wait_call.get_arg("status")
if status_to_use == "MPI_STATUS_IGNORE":
status_to_use = "MPI_STATUSES_IGNORE"
if wait_func_to_use == "mpi_waitall":
test_call = MPICallFactory.mpi_waitall("1", wait_call.get_arg("request"), "&" + wait_call.get_arg("status"))
test_call = MPICallFactory.mpi_waitall("1", wait_call.get_arg("request"), status_to_use)
test_call.set_rank_executing(wait_call.get_rank_executing())
test_call.set_identifier(wait_call.get_identifier())
tm.insert_instruction(test_call, before_instruction=wait_call) # insertion before the improbe call
@@ -202,19 +206,20 @@ def replace_wait(wait_call, tm, wait_func_to_use):
if wait_func_to_use == "mpi_testall":
flag_name = tm.add_stack_variable("int")
test_call = MPICallFactory.mpi_testall("1", wait_call.get_arg("request"), "&" + flag_name,
wait_call.get_arg("status"))
status_to_use)
test_call.set_rank_executing(wait_call.get_rank_executing())
test_call.set_identifier(wait_call.get_identifier())
tm.insert_instruction(Instruction("while (!" + flag_name + "){", rank=wait_call.get_rank_executing()),
before_instruction=wait_call)
tm.insert_instruction(test_call, before_instruction=wait_call) # insertion before the improbe call
tm.insert_instruction(Instruction("}", rank=wait_call.get_rank_executing()), before_instruction=wait_call) # end while
tm.insert_instruction(Instruction("}", rank=wait_call.get_rank_executing()),
before_instruction=wait_call) # end while
tm.remove_instruction(wait_call)
return
if wait_func_to_use == "mpi_waitany":
idx_name = tm.add_stack_variable("int")
test_call = MPICallFactory.mpi_waitany("1", wait_call.get_arg("request"), "&" + idx_name,
"&" + wait_call.get_arg("status"))
wait_call.get_arg("status"))
test_call.set_rank_executing(wait_call.get_rank_executing())
test_call.set_identifier(wait_call.get_identifier())
tm.insert_instruction(test_call, before_instruction=wait_call) # insertion before the improbe call
@@ -230,14 +235,15 @@ def replace_wait(wait_call, tm, wait_func_to_use):
tm.insert_instruction(Instruction("while (!" + flag_name + "){", rank=wait_call.get_rank_executing()),
before_instruction=wait_call)
tm.insert_instruction(test_call, before_instruction=wait_call) # insertion before the improbe call
tm.insert_instruction(Instruction("}", rank=wait_call.get_rank_executing()), before_instruction=wait_call) # end while
tm.insert_instruction(Instruction("}", rank=wait_call.get_rank_executing()),
before_instruction=wait_call) # end while
tm.remove_instruction(wait_call)
return
if wait_func_to_use == "mpi_waitsome":
idx_name = tm.add_stack_variable("int")
idx_array = tm.add_stack_variable("int")
test_call = MPICallFactory.mpi_waitsome("1", wait_call.get_arg("request"), "&" + idx_name,
"&" + idx_array, "&" + wait_call.get_arg("status"))
"&" + idx_array, status_to_use)
test_call.set_rank_executing(wait_call.get_rank_executing())
test_call.set_identifier(wait_call.get_identifier())
tm.insert_instruction(test_call, before_instruction=wait_call) # insertion before the improbe call
@@ -247,13 +253,14 @@ def replace_wait(wait_call, tm, wait_func_to_use):
flag_name = tm.add_stack_variable("int")
idx_array = tm.add_stack_variable("int")
test_call = MPICallFactory.mpi_testsome("1", wait_call.get_arg("request"), "&" + flag_name, "&" + idx_array,
wait_call.get_arg("status"))
status_to_use)
test_call.set_rank_executing(wait_call.get_rank_executing())
test_call.set_identifier(wait_call.get_identifier())
tm.insert_instruction(Instruction("while (!" + flag_name + "){", rank=wait_call.get_rank_executing()),
before_instruction=wait_call)
tm.insert_instruction(test_call, before_instruction=wait_call) # insertion before the improbe call
tm.insert_instruction(Instruction("}", rank=wait_call.get_rank_executing()), before_instruction=wait_call) # end while
tm.insert_instruction(Instruction("}", rank=wait_call.get_rank_executing()),
before_instruction=wait_call) # end while
tm.remove_instruction(wait_call)
return
assert False and "Not implemented"
@@ -463,24 +470,23 @@ def get_intercomm(comm_create_func: str, tm: TemplateManager, before_idx: int =
return intercomm
if comm_create_func == "mpi_intercomm_merge":
intercomm_base_comm = tm.add_stack_variable("MPI_Comm")
to_merge_intercomm = tm.add_stack_variable("MPI_Comm")
result_comm = tm.add_stack_variable("MPI_Comm")
call = MPICallFactory.mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&"+intercomm_base_comm)
call = MPICallFactory.mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&" + intercomm_base_comm)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
call = MPICallFactory.mpi_intercomm_create(intercomm_base_comm, "0", "MPI_COMM_WORLD", "!(rank %2)",
CorrectParameterFactory().get("tag"), "&"+to_merge_intercomm)
CorrectParameterFactory().get("tag"), "&" + to_merge_intercomm)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
call = MPICallFactory.mpi_intercomm_merge(to_merge_intercomm, "rank %2", "&" + result_comm)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
call = MPICallFactory.mpi_comm_free("&"+to_merge_intercomm)
call = MPICallFactory.mpi_comm_free("&" + to_merge_intercomm)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
call = MPICallFactory.mpi_comm_free("&"+intercomm_base_comm)
call = MPICallFactory.mpi_comm_free("&" + intercomm_base_comm)
call.set_identifier(identifier)
tm.insert_instruction(call, before_instruction=before_idx)
return result_comm
Loading