Skip to content
Snippets Groups Projects
Commit 36d5f394 authored by Jammer, Tim's avatar Jammer, Tim
Browse files

fix default send/recv Template and template.to_str

parent cdf5a99c
No related branches found
No related tags found
1 merge request!14Infrastructure: Remove Instructionblock
...@@ -27,10 +27,10 @@ class Instruction(object): ...@@ -27,10 +27,10 @@ class Instruction(object):
def set_identifier(self, identifier: str): def set_identifier(self, identifier: str):
self._identifier = identifier self._identifier = identifier
def get_ranks_executing(self) -> str | int: def get_rank_executing(self) -> str | int:
return self._rank return self._rank
def set_ranks_executing(self, rank: str | int): def set_rank_executing(self, rank: str | int):
if isinstance(rank, str): if isinstance(rank, str):
assert rank in ['all', 'not0'] assert rank in ['all', 'not0']
self._rank = rank self._rank = rank
......
...@@ -4,6 +4,7 @@ from __future__ import annotations ...@@ -4,6 +4,7 @@ from __future__ import annotations
import typing import typing
from scripts.Infrastructure.Instruction import Instruction from scripts.Infrastructure.Instruction import Instruction
from scripts.Infrastructure.MPICall import MPICall
template = """// @{generatedby}@ template = """// @{generatedby}@
/* ///////////////////////// The MPI Bug Bench //////////////////////// /* ///////////////////////// The MPI Bug Bench ////////////////////////
...@@ -114,14 +115,14 @@ class TemplateManager: ...@@ -114,14 +115,14 @@ class TemplateManager:
code_string = "" code_string = ""
current_rank = 'all' current_rank = 'all'
for inst in self._instructions: for inst in self._instructions:
if inst.get_ranks_executing != current_rank: if inst.get_rank_executing != current_rank:
if current_rank != 'all': if current_rank != 'all':
code_string = code_string + "}\n" # end previous if code_string = code_string + "}\n" # end previous if
current_rank = inst.get_ranks_executing() current_rank = inst.get_rank_executing()
if current_rank == 'not0': if current_rank == 'not0':
code_string = code_string + "if (rank!=0){\n" code_string = code_string + "if (rank!=0){\n"
elif current_rank != 'all': elif current_rank != 'all':
code_string = code_string + "if (rank==%D){\n" % current_rank code_string = code_string + "if (rank==%d){\n" % current_rank
code_string += str(inst) + "\n" code_string += str(inst) + "\n"
# end for inst # end for inst
...@@ -148,23 +149,35 @@ class TemplateManager: ...@@ -148,23 +149,35 @@ class TemplateManager:
.replace("@{version}@", version) .replace("@{version}@", version)
.replace("@{test_code}@", code_string)) .replace("@{test_code}@", code_string))
def register_instruction(self, inst: str | Instruction | typing.List[Instruction], identifier: str = None): def register_instruction(self, inst: str | Instruction | typing.List[Instruction], identifier: str = None,
rank_to_execute: str | int = None):
""" """
Registers an instruction block with the template. inserting it at the end, before the mpi finalize Registers an instruction block with the template. inserting it at the end, before the mpi finalize
Parameters: Parameters:
- inst: The instruction to register. - inst: The instruction to register.
- optional: identifier: overwirtes the identifier of the instructioneith the provided one (no override if None) - optional: identifier: overwirtes the identifier of the instructioneith the provided one (no override if None)
- optional: rank_to_execute: overwirtes the rank_to_execute of the instructioneith the provided one (no override if None)
""" """
if isinstance(inst, list): if isinstance(inst, list):
if identifier is not None: if identifier is not None:
for i in inst: for i in inst:
i.set_identifier(identifier) i.set_identifier(identifier)
if rank_to_execute is not None:
for i in inst:
i.set_ranks_executing(rank_to_execute)
self._instructions.extend(inst) self._instructions.extend(inst)
elif isinstance(inst, str):
if rank_to_execute is not None:
self._instructions.append(Instruction(inst, rank=rank_to_execute, identifier=identifier))
else:
# use default ('all')
self._instructions.append(Instruction(inst, identifier=identifier))
else: else:
if isinstance(inst, str):
self._instructions.append(Instruction(inst, identifier))
if identifier is not None: if identifier is not None:
inst.set_identifier(identifier) inst.set_identifier(identifier)
if rank_to_execute is not None:
for i in inst:
i.set_ranks_executing(rank_to_execute)
self._instructions.append(inst) self._instructions.append(inst)
def get_version(self) -> str: def get_version(self) -> str:
...@@ -174,9 +187,9 @@ class TemplateManager: ...@@ -174,9 +187,9 @@ class TemplateManager:
str: The MPI version used. str: The MPI version used.
""" """
max_v = "0.0" max_v = "0.0"
for block in self._blocks: for inst in self._instructions:
assert isinstance(block, InstructionBlock) if isinstance(inst,MPICall):
max_v = max(block.get_version(), max_v) max_v = max(inst.get_version(), max_v)
return max_v return max_v
def set_description(self, descr_short: str, descr_full: str): def set_description(self, descr_short: str, descr_full: str):
......
...@@ -5,7 +5,7 @@ import typing ...@@ -5,7 +5,7 @@ import typing
from scripts.Infrastructure.AllocCall import AllocCall from scripts.Infrastructure.AllocCall import AllocCall
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory
from scripts.Infrastructure.InstructionBlock import InstructionBlock from scripts.Infrastructure.Instruction import Instruction
from scripts.Infrastructure.MPICall import MPICall from scripts.Infrastructure.MPICall import MPICall
from scripts.Infrastructure.MPICallFactory import CorrectMPICallFactory, MPICallFactory from scripts.Infrastructure.MPICallFactory import CorrectMPICallFactory, MPICallFactory
from scripts.Infrastructure.Template import TemplateManager from scripts.Infrastructure.Template import TemplateManager
...@@ -24,7 +24,7 @@ def get_default_template(mpi_func): ...@@ -24,7 +24,7 @@ def get_default_template(mpi_func):
def get_send_recv_template(send_func: str = "mpi_isend", recv_func: str | typing.Tuple[str, str] = "mpi_irecv"): def get_send_recv_template(send_func: str = "mpi_isend", recv_func: str | typing.Tuple[str, str] = "mpi_irecv"):
""" """
Contructs a default template for the given mpi send recv function pair Contructs a default template for the given mpi send recv function pair it contains a send from rank 1 to rank 0
Returns: Returns:
TemplateManager Initialized with a default template TemplateManager Initialized with a default template
The function is contained in a block named MPICALL with seperate calls for rank 1 and 2) The function is contained in a block named MPICALL with seperate calls for rank 1 and 2)
...@@ -58,48 +58,40 @@ def get_send_recv_template(send_func: str = "mpi_isend", recv_func: str | typing ...@@ -58,48 +58,40 @@ def get_send_recv_template(send_func: str = "mpi_isend", recv_func: str | typing
tm = TemplateManager() tm = TemplateManager()
cf = CorrectParameterFactory() cf = CorrectParameterFactory()
alloc_block = InstructionBlock("alloc") tm.register_instruction(cf.get_buffer_alloc(), identifier="ALLOC")
alloc_block.register_instruction(cf.get_buffer_alloc())
if send_func in sendrecv_funcs: if send_func in sendrecv_funcs:
# spilt send and recv buf # spilt send and recv buf
alloc = cf.get_buffer_alloc() alloc = cf.get_buffer_alloc()
alloc.set_identifier("ALLOC")
alloc.set_name("recv_buf") alloc.set_name("recv_buf")
alloc_block.register_instruction(alloc) tm.register_instruction(alloc)
if recv_func in probe_pairs: if recv_func in probe_pairs:
alloc_block.register_instruction("MPI_Message msg;") tm.register_instruction("MPI_Message msg;", identifier="ALLOC")
if send_func in ["mpi_bsend", "mpi_ibsend", "mpi_bsend_init"]: if send_func in ["mpi_bsend", "mpi_ibsend", "mpi_bsend_init"]:
buf_size = "sizeof(int)*10 + MPI_BSEND_OVERHEAD" buf_size = "sizeof(int)*10 + MPI_BSEND_OVERHEAD"
alloc_block.register_instruction(AllocCall("char", buf_size, "mpi_buf")) tm.register_instruction(AllocCall("char", buf_size, "mpi_buf"), identifier="ALLOC")
alloc_block.register_instruction(MPICallFactory().mpi_buffer_attach("mpi_buf", buf_size)) tm.register_instruction(MPICallFactory().mpi_buffer_attach("mpi_buf", buf_size), identifier="ALLOC")
if (send_func in isend_funcs + persistent_send_funcs or if (send_func in isend_funcs + persistent_send_funcs or
recv_func in persistent_recv_funcs + irecv_funcs + probe_pairs): recv_func in persistent_recv_funcs + irecv_funcs + probe_pairs):
alloc_block.register_instruction("MPI_Request request;", 'all') tm.register_instruction("MPI_Request request;", identifier="ALLOC")
if recv_func in probe_pairs: if recv_func in probe_pairs:
alloc_block.register_instruction("int flag=0;") tm.register_instruction("int flag=0;", identifier="ALLOC")
tm.register_instruction_block(alloc_block)
# end preperation of all local variables # end preperation of all local variables
# before the send/recv block # before the send/recv block
if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"]: if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"]:
b = InstructionBlock("SYNC")
# sender needs to wait until recv is started # sender needs to wait until recv is started
b.register_instruction(CorrectMPICallFactory().mpi_barrier(), 1) tm.register_instruction(CorrectMPICallFactory.mpi_barrier(), identifier="SYNC", rank_to_execute=1)
tm.register_instruction_block(b)
cmpicf = CorrectMPICallFactory()
# get the send and recv block # get the send and recv block
recv_to_use = recv_func recv_to_use = recv_func
if recv_func in probe_pairs: if recv_func in probe_pairs:
recv_to_use = recv_func[0] recv_to_use = recv_func[0]
s = cmpicf.get(send_func) s = CorrectMPICallFactory.get(send_func)
r = cmpicf.get(recv_to_use) r = CorrectMPICallFactory.get(recv_to_use)
if send_func in sendrecv_funcs: if send_func in sendrecv_funcs:
# sending the second msg # sending the second msg
...@@ -110,68 +102,59 @@ def get_send_recv_template(send_func: str = "mpi_isend", recv_func: str | typing ...@@ -110,68 +102,59 @@ def get_send_recv_template(send_func: str = "mpi_isend", recv_func: str | typing
if r.has_arg("recvbuf"): if r.has_arg("recvbuf"):
r.set_arg("recvbuf", "recv_buf") r.set_arg("recvbuf", "recv_buf")
b = InstructionBlock("MPICALL") s.set_identifier("MPICALL")
b.register_instruction(s, 1) s.set_rank_executing(1)
b.register_instruction(r, 0) r.set_identifier("MPICALL")
r.set_rank_executing(0)
tm.register_instruction(s)
tm.register_instruction(r)
if recv_func in probe_pairs: if recv_func in probe_pairs:
if recv_func in [["mpi_improbe", "mpi_mrecv"], if recv_func in [["mpi_improbe", "mpi_mrecv"],
["mpi_improbe", "mpi_imrecv"]]: ["mpi_improbe", "mpi_imrecv"]]:
b.insert_instruction("while (!flag){", 0, 0) tm.insert_instruction(Instruction("while (!flag){", rank=0), before_instruction="MPICALL")
# insertion before the improbe call # insertion before the improbe call
b.register_instruction("}", 0) # end while tm.register_instruction("}", rank_to_execute=0) # end while
b.register_instruction(CorrectMPICallFactory().get(recv_func[1]), 0) tm.register_instruction(CorrectMPICallFactory.get(recv_func[1]), rank_to_execute=0)
if send_func in persistent_send_funcs: if send_func in persistent_send_funcs:
b.register_instruction(cmpicf.mpi_start(), 1) tm.register_instruction(CorrectMPICallFactory.mpi_start(), rank_to_execute=1, identifier="START")
if send_func == "mpi_psend_init": if send_func == "mpi_psend_init":
# the pready takes a Request NOt a request* # the pready takes a Request NOt a request*
b.register_instruction(MPICallFactory().mpi_pready("0", cf.get("request")[1:]), 1) tm.register_instruction(MPICallFactory.mpi_pready("0", cf.get("request")[1:]), rank_to_execute=1)
if recv_func in persistent_recv_funcs: if recv_func in persistent_recv_funcs:
b.register_instruction(cmpicf.mpi_start(), 0) # tm.register_instruction(CorrectMPICallFactory.mpi_start(), rank_to_execute=0, identifier="START")
# parrived is not necessary # parrived is not necessary
tm.register_instruction_block(b)
# after send/recv # after send/recv
if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"]: if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"]:
b = InstructionBlock("SYNC")
# barrier indicating recv has started # barrier indicating recv has started
b.register_instruction(CorrectMPICallFactory().mpi_barrier(), 0) tm.register_instruction(CorrectMPICallFactory.mpi_barrier(), rank_to_execute=0, identifier="SYNC")
tm.register_instruction_block(b)
# wait for op to complete # wait for op to complete
if send_func in isend_funcs + persistent_send_funcs:
b = InstructionBlock("WAIT")
b.register_instruction(CorrectMPICallFactory().mpi_wait(), 1)
tm.register_instruction_block(b)
if send_func in isend_funcs + persistent_send_funcs:
tm.register_instruction(CorrectMPICallFactory.mpi_wait(), rank_to_execute=1, identifier="WAIT")
if recv_func in irecv_funcs + persistent_recv_funcs + [["mpi_mprobe", "mpi_imrecv"], if recv_func in irecv_funcs + persistent_recv_funcs + [["mpi_mprobe", "mpi_imrecv"],
["mpi_improbe", "mpi_imrecv"]]: ["mpi_improbe", "mpi_imrecv"]]:
b = InstructionBlock("WAIT") tm.register_instruction(CorrectMPICallFactory.mpi_wait(), rank_to_execute=0, identifier="WAIT")
b.register_instruction(CorrectMPICallFactory().mpi_wait(), 0)
tm.register_instruction_block(b)
# end MPI operation # end MPI operation
# cleanup # cleanup
free_block = InstructionBlock("buf_free")
if send_func in ["mpi_bsend", "mpi_ibsend", "mpi_bsend_init"]: if send_func in ["mpi_bsend", "mpi_ibsend", "mpi_bsend_init"]:
free_block = InstructionBlock("buf_detach") tm.register_instruction("int freed_size;", identifier="FREE")
free_block.register_instruction("int freed_size;") tm.register_instruction(MPICallFactory.mpi_buffer_detach("mpi_buf", "&freed_size"), identifier="FREE")
free_block.register_instruction(MPICallFactory().mpi_buffer_detach("mpi_buf", "&freed_size")) tm.register_instruction("free(mpi_buf);", identifier="FREE")
free_block.register_instruction("free(mpi_buf);")
free_block.register_instruction(cf.get_buffer_free()) tm.register_instruction(cf.get_buffer_free(), identifier="FREE")
if send_func in sendrecv_funcs: if send_func in sendrecv_funcs:
# spilt send and recv buf # spilt send and recv buf
free_block.register_instruction("free(recv_buf);") tm.register_instruction("free(recv_buf);", identifier="FREE")
if send_func in persistent_send_funcs: if send_func in persistent_send_funcs:
free_block.register_instruction(cmpicf.mpi_request_free()) tm.register_instruction(CorrectMPICallFactory.mpi_request_free(), identifier="FREE")
tm.register_instruction_block(free_block)
return tm return tm
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment