Skip to content
Snippets Groups Projects
Verified Commit 17730c7d authored by Simon Schwitanski's avatar Simon Schwitanski :slight_smile:
Browse files

Add templates for window allocations and RMA calls

parent f994cc14
No related branches found
No related tags found
2 merge requests!9Infrastructure: Type Hints, Instruction class and lists of instructions,!8Draft: RMA
...@@ -184,6 +184,74 @@ def get_collective_template(collective_func, seperate=True): ...@@ -184,6 +184,74 @@ def get_collective_template(collective_func, seperate=True):
return tm return tm
def get_allocated_window(win_alloc_func, name, bufname, ctype, num_elements):
"""
Constructs a window allocation using Win_allocate or Win_create.
:param win_alloc_func: The window allocation to use (mpi_win_allocate or mpi_win_create).
:param name: name of the window
"""
b = InstructionBlock("win_allocate")
# declare window
b.register_operation(f"MPI_Win {name};")
# extract C data type and window buffer name
# dtype = CorrectParameterFactory().dtype[0]
# winbuf_name = CorrectParameterFactory().winbuf_var_name
# winbuf_size = CorrectParameterFactory().buf_size_bytes
win_allocate_call = None
if win_alloc_func == "mpi_win_allocate":
# MPI allocate, only declaration required
b.register_operation(f"{ctype}* {bufname};")
win_allocate_call = CorrectMPICallFactory().mpi_win_allocate()
win_allocate_call.set_arg("baseptr", "&" + bufname)
elif win_alloc_func == "mpi_win_create":
# allocate buffer for win_create
b.register_operation(AllocCall(ctype, num_elements, bufname))
win_allocate_call = CorrectMPICallFactory().mpi_win_create()
win_allocate_call.set_arg("base", bufname)
else:
assert False
# set common parameters for both calls
win_allocate_call.set_arg("win", "&" + name)
buf_size_bytes = num_elements + "*sizeof(" + ctype + ")"
win_allocate_call.set_arg("size", buf_size_bytes)
win_allocate_call.set_arg("disp_unit", f"sizeof({ctype})")
b.register_operation(win_allocate_call)
return b
def get_rma_call(rma_func, rank):
b = InstructionBlock(rma_func.replace('mpi_',''))
cf = CorrectParameterFactory()
cfmpi = CorrectMPICallFactory()
# request-based RMA call, add request
if rma_func.startswith("mpi_r"):
b.register_operation(f"MPI_Request " + cf.get("request")[1:] + ";", kind=rank)
# some RMA ops require result_addr
if rma_func in ["mpi_get_accumulate", "mpi_rget_accumulate", "mpi_fetch_and_op", "mpi_compare_and_swap"]:
b.register_operation(AllocCall(cf.dtype[0], cf.buf_size, cf.get("result_addr")), kind=rank)
# some RMA ops require compare_addr
if rma_func in ["mpi_fetch_and_op", "mpi_compare_and_swap"]:
b.register_operation(AllocCall(cf.dtype[0], cf.buf_size, cf.get("compare_addr")), kind=rank)
b.register_operation(getattr(cfmpi, rma_func)(), kind=rank)
return b
def get_communicator(comm_create_func, name): def get_communicator(comm_create_func, name):
""" """
:param comm_create_func: teh function used to create the new communicator :param comm_create_func: teh function used to create the new communicator
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment