Skip to content
Snippets Groups Projects
Commit 558c56f8 authored by Jammer, Tim's avatar Jammer, Tim
Browse files

added sendrecv case

parent 3571a8df
No related branches found
No related tags found
1 merge request!4Devel tj
......@@ -21,31 +21,32 @@ class CorrectParameterFactory:
def get_buffer_free(self):
b = InstructionBlock("free")
b.register_operation(get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False)), kind='all')
b.register_operation(get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False)),
kind='all')
return b
def get(self, param, func=None):
if param == "BUFFER" or param == "buf" or param == "buffer" or param == "sendbuf" or param == "recvbuf":
if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf"]:
return self.buf_var_name
if param == "COUNT" or param == "count":
if param in ["COUNT", "count", "sendcount", "recvcount"]:
return str(self.buf_size)
if param == "DATATYPE" or param == "datatype":
if param in ["DATATYPE", "datatype", "sendtype", "recvtype"]:
return self.dtype[1]
if param == "DEST" or param == "dest":
if param in ["DEST", "dest"]:
return "0"
if param == "SRC" or param == "source":
if param in ["SRC", "source"]:
return "1"
if param == "RANK" or param == "root":
if param in ["RANK", "root"]:
return "0"
if param == "TAG" or param == "tag":
if param in ["TAG", "tag", "sendtag", "recvtag"]:
return str(self.tag)
if param == "COMM" or param == "comm":
if param in ["COMM", "comm"]:
return "MPI_COMM_WORLD"
if param == "STATUS" or param == "status":
if param in ["STATUS", "status"]:
return "MPI_STATUS_IGNORE"
if param == "OPERATION" or param == "op":
if param in ["OPERATION", "op"]:
return "MPI_SUM"
if param=="REQUEST" or param == "request":
if param in ["REQUEST", "request"]:
return "&request"
print("Not Implemented: " + param)
......
......@@ -26,16 +26,21 @@ def get_send_recv_template(send_func, recv_func):
The function is contained in a block named MPICALL with seperate calls for rank 1 and 2)
"""
# currently supported:
assert send_func in ["mpi_send", "mpi_ssend", "mpi_isend", "mpi_issend", "mpi_sendrecv", "mpi_rsend", "mpi_irsend",
"mpi_bsend", "mpi_ibsend"]
assert recv_func in ["mpi_recv", "mpi_irecv", "mpi_sendrecv"]
"mpi_bsend", "mpi_ibsend","mpi_sendrecv", "mpi_sendrecv_replace", "mpi_isendrecv",
"mpi_isendrecv_replace"]
assert recv_func in ["mpi_recv", "mpi_irecv", "mpi_sendrecv", "mpi_sendrecv_replace", "mpi_isendrecv",
"mpi_isendrecv_replace"]
if send_func == "mpi_sendrecv" or recv_func == "mpi_sendrecv":
sendrecv_funcs = ["mpi_sendrecv", "mpi_sendrecv_replace"]
if send_func in sendrecv_funcs or recv_func == sendrecv_funcs:
assert recv_func == send_func
assert False and "NOT IMPLEMENTED YET"
# default template generation only supports if both use same mechanism
if send_func in ["mpi_rsend", "mpi_irsend"]:
assert recv_func == "mpi_irecv"
assert recv_func == "mpi_irecv" # else: deadlock
tm = TemplateManager()
cf = CorrectParameterFactory()
......@@ -49,6 +54,12 @@ def get_send_recv_template(send_func, recv_func):
b.register_operation(MPICallFactory().mpi_buffer_attach("mpi_buf", buf_size))
tm.register_instruction_block(b)
if send_func in sendrecv_funcs:
# spilt send and recv buf
b = cf.get_buffer_alloc()
b.get_operation('all',0).set_name("recv_buf")
tm.register_instruction_block(b)
cmpicf = CorrectMPICallFactory()
send_func_creator_function = getattr(cmpicf, send_func)
s = send_func_creator_function()
......@@ -56,6 +67,15 @@ def get_send_recv_template(send_func, recv_func):
recv_func_creator_function = getattr(cmpicf, recv_func)
r = recv_func_creator_function()
if send_func in sendrecv_funcs:
# sending the second msg
s.set_arg("source", "0")
r.set_arg("dest", "1")
if s.has_arg("recvbuf"):
s.set_arg("recvbuf", "recv_buf")
if r.has_arg("recvbuf"):
r.set_arg("recvbuf", "recv_buf")
if send_func.startswith("mpi_i") or recv_func.startswith("mpi_i"):
b = InstructionBlock("MPI_REQUEST")
b.register_operation("MPI_Request request;", 'all')
......@@ -93,7 +113,12 @@ def get_send_recv_template(send_func, recv_func):
b.register_operation("int freed_size;")
b.register_operation(MPICallFactory().mpi_buffer_detach("mpi_buf", "&freed_size"))
b.register_operation("free(mpi_buf);")
tm.register_instruction_block(b)
if send_func in sendrecv_funcs:
# spilt send and recv buf
b = InstructionBlock("buf_free")
b.register_operation("free(recv_buf);")
tm.register_instruction_block(b)
return tm
......
......@@ -9,13 +9,17 @@ from scripts.Infrastructure.TemplateFactory import get_send_recv_template
from itertools import chain
sendrecv_funcs = ["mpi_sendrecv", "mpi_sendrecv_replace"]
class InvalidRankErrorP2P(ErrorGenerator):
invalid_ranks = ["-1", "nprocs", "MPI_PROC_NULL"]
functions_to_check = ["mpi_send",
"mpi_recv", "mpi_irecv",
"mpi_isend", "mpi_ssend", "mpi_issend", "mpi_rsend","mpi_irsend","mpi_bsend","mpi_ibsend"]
recv_funcs = ["mpi_recv", "mpi_irecv"]
"mpi_isend", "mpi_ssend", "mpi_issend", "mpi_rsend", "mpi_irsend", "mpi_bsend", "mpi_ibsend"
] + sendrecv_funcs + sendrecv_funcs
# chekc sendrecv funcs two times: the send and recv part
recv_funcs = ["mpi_recv", "mpi_irecv"] + sendrecv_funcs
def __init__(self):
pass
......@@ -41,6 +45,11 @@ class InvalidRankErrorP2P(ErrorGenerator):
check_receive = True
recv_func = send_func
send_func = "mpi_send"
if recv_func in sendrecv_funcs:
send_func = recv_func
if i % len(self.functions_to_check) >= len(self.functions_to_check) - len(sendrecv_funcs):
# check the send part of sendrecv
check_receive = False
tm = get_send_recv_template(send_func, recv_func)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment