Skip to content
Snippets Groups Projects
Commit b28fdf2f authored by Jammer, Tim's avatar Jammer, Tim
Browse files

Refactoring: CorrectParameterFactory Returns AllocCall instead of InstructionBlock

parent 736a10b8
No related branches found
No related tags found
1 merge request!5more work on infrastructure III
...@@ -17,7 +17,9 @@ class Invalid_negative_rank_error: ...@@ -17,7 +17,9 @@ class Invalid_negative_rank_error:
tm.set_description("Invalid Rank: %s" % self.invalid_ranks[self.rank_to_use]) tm.set_description("Invalid Rank: %s" % self.invalid_ranks[self.rank_to_use])
# include the buffer allocation in the template (all ranks execute it) # include the buffer allocation in the template (all ranks execute it)
tm.register_instruction_block(correct_params.get_buffer_alloc()) alloc_block = InstructionBlock("alloc")
alloc_block.register_operation(correct_params.get_buffer_alloc())
tm.register_instruction_block(alloc_block)
send = MPI_Call_Factory().mpi_send( send = MPI_Call_Factory().mpi_send(
correct_params.get("BUFFER"), correct_params.get("BUFFER"),
......
...@@ -15,15 +15,10 @@ class CorrectParameterFactory: ...@@ -15,15 +15,10 @@ class CorrectParameterFactory:
pass pass
def get_buffer_alloc(self): def get_buffer_alloc(self):
b = InstructionBlock("alloc") return AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False)
b.register_operation(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False), kind='all')
return b
def get_buffer_free(self): def get_buffer_free(self):
b = InstructionBlock("free") return get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False))
b.register_operation(get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False)),
kind='all')
return b
def get(self, param, func=None): def get(self, param, func=None):
if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf"]: if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf"]:
......
...@@ -48,7 +48,15 @@ def get_send_recv_template(send_func="mpi_isend", recv_func="mpi_irecv"): ...@@ -48,7 +48,15 @@ def get_send_recv_template(send_func="mpi_isend", recv_func="mpi_irecv"):
tm = TemplateManager() tm = TemplateManager()
cf = CorrectParameterFactory() cf = CorrectParameterFactory()
tm.register_instruction_block(cf.get_buffer_alloc()) alloc_block = InstructionBlock("alloc")
alloc_block.register_operation(cf.get_buffer_alloc())
if send_func in sendrecv_funcs:
# spilt send and recv buf
alloc = cf.get_buffer_alloc()
alloc.set_name("recv_buf")
alloc_block.register_operation(alloc)
tm.register_instruction_block(alloc_block)
if send_func in ["mpi_bsend", "mpi_ibsend", "mpi_bsend_init"]: if send_func in ["mpi_bsend", "mpi_ibsend", "mpi_bsend_init"]:
b = InstructionBlock("buf_attach") b = InstructionBlock("buf_attach")
...@@ -57,10 +65,8 @@ def get_send_recv_template(send_func="mpi_isend", recv_func="mpi_irecv"): ...@@ -57,10 +65,8 @@ def get_send_recv_template(send_func="mpi_isend", recv_func="mpi_irecv"):
b.register_operation(MPICallFactory().mpi_buffer_attach("mpi_buf", buf_size)) b.register_operation(MPICallFactory().mpi_buffer_attach("mpi_buf", buf_size))
tm.register_instruction_block(b) tm.register_instruction_block(b)
if send_func in sendrecv_funcs:
# spilt send and recv buf
b = cf.get_buffer_alloc()
b.get_operation('all', 0).set_name("recv_buf")
tm.register_instruction_block(b) tm.register_instruction_block(b)
cmpicf = CorrectMPICallFactory() cmpicf = CorrectMPICallFactory()
...@@ -121,7 +127,7 @@ def get_send_recv_template(send_func="mpi_isend", recv_func="mpi_irecv"): ...@@ -121,7 +127,7 @@ def get_send_recv_template(send_func="mpi_isend", recv_func="mpi_irecv"):
b.register_operation(CorrectMPICallFactory().mpi_wait(), 0) b.register_operation(CorrectMPICallFactory().mpi_wait(), 0)
tm.register_instruction_block(b) tm.register_instruction_block(b)
tm.register_instruction_block(cf.get_buffer_free())
if send_func in ["mpi_bsend", "mpi_ibsend","mpi_bsend_init"]: if send_func in ["mpi_bsend", "mpi_ibsend","mpi_bsend_init"]:
b = InstructionBlock("buf_detach") b = InstructionBlock("buf_detach")
...@@ -130,11 +136,12 @@ def get_send_recv_template(send_func="mpi_isend", recv_func="mpi_irecv"): ...@@ -130,11 +136,12 @@ def get_send_recv_template(send_func="mpi_isend", recv_func="mpi_irecv"):
b.register_operation("free(mpi_buf);") b.register_operation("free(mpi_buf);")
tm.register_instruction_block(b) tm.register_instruction_block(b)
free_block = InstructionBlock("buf_free")
free_block.register_operation(cf.get_buffer_free())
if send_func in sendrecv_funcs: if send_func in sendrecv_funcs:
# spilt send and recv buf # spilt send and recv buf
b = InstructionBlock("buf_free")
b.register_operation("free(recv_buf);") b.register_operation("free(recv_buf);")
tm.register_instruction_block(b) tm.register_instruction_block(free_block)
if send_func in persistent_send_funcs: if send_func in persistent_send_funcs:
# spilt send and recv buf # spilt send and recv buf
...@@ -157,7 +164,14 @@ def get_collective_template(collective_func, seperate=True): ...@@ -157,7 +164,14 @@ def get_collective_template(collective_func, seperate=True):
tm = TemplateManager() tm = TemplateManager()
cf = CorrectParameterFactory() cf = CorrectParameterFactory()
tm.register_instruction_block(cf.get_buffer_alloc()) alloc_block = InstructionBlock("alloc")
alloc_block.register_operation(cf.get_buffer_alloc())
if False:
# spilt send and recv buf
alloc = cf.get_buffer_alloc()
alloc.set_name("recv_buf")
alloc_block.register_operation(alloc)
tm.register_instruction_block(alloc_block)
cmpicf = CorrectMPICallFactory() cmpicf = CorrectMPICallFactory()
call_creator_function = getattr(cmpicf, collective_func) call_creator_function = getattr(cmpicf, collective_func)
...@@ -172,6 +186,8 @@ def get_collective_template(collective_func, seperate=True): ...@@ -172,6 +186,8 @@ def get_collective_template(collective_func, seperate=True):
tm.register_instruction_block(b) tm.register_instruction_block(b)
b.register_operation(cf.get_buffer_free()) free_block = InstructionBlock("buf_free")
free_block.register_operation(cf.get_buffer_free())
tm.register_instruction_block(free_block)
return tm return tm
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment