Skip to content
Snippets Groups Projects
Commit 3571a8df authored by Jammer, Tim's avatar Jammer, Tim
Browse files

added Bsend case

parent 8ff0e14a
No related branches found
No related tags found
1 merge request!4Devel tj
#! /usr/bin/python3 #! /usr/bin/python3
from scripts.Infrastructure.AllocCall import AllocCall
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory
from scripts.Infrastructure.InstructionBlock import InstructionBlock from scripts.Infrastructure.InstructionBlock import InstructionBlock
from scripts.Infrastructure.MPICall import MPI_Call from scripts.Infrastructure.MPICall import MPI_Call
from scripts.Infrastructure.MPICallFactory import CorrectMPICallFactory from scripts.Infrastructure.MPICallFactory import CorrectMPICallFactory, MPICallFactory
from scripts.Infrastructure.Template import TemplateManager from scripts.Infrastructure.Template import TemplateManager
...@@ -25,7 +26,8 @@ def get_send_recv_template(send_func, recv_func): ...@@ -25,7 +26,8 @@ def get_send_recv_template(send_func, recv_func):
The function is contained in a block named MPICALL with seperate calls for rank 1 and 2) The function is contained in a block named MPICALL with seperate calls for rank 1 and 2)
""" """
assert send_func in ["mpi_send", "mpi_ssend", "mpi_isend", "mpi_issend", "mpi_sendrecv", "mpi_rsend", "mpi_irsend"] assert send_func in ["mpi_send", "mpi_ssend", "mpi_isend", "mpi_issend", "mpi_sendrecv", "mpi_rsend", "mpi_irsend",
"mpi_bsend", "mpi_ibsend"]
assert recv_func in ["mpi_recv", "mpi_irecv", "mpi_sendrecv"] assert recv_func in ["mpi_recv", "mpi_irecv", "mpi_sendrecv"]
if send_func == "mpi_sendrecv" or recv_func == "mpi_sendrecv": if send_func == "mpi_sendrecv" or recv_func == "mpi_sendrecv":
...@@ -40,6 +42,13 @@ def get_send_recv_template(send_func, recv_func): ...@@ -40,6 +42,13 @@ def get_send_recv_template(send_func, recv_func):
tm.register_instruction_block(cf.get_buffer_alloc()) tm.register_instruction_block(cf.get_buffer_alloc())
if send_func in ["mpi_bsend", "mpi_ibsend"]:
b = InstructionBlock("buf_attach")
buf_size = "sizeof(int)*10 + MPI_BSEND_OVERHEAD"
b.register_operation(AllocCall("char", buf_size, "mpi_buf"))
b.register_operation(MPICallFactory().mpi_buffer_attach("mpi_buf", buf_size))
tm.register_instruction_block(b)
cmpicf = CorrectMPICallFactory() cmpicf = CorrectMPICallFactory()
send_func_creator_function = getattr(cmpicf, send_func) send_func_creator_function = getattr(cmpicf, send_func)
s = send_func_creator_function() s = send_func_creator_function()
...@@ -79,6 +88,14 @@ def get_send_recv_template(send_func, recv_func): ...@@ -79,6 +88,14 @@ def get_send_recv_template(send_func, recv_func):
b.register_operation(cf.get_buffer_free()) b.register_operation(cf.get_buffer_free())
if send_func in ["mpi_bsend", "mpi_ibsend"]:
b = InstructionBlock("buf_detach")
b.register_operation("int freed_size;")
b.register_operation(MPICallFactory().mpi_buffer_detach("mpi_buf","&freed_size"))
b.register_operation("free(mpi_buf);")
tm.register_instruction_block(b)
return tm return tm
......
...@@ -14,7 +14,7 @@ class InvalidRankErrorP2P(ErrorGenerator): ...@@ -14,7 +14,7 @@ class InvalidRankErrorP2P(ErrorGenerator):
invalid_ranks = ["-1", "nprocs", "MPI_PROC_NULL"] invalid_ranks = ["-1", "nprocs", "MPI_PROC_NULL"]
functions_to_check = ["mpi_send", functions_to_check = ["mpi_send",
"mpi_recv", "mpi_irecv", "mpi_recv", "mpi_irecv",
"mpi_isend", "mpi_ssend", "mpi_issend", "mpi_rsend","mpi_irsend"] "mpi_isend", "mpi_ssend", "mpi_issend", "mpi_rsend","mpi_irsend","mpi_bsend","mpi_ibsend"]
recv_funcs = ["mpi_recv", "mpi_irecv"] recv_funcs = ["mpi_recv", "mpi_irecv"]
def __init__(self): def __init__(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment