Skip to content
Snippets Groups Projects
Commit eb6b0e3c authored by Jammer, Tim's avatar Jammer, Tim
Browse files

Renamed operation with instruction (#6)

parent 636ff42f
No related branches found
No related tags found
1 merge request!9Infrastructure: Type Hints, Instruction class and lists of instructions
......@@ -18,7 +18,7 @@ class Invalid_negative_rank_error:
# include the buffer allocation in the template (all ranks execute it)
alloc_block = InstructionBlock("alloc")
alloc_block.register_operation(correct_params.get_buffer_alloc())
alloc_block.register_instruction(correct_params.get_buffer_alloc())
tm.register_instruction_block(alloc_block)
send = MPI_Call_Factory().mpi_send(
......@@ -33,9 +33,9 @@ class Invalid_negative_rank_error:
b = InstructionBlock()
# only rank 0 execute the send
b.register_operation(send, 0)
b.register_instruction(send, 0)
# only rank 1 execute the recv
b.register_operation(get_matching_recv(send), 1)
b.register_instruction(get_matching_recv(send), 1)
tm.register_instruction_block(b)
return tm
......
......@@ -33,7 +33,7 @@ class InstructionBlock:
assert not isinstance(name, int)
self.name = name
def register_operation(self, op: str | Instruction | typing.List[Instruction], kind: str | int = 'all'):
def register_instruction(self, op: str | Instruction | typing.List[Instruction], kind: str | int = 'all'):
"""
Registers an operation based on rank.
......@@ -106,7 +106,7 @@ class InstructionBlock:
return result_str
def has_operation(self, kind: int | str = 'all', index: int = 0) -> bool:
def has_instruction(self, kind: int | str = 'all', index: int = 0) -> bool:
"""
Checks if the Block has an operation with the given index and kind
Parameters:
......@@ -121,7 +121,7 @@ class InstructionBlock:
except (KeyError, IndexError) as e:
return False
def get_operation(self, kind: int | str = 'all', index: str | int = 0) -> Instruction | typing.List[Instruction]:
def get_instruction(self, kind: int | str = 'all', index: str | int = 0) -> Instruction | typing.List[Instruction]:
"""
Retrieve the operation registered. will Raise IndexError if not present
Parameters:
......@@ -136,7 +136,7 @@ class InstructionBlock:
as_int = int(index) # will Raise ValueError if not integer
return self.operations[kind][as_int]
def replace_operation(self, op: str | Instruction | typing.List[Instruction], kind: str | int = 'all',
def replace_instruction(self, op: str | Instruction | typing.List[Instruction], kind: str | int = 'all',
index: str | int = 0):
"""
Replace the operation registered. will Raise IndexError if not present
......@@ -164,7 +164,7 @@ class InstructionBlock:
raise IndexError("Operation Not Found")
self.operations[kind][as_int] = op
def insert_operation(self, op: str | Instruction | typing.List[Instruction], kind: str | int = 'all',
def insert_instruction(self, op: str | Instruction | typing.List[Instruction], kind: str | int = 'all',
before_index: int = 0):
"""
Inserts an operation before the specified one. will Raise IndexError if not present
......@@ -185,7 +185,7 @@ class InstructionBlock:
else:
self.operations[kind].insert(before_index, op)
def remove_operation(self, kind: str | int = 'all', index: str | int = 0):
def remove_instruction(self, kind: str | int = 'all', index: str | int = 0):
"""
Removes the operation registered. will Raise IndexError if not present
Parameters:
......
#! /usr/bin/python3
from __future__ import annotations
import typing
from scripts.Infrastructure.AllocCall import AllocCall
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory
from scripts.Infrastructure.InstructionBlock import InstructionBlock
......@@ -18,7 +22,7 @@ def get_default_template(mpi_func):
pass
def get_send_recv_template(send_func="mpi_isend", recv_func="mpi_irecv"):
def get_send_recv_template(send_func: str = "mpi_isend", recv_func: str | typing.Tuple[str, str] = "mpi_irecv"):
"""
Contructs a default template for the given mpi send recv function pair
Returns:
......@@ -27,50 +31,75 @@ def get_send_recv_template(send_func="mpi_isend", recv_func="mpi_irecv"):
"""
# currently supported:
assert send_func in ["mpi_send", "mpi_ssend", "mpi_isend", "mpi_issend", "mpi_sendrecv", "mpi_rsend", "mpi_irsend",
"mpi_bsend", "mpi_ibsend", "mpi_sendrecv", "mpi_sendrecv_replace", "mpi_isendrecv",
"mpi_isendrecv_replace", "mpi_send_init", "mpi_ssend_init", "mpi_bsend_init", "mpi_rsend_init",
"mpi_psend_init"]
assert recv_func in ["mpi_recv", "mpi_irecv", "mpi_sendrecv", "mpi_sendrecv_replace", "mpi_isendrecv",
"mpi_isendrecv_replace", "mpi_recv_init", "mpi_precv_init"]
# mprobe and mrecv combinations allowed
probe_pairs = [["mpi_mprobe", "mpi_mrecv"], ["mpi_mprobe", "mpi_imrecv"], ["mpi_improbe", "mpi_mrecv"],
["mpi_improbe", "mpi_imrecv"]]
sendrecv_funcs = ["mpi_sendrecv", "mpi_sendrecv_replace"]
persistent_send_funcs = ["mpi_send_init", "mpi_ssend_init", "mpi_bsend_init", "mpi_rsend_init", "mpi_psend_init"]
persistent_recv_funcs = ["mpi_recv_init", "mpi_precv_init"]
isend_funcs = ["mpi_isend", "mpi_issend", "mpi_irsend", "mpi_ibsend"]
irecv_funcs = ["mpi_irecv"]
assert (send_func in ["mpi_send", "mpi_ssend", "mpi_rsend", "mpi_bsend"]
+ sendrecv_funcs + isend_funcs + persistent_send_funcs)
assert recv_func in ["mpi_recv"] + sendrecv_funcs + irecv_funcs + persistent_recv_funcs + probe_pairs
if send_func in sendrecv_funcs or recv_func == sendrecv_funcs:
assert recv_func == send_func
# default template generation only supports if both use same mechanism
if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"]:
assert recv_func in ["mpi_irecv", "mpi_recv_init"] # else: deadlock
assert recv_func in ["mpi_irecv", "mpi_recv_init", "mpi_precv_init"] # else: deadlock
tm = TemplateManager()
cf = CorrectParameterFactory()
alloc_block = InstructionBlock("alloc")
alloc_block.register_operation(cf.get_buffer_alloc())
alloc_block.register_instruction(cf.get_buffer_alloc())
if send_func in sendrecv_funcs:
# spilt send and recv buf
alloc = cf.get_buffer_alloc()
alloc.set_name("recv_buf")
alloc_block.register_operation(alloc)
tm.register_instruction_block(alloc_block)
alloc_block.register_instruction(alloc)
if recv_func in probe_pairs:
alloc_block.register_instruction("MPI_Message msg;")
if send_func in ["mpi_bsend", "mpi_ibsend", "mpi_bsend_init"]:
b = InstructionBlock("buf_attach")
buf_size = "sizeof(int)*10 + MPI_BSEND_OVERHEAD"
b.register_operation(AllocCall("char", buf_size, "mpi_buf"))
b.register_operation(MPICallFactory().mpi_buffer_attach("mpi_buf", buf_size))
alloc_block.register_instruction(AllocCall("char", buf_size, "mpi_buf"))
alloc_block.register_instruction(MPICallFactory().mpi_buffer_attach("mpi_buf", buf_size))
if (send_func in isend_funcs + persistent_send_funcs or
recv_func in persistent_recv_funcs + irecv_funcs + probe_pairs):
alloc_block.register_instruction("MPI_Request request;", 'all')
if recv_func in probe_pairs:
alloc_block.register_instruction("int flag=0;")
tm.register_instruction_block(alloc_block)
# end preperation of all local variables
# before the send/recv block
if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"]:
b = InstructionBlock("SYNC")
# sender needs to wait until recv is started
b.register_instruction(CorrectMPICallFactory().mpi_barrier(), 1)
tm.register_instruction_block(b)
cmpicf = CorrectMPICallFactory()
send_func_creator_function = getattr(cmpicf, send_func)
s = send_func_creator_function()
recv_func_creator_function = getattr(cmpicf, recv_func)
r = recv_func_creator_function()
# get the send and recv block
recv_to_use = recv_func
if recv_func in probe_pairs:
recv_to_use = recv_func[0]
s = cmpicf.get(send_func)
r = cmpicf.get(recv_to_use)
if send_func in sendrecv_funcs:
# sending the second msg
......@@ -81,65 +110,68 @@ def get_send_recv_template(send_func="mpi_isend", recv_func="mpi_irecv"):
if r.has_arg("recvbuf"):
r.set_arg("recvbuf", "recv_buf")
if (send_func.startswith("mpi_i") or recv_func.startswith("mpi_i")
or send_func in persistent_send_funcs or recv_func in persistent_recv_funcs):
b = InstructionBlock("MPI_REQUEST")
b.register_operation("MPI_Request request;", 'all')
tm.register_instruction_block(b)
b = InstructionBlock("MPICALL")
b.register_instruction(s, 1)
b.register_instruction(r, 0)
if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"]:
b = InstructionBlock("SYNC")
b.register_operation(CorrectMPICallFactory().mpi_barrier(), 1)
tm.register_instruction_block(b)
if recv_func in probe_pairs:
if recv_func in [["mpi_improbe", "mpi_mrecv"],
["mpi_improbe", "mpi_imrecv"]]:
b.insert_instruction("while (!flag){", 0, 0)
# insertion before the improbe call
b.register_instruction("}", 0) # end while
b.register_instruction(CorrectMPICallFactory().get(recv_func[1]), 0)
b = InstructionBlock("MPICALL")
b.register_operation(s, 1)
b.register_operation(r, 0)
if send_func in persistent_send_funcs:
b.register_operation(cmpicf.mpi_start(), 1)
b.register_instruction(cmpicf.mpi_start(), 1)
if send_func == "mpi_psend_init":
# the pready takes a Request NOt a request*
b.register_operation(MPICallFactory().mpi_pready("0", cf.get("request")[1:]), 1)
b.register_instruction(MPICallFactory().mpi_pready("0", cf.get("request")[1:]), 1)
if recv_func in persistent_recv_funcs:
b.register_operation(cmpicf.mpi_start(), 0) #
b.register_instruction(cmpicf.mpi_start(), 0) #
# parrived is not necessary
tm.register_instruction_block(b)
# after send/recv
if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"]:
b = InstructionBlock("SYNC")
b.register_operation(CorrectMPICallFactory().mpi_barrier(), 1)
# barrier indicating recv has started
b.register_instruction(CorrectMPICallFactory().mpi_barrier(), 0)
tm.register_instruction_block(b)
if send_func.startswith("mpi_i") or send_func in persistent_send_funcs:
# wait for op to complete
if send_func in isend_funcs + persistent_send_funcs:
b = InstructionBlock("WAIT")
b.register_operation(CorrectMPICallFactory().mpi_wait(), 1)
b.register_instruction(CorrectMPICallFactory().mpi_wait(), 1)
tm.register_instruction_block(b)
if recv_func.startswith("mpi_i") or recv_func in persistent_recv_funcs:
if recv_func in irecv_funcs + persistent_recv_funcs + [["mpi_mprobe", "mpi_imrecv"],
["mpi_improbe", "mpi_imrecv"]]:
b = InstructionBlock("WAIT")
b.register_operation(CorrectMPICallFactory().mpi_wait(), 0)
b.register_instruction(CorrectMPICallFactory().mpi_wait(), 0)
tm.register_instruction_block(b)
# end MPI operation
# cleanup
free_block = InstructionBlock("buf_free")
if send_func in ["mpi_bsend", "mpi_ibsend", "mpi_bsend_init"]:
b = InstructionBlock("buf_detach")
b.register_operation("int freed_size;")
b.register_operation(MPICallFactory().mpi_buffer_detach("mpi_buf", "&freed_size"))
b.register_operation("free(mpi_buf);")
tm.register_instruction_block(b)
free_block = InstructionBlock("buf_detach")
free_block.register_instruction("int freed_size;")
free_block.register_instruction(MPICallFactory().mpi_buffer_detach("mpi_buf", "&freed_size"))
free_block.register_instruction("free(mpi_buf);")
free_block.register_instruction(cf.get_buffer_free())
free_block = InstructionBlock("buf_free")
free_block.register_operation(cf.get_buffer_free())
if send_func in sendrecv_funcs:
# spilt send and recv buf
b.register_operation("free(recv_buf);")
tm.register_instruction_block(free_block)
free_block.register_instruction("free(recv_buf);")
if send_func in persistent_send_funcs:
# spilt send and recv buf
b = InstructionBlock("req_free")
b.register_operation(cmpicf.mpi_request_free())
tm.register_instruction_block(b)
free_block.register_instruction(cmpicf.mpi_request_free())
tm.register_instruction_block(free_block)
return tm
......@@ -156,12 +188,12 @@ def get_collective_template(collective_func, seperate=True):
cf = CorrectParameterFactory()
alloc_block = InstructionBlock("alloc")
alloc_block.register_operation(cf.get_buffer_alloc())
alloc_block.register_instruction(cf.get_buffer_alloc())
if False:
# spilt send and recv buf
alloc = cf.get_buffer_alloc()
alloc.set_name("recv_buf")
alloc_block.register_operation(alloc)
alloc_block.register_instruction(alloc)
tm.register_instruction_block(alloc_block)
cmpicf = CorrectMPICallFactory()
......@@ -170,15 +202,15 @@ def get_collective_template(collective_func, seperate=True):
b = InstructionBlock("MPICALL")
if seperate:
b.register_operation(c, 1)
b.register_operation(c, 0)
b.register_instruction(c, 1)
b.register_instruction(c, 0)
else:
b.register_operation(c, 'all')
b.register_instruction(c, 'all')
tm.register_instruction_block(b)
free_block = InstructionBlock("buf_free")
free_block.register_operation(cf.get_buffer_free())
free_block.register_instruction(cf.get_buffer_free())
tm.register_instruction_block(free_block)
return tm
......@@ -194,7 +226,7 @@ def get_allocated_window(win_alloc_func, name, bufname, ctype, num_elements):
b = InstructionBlock("win_allocate")
# declare window
b.register_operation(f"MPI_Win {name};")
b.register_instruction(f"MPI_Win {name};")
# extract C data type and window buffer name
# dtype = CorrectParameterFactory().dtype[0]
......@@ -204,12 +236,12 @@ def get_allocated_window(win_alloc_func, name, bufname, ctype, num_elements):
if win_alloc_func == "mpi_win_allocate":
# MPI allocate, only declaration required
b.register_operation(f"{ctype}* {bufname};")
b.register_instruction(f"{ctype}* {bufname};")
win_allocate_call = CorrectMPICallFactory().mpi_win_allocate()
win_allocate_call.set_arg("baseptr", "&" + bufname)
elif win_alloc_func == "mpi_win_create":
# allocate buffer for win_create
b.register_operation(AllocCall(ctype, num_elements, bufname))
b.register_instruction(AllocCall(ctype, num_elements, bufname))
win_allocate_call = CorrectMPICallFactory().mpi_win_create()
win_allocate_call.set_arg("base", bufname)
else:
......@@ -222,14 +254,12 @@ def get_allocated_window(win_alloc_func, name, bufname, ctype, num_elements):
win_allocate_call.set_arg("size", buf_size_bytes)
win_allocate_call.set_arg("disp_unit", f"sizeof({ctype})")
b.register_operation(win_allocate_call)
b.register_instruction(win_allocate_call)
return b
def get_rma_call(rma_func, rank):
b = InstructionBlock(rma_func.replace('mpi_', ''))
cf = CorrectParameterFactory()
......@@ -237,21 +267,20 @@ def get_rma_call(rma_func, rank):
# request-based RMA call, add request
if rma_func.startswith("mpi_r"):
b.register_operation(f"MPI_Request " + cf.get("request")[1:] + ";", kind=rank)
b.register_instruction(f"MPI_Request " + cf.get("request")[1:] + ";", kind=rank)
# some RMA ops require result_addr
if rma_func in ["mpi_get_accumulate", "mpi_rget_accumulate", "mpi_fetch_and_op", "mpi_compare_and_swap"]:
b.register_operation(AllocCall(cf.dtype[0], cf.buf_size, cf.get("result_addr")), kind=rank)
b.register_instruction(AllocCall(cf.dtype[0], cf.buf_size, cf.get("result_addr")), kind=rank)
# some RMA ops require compare_addr
if rma_func in ["mpi_fetch_and_op", "mpi_compare_and_swap"]:
b.register_operation(AllocCall(cf.dtype[0], cf.buf_size, cf.get("compare_addr")), kind=rank)
b.register_instruction(AllocCall(cf.dtype[0], cf.buf_size, cf.get("compare_addr")), kind=rank)
b.register_operation(getattr(cfmpi, rma_func)(), kind=rank)
b.register_instruction(getattr(cfmpi, rma_func)(), kind=rank)
return b
def get_communicator(comm_create_func, name):
"""
:param comm_create_func: teh function used to create the new communicator
......@@ -262,12 +291,12 @@ def get_communicator(comm_create_func, name):
"mpi_comm_idup_with_info", "mpi_comm_create", "mpi_comm_create_group", "mpi_comm_split",
"mpi_comm_split_type", "mpi_comm_create_from_group"]
b = InstructionBlock("comm_create")
b.register_operation("MPI_Comm " + name + ";")
b.register_instruction("MPI_Comm " + name + ";")
if comm_create_func.startswith("mpi_comm_i"):
b.register_operation("MPI_Request comm_create_req;")
b.register_instruction("MPI_Request comm_create_req;")
if comm_create_func in ["mpi_comm_create", "mpi_comm_create_group"]:
b.register_operation("MPI_Group group;")
b.register_operation(CorrectMPICallFactory().mpi_comm_group())
b.register_instruction("MPI_Group group;")
b.register_instruction(CorrectMPICallFactory().mpi_comm_group())
cmpicf = CorrectMPICallFactory()
call_creator_function = getattr(cmpicf, comm_create_func)
call = call_creator_function()
......@@ -276,11 +305,11 @@ def get_communicator(comm_create_func, name):
call.set_arg("request", "&comm_create_req")
if comm_create_func in ["mpi_comm_create", "mpi_comm_create_group"]:
call.set_arg("group", "group") # not &group
b.register_operation(call)
b.register_instruction(call)
if comm_create_func.startswith("mpi_comm_i"):
b.register_operation(MPICallFactory().mpi_wait("&comm_create_req", "MPI_STATUS_IGNORE"))
b.register_instruction(MPICallFactory().mpi_wait("&comm_create_req", "MPI_STATUS_IGNORE"))
if comm_create_func in ["mpi_comm_create", "mpi_comm_create_group"]:
b.register_operation(cmpicf.mpi_group_free())
b.register_instruction(cmpicf.mpi_group_free())
return b
......@@ -297,31 +326,31 @@ def get_intercomm(comm_create_func, name):
assert name != "intercomm_base_comm"
if comm_create_func == "mpi_intercomm_create":
b = InstructionBlock("comm_create")
b.register_operation("MPI_Comm intercomm_base_comm;")
b.register_operation(
b.register_instruction("MPI_Comm intercomm_base_comm;")
b.register_instruction(
MPICallFactory().mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&intercomm_base_comm"))
b.register_operation("MPI_Comm " + name + ";")
b.register_operation(
b.register_instruction("MPI_Comm " + name + ";")
b.register_instruction(
MPICallFactory().mpi_intercomm_create("intercomm_base_comm", "0", "MPI_COMM_WORLD", "!(rank %2)",
CorrectParameterFactory().get("tag"), "&" + name))
b.register_operation(MPICallFactory().mpi_comm_free("&intercomm_base_comm"))
b.register_instruction(MPICallFactory().mpi_comm_free("&intercomm_base_comm"))
return b
if comm_create_func == "mpi_intercomm_create_from_groups":
b = InstructionBlock("comm_create")
b.register_operation("MPI_Group world_group,even_group,odd_group;")
b.register_operation(MPICallFactory().mpi_comm_group("MPI_COMM_WORLD", "&world_group"))
b.register_operation(
b.register_instruction("MPI_Group world_group,even_group,odd_group;")
b.register_instruction(MPICallFactory().mpi_comm_group("MPI_COMM_WORLD", "&world_group"))
b.register_instruction(
MPICallFactory().mpi_comm_group("intercomm_base_comm", "&intercomm_base_comm_group"))
b.register_operation("int[3] triplet;")
b.register_operation("triplet[0] =0;")
b.register_operation("triplet[1] =size;")
b.register_operation("triplet[2] =2;")
b.register_operation(MPICallFactory().mpi_group_incl("world_group", "1","&triplet", "even_group"))
b.register_operation("triplet[0] =1;")
b.register_operation(MPICallFactory().mpi_group_incl("world_group", "1","&triplet", "odd_group"))
b.register_operation("MPI_Comm " + name + ";")
b.register_operation(
b.register_instruction("int[3] triplet;")
b.register_instruction("triplet[0] =0;")
b.register_instruction("triplet[1] =size;")
b.register_instruction("triplet[2] =2;")
b.register_instruction(MPICallFactory().mpi_group_incl("world_group", "1", "&triplet", "even_group"))
b.register_instruction("triplet[0] =1;")
b.register_instruction(MPICallFactory().mpi_group_incl("world_group", "1", "&triplet", "odd_group"))
b.register_instruction("MPI_Comm " + name + ";")
b.register_instruction(
MPICallFactory().mpi_intercomm_create_from_groups("(rank % 2 ? even_group:odd_group)", "0",
"(!(rank % 2) ? even_group:odd_group)", "0",
CorrectParameterFactory().get("stringtag"),
......@@ -331,18 +360,18 @@ def get_intercomm(comm_create_func, name):
return b
if comm_create_func == "mpi_intercomm_merge":
b = InstructionBlock("comm_create")
b.register_operation("MPI_Comm intercomm_base_comm;")
b.register_operation("MPI_Comm to_merge_intercomm_comm;")
b.register_operation(
b.register_instruction("MPI_Comm intercomm_base_comm;")
b.register_instruction("MPI_Comm to_merge_intercomm_comm;")
b.register_instruction(
MPICallFactory().mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&intercomm_base_comm"))
b.register_operation("MPI_Comm " + name + ";")
b.register_operation(
b.register_instruction("MPI_Comm " + name + ";")
b.register_instruction(
MPICallFactory().mpi_intercomm_create("intercomm_base_comm", "0", "MPI_COMM_WORLD", "!(rank %2)",
CorrectParameterFactory().get("tag"), "&to_merge_intercomm_comm"))
b.register_operation(MPICallFactory().mpi_intercomm_merge("to_merge_intercomm_comm", "rank %2", "&" + name))
b.register_instruction(MPICallFactory().mpi_intercomm_merge("to_merge_intercomm_comm", "rank %2", "&" + name))
b.register_operation(MPICallFactory().mpi_comm_free("&to_merge_intercomm_comm"))
b.register_operation(MPICallFactory().mpi_comm_free("&intercomm_base_comm"))
b.register_instruction(MPICallFactory().mpi_comm_free("&to_merge_intercomm_comm"))
b.register_instruction(MPICallFactory().mpi_comm_free("&intercomm_base_comm"))
return b
return None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment