diff --git a/scripts/Infrastructure/CorrectParameter.py b/scripts/Infrastructure/CorrectParameter.py index db8d5517dd03772fb196b6a6fd44aba7c5c91be6..d136ad1c11d1e39d0f84a1fc7a4d6b368194b16d 100644 --- a/scripts/Infrastructure/CorrectParameter.py +++ b/scripts/Infrastructure/CorrectParameter.py @@ -8,10 +8,13 @@ from scripts.Infrastructure.AllocCall import AllocCall, get_free class CorrectParameterFactory: # default params buf_size = "10" + recvbuf_size = "10*nprocs" dtype = ['int', 'MPI_INT'] buf_size_bytes = f"{buf_size}*sizeof({dtype[0]})" + recvbuf_size_bytes = f"{recvbuf_size}*sizeof({dtype[0]})" tag = 0 buf_var_name = "buf" + recvbuf_var_name = "recvbuf" winbuf_var_name = "winbuf" def __init__(self): @@ -20,12 +23,20 @@ class CorrectParameterFactory: def get_buffer_alloc(self) -> AllocCall: return AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False) + def get_recvbuffer_alloc(self) -> AllocCall: + return AllocCall(self.dtype[0], self.recvbuf_size, self.recvbuf_var_name, use_malloc=False) + def get_buffer_free(self) -> Instruction: return get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False)) + + def get_recvbuffer_free(self) -> Instruction: + return get_free(AllocCall(self.dtype[0], self.recvbuf_size, self.recvbuf_var_name, use_malloc=False)) def get(self, param: str) -> str: - if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf", "origin_addr"]: + if param in ["BUFFER", "buf", "buffer", "sendbuf", "origin_addr"]: return self.buf_var_name + if param in ["recvbuf"]: + return self.recvbuf_var_name if param in ["COUNT", "count", "sendcount", "recvcount", "origin_count", "target_count", "result_count"]: return str(self.buf_size) if param in ["sendcounts", "recvcounts"]: diff --git a/scripts/Infrastructure/TemplateFactory.py b/scripts/Infrastructure/TemplateFactory.py index 1d36fa97984d93c37e73ffefb5d33a0b38eeccbe..cbb1639343718d1f39697e87e685082a2f274f06 100644 --- a/scripts/Infrastructure/TemplateFactory.py +++ b/scripts/Infrastructure/TemplateFactory.py @@ -196,6 +196,11 @@ def get_collective_template(collective_func): alloc.set_identifier("ALLOC") alloc.set_name("buf") tm.register_instruction(alloc) + if c.has_arg("recvbuf"): + recvalloc = cf.get_recvbuffer_alloc() + recvalloc.set_identifier("ALLOC") + recvalloc.set_name("recvbuf") + tm.register_instruction(recvalloc) if c.has_arg("comm_cart"): tm.add_stack_variable("MPI_Comm") @@ -223,6 +228,9 @@ def get_collective_template(collective_func): if c.has_arg("buffer") or c.has_arg("sendbuf"): tm.register_instruction(cf.get_buffer_free(), identifier="FREE") + if c.has_arg("recvbuf"): + tm.register_instruction(cf.get_recvbuffer_free(), identifier="FREE") + return tm def get_two_collective_template(collective_func1, collective_func2): @@ -241,6 +249,10 @@ def get_two_collective_template(collective_func1, collective_func2): alloc.set_identifier("ALLOC") alloc.set_name("buf") tm.register_instruction(alloc) + recvalloc = cf.get_recvbuffer_alloc() + recvalloc.set_identifier("ALLOC") + recvalloc.set_name("recvbuf") + tm.register_instruction(recvalloc) cmpicf = CorrectMPICallFactory() call_creator_function = getattr(cmpicf, collective_func1) @@ -266,6 +278,7 @@ def get_two_collective_template(collective_func1, collective_func2): tm.register_instruction(CorrectMPICallFactory.mpi_wait(), rank_to_execute='all', identifier="WAIT") tm.register_instruction(cf.get_buffer_free(), identifier="FREE") + tm.register_instruction(cf.get_recvbuffer_free(), identifier="FREE") return tm