Skip to content
Snippets Groups Projects
Commit ad17d7e9 authored by Emmanuelle Saillard's avatar Emmanuelle Saillard
Browse files

fix recv buf issue

parent caff3603
No related branches found
No related tags found
No related merge requests found
...@@ -8,10 +8,13 @@ from scripts.Infrastructure.AllocCall import AllocCall, get_free ...@@ -8,10 +8,13 @@ from scripts.Infrastructure.AllocCall import AllocCall, get_free
class CorrectParameterFactory: class CorrectParameterFactory:
# default params # default params
buf_size = "10" buf_size = "10"
recvbuf_size = "10*nprocs"
dtype = ['int', 'MPI_INT'] dtype = ['int', 'MPI_INT']
buf_size_bytes = f"{buf_size}*sizeof({dtype[0]})" buf_size_bytes = f"{buf_size}*sizeof({dtype[0]})"
recvbuf_size_bytes = f"{recvbuf_size}*sizeof({dtype[0]})"
tag = 0 tag = 0
buf_var_name = "buf" buf_var_name = "buf"
recvbuf_var_name = "recvbuf"
winbuf_var_name = "winbuf" winbuf_var_name = "winbuf"
def __init__(self): def __init__(self):
...@@ -20,12 +23,20 @@ class CorrectParameterFactory: ...@@ -20,12 +23,20 @@ class CorrectParameterFactory:
def get_buffer_alloc(self) -> AllocCall: def get_buffer_alloc(self) -> AllocCall:
return AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False) return AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False)
def get_recvbuffer_alloc(self) -> AllocCall:
return AllocCall(self.dtype[0], self.recvbuf_size, self.recvbuf_var_name, use_malloc=False)
def get_buffer_free(self) -> Instruction: def get_buffer_free(self) -> Instruction:
return get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False)) return get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False))
def get_recvbuffer_free(self) -> Instruction:
return get_free(AllocCall(self.dtype[0], self.recvbuf_size, self.recvbuf_var_name, use_malloc=False))
def get(self, param: str) -> str: def get(self, param: str) -> str:
if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf", "origin_addr"]: if param in ["BUFFER", "buf", "buffer", "sendbuf", "origin_addr"]:
return self.buf_var_name return self.buf_var_name
if param in ["recvbuf"]:
return self.recvbuf_var_name
if param in ["COUNT", "count", "sendcount", "recvcount", "origin_count", "target_count", "result_count"]: if param in ["COUNT", "count", "sendcount", "recvcount", "origin_count", "target_count", "result_count"]:
return str(self.buf_size) return str(self.buf_size)
if param in ["sendcounts", "recvcounts"]: if param in ["sendcounts", "recvcounts"]:
......
...@@ -196,6 +196,11 @@ def get_collective_template(collective_func): ...@@ -196,6 +196,11 @@ def get_collective_template(collective_func):
alloc.set_identifier("ALLOC") alloc.set_identifier("ALLOC")
alloc.set_name("buf") alloc.set_name("buf")
tm.register_instruction(alloc) tm.register_instruction(alloc)
if c.has_arg("recvbuf"):
recvalloc = cf.get_recvbuffer_alloc()
recvalloc.set_identifier("ALLOC")
recvalloc.set_name("recvbuf")
tm.register_instruction(recvalloc)
if c.has_arg("comm_cart"): if c.has_arg("comm_cart"):
tm.add_stack_variable("MPI_Comm") tm.add_stack_variable("MPI_Comm")
...@@ -223,6 +228,9 @@ def get_collective_template(collective_func): ...@@ -223,6 +228,9 @@ def get_collective_template(collective_func):
if c.has_arg("buffer") or c.has_arg("sendbuf"): if c.has_arg("buffer") or c.has_arg("sendbuf"):
tm.register_instruction(cf.get_buffer_free(), identifier="FREE") tm.register_instruction(cf.get_buffer_free(), identifier="FREE")
if c.has_arg("recvbuf"):
tm.register_instruction(cf.get_recvbuffer_free(), identifier="FREE")
return tm return tm
def get_two_collective_template(collective_func1, collective_func2): def get_two_collective_template(collective_func1, collective_func2):
...@@ -241,6 +249,10 @@ def get_two_collective_template(collective_func1, collective_func2): ...@@ -241,6 +249,10 @@ def get_two_collective_template(collective_func1, collective_func2):
alloc.set_identifier("ALLOC") alloc.set_identifier("ALLOC")
alloc.set_name("buf") alloc.set_name("buf")
tm.register_instruction(alloc) tm.register_instruction(alloc)
recvalloc = cf.get_recvbuffer_alloc()
recvalloc.set_identifier("ALLOC")
recvalloc.set_name("recvbuf")
tm.register_instruction(recvalloc)
cmpicf = CorrectMPICallFactory() cmpicf = CorrectMPICallFactory()
call_creator_function = getattr(cmpicf, collective_func1) call_creator_function = getattr(cmpicf, collective_func1)
...@@ -266,6 +278,7 @@ def get_two_collective_template(collective_func1, collective_func2): ...@@ -266,6 +278,7 @@ def get_two_collective_template(collective_func1, collective_func2):
tm.register_instruction(CorrectMPICallFactory.mpi_wait(), rank_to_execute='all', identifier="WAIT") tm.register_instruction(CorrectMPICallFactory.mpi_wait(), rank_to_execute='all', identifier="WAIT")
tm.register_instruction(cf.get_buffer_free(), identifier="FREE") tm.register_instruction(cf.get_buffer_free(), identifier="FREE")
tm.register_instruction(cf.get_recvbuffer_free(), identifier="FREE")
return tm return tm
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment