From ad17d7e9649c1200b859f87bff6dd1217a232c5d Mon Sep 17 00:00:00 2001 From: Emmanuelle Saillard <emmanuelle.saillard@inria.fr> Date: Mon, 6 May 2024 15:26:17 +0200 Subject: [PATCH] fix recv buf issue --- scripts/Infrastructure/CorrectParameter.py | 13 ++++++++++++- scripts/Infrastructure/TemplateFactory.py | 13 +++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/scripts/Infrastructure/CorrectParameter.py b/scripts/Infrastructure/CorrectParameter.py index db8d5517d..d136ad1c1 100644 --- a/scripts/Infrastructure/CorrectParameter.py +++ b/scripts/Infrastructure/CorrectParameter.py @@ -8,10 +8,13 @@ from scripts.Infrastructure.AllocCall import AllocCall, get_free class CorrectParameterFactory: # default params buf_size = "10" + recvbuf_size = "10*nprocs" dtype = ['int', 'MPI_INT'] buf_size_bytes = f"{buf_size}*sizeof({dtype[0]})" + recvbuf_size_bytes = f"{recvbuf_size}*sizeof({dtype[0]})" tag = 0 buf_var_name = "buf" + recvbuf_var_name = "recvbuf" winbuf_var_name = "winbuf" def __init__(self): @@ -20,12 +23,20 @@ class CorrectParameterFactory: def get_buffer_alloc(self) -> AllocCall: return AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False) + def get_recvbuffer_alloc(self) -> AllocCall: + return AllocCall(self.dtype[0], self.recvbuf_size, self.recvbuf_var_name, use_malloc=False) + def get_buffer_free(self) -> Instruction: return get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False)) + + def get_recvbuffer_free(self) -> Instruction: + return get_free(AllocCall(self.dtype[0], self.recvbuf_size, self.recvbuf_var_name, use_malloc=False)) def get(self, param: str) -> str: - if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf", "origin_addr"]: + if param in ["BUFFER", "buf", "buffer", "sendbuf", "origin_addr"]: return self.buf_var_name + if param in ["recvbuf"]: + return self.recvbuf_var_name if param in ["COUNT", "count", "sendcount", "recvcount", "origin_count", "target_count", "result_count"]: return str(self.buf_size) if param in ["sendcounts", "recvcounts"]: diff --git a/scripts/Infrastructure/TemplateFactory.py b/scripts/Infrastructure/TemplateFactory.py index 1d36fa979..cbb163934 100644 --- a/scripts/Infrastructure/TemplateFactory.py +++ b/scripts/Infrastructure/TemplateFactory.py @@ -196,6 +196,11 @@ def get_collective_template(collective_func): alloc.set_identifier("ALLOC") alloc.set_name("buf") tm.register_instruction(alloc) + if c.has_arg("recvbuf"): + recvalloc = cf.get_recvbuffer_alloc() + recvalloc.set_identifier("ALLOC") + recvalloc.set_name("recvbuf") + tm.register_instruction(recvalloc) if c.has_arg("comm_cart"): tm.add_stack_variable("MPI_Comm") @@ -223,6 +228,9 @@ def get_collective_template(collective_func): if c.has_arg("buffer") or c.has_arg("sendbuf"): tm.register_instruction(cf.get_buffer_free(), identifier="FREE") + if c.has_arg("recvbuf"): + tm.register_instruction(cf.get_recvbuffer_free(), identifier="FREE") + return tm def get_two_collective_template(collective_func1, collective_func2): @@ -241,6 +249,10 @@ def get_two_collective_template(collective_func1, collective_func2): alloc.set_identifier("ALLOC") alloc.set_name("buf") tm.register_instruction(alloc) + recvalloc = cf.get_recvbuffer_alloc() + recvalloc.set_identifier("ALLOC") + recvalloc.set_name("recvbuf") + tm.register_instruction(recvalloc) cmpicf = CorrectMPICallFactory() call_creator_function = getattr(cmpicf, collective_func1) @@ -266,6 +278,7 @@ def get_two_collective_template(collective_func1, collective_func2): tm.register_instruction(CorrectMPICallFactory.mpi_wait(), rank_to_execute='all', identifier="WAIT") tm.register_instruction(cf.get_buffer_free(), identifier="FREE") + tm.register_instruction(cf.get_recvbuffer_free(), identifier="FREE") return tm -- GitLab