diff --git a/scripts/Infrastructure/TemplateFactory.py b/scripts/Infrastructure/TemplateFactory.py index 7cee75f02c88aaa0110a99c5bad7dae39700df81..c17899344642d1a70bc86c80f80bece6eb1a62cb 100644 --- a/scripts/Infrastructure/TemplateFactory.py +++ b/scripts/Infrastructure/TemplateFactory.py @@ -77,7 +77,7 @@ def get_send_recv_template(send_func: str = "mpi_isend", recv_func: str | typing recv_func in persistent_recv_funcs + irecv_funcs + probe_pairs): tm.add_stack_variable("MPI_Request") if recv_func in probe_pairs: - tm.add_stack_variable("int") # the flag parameter + flag_name = tm.add_stack_variable("int") # end preperation of all local variables # before the send/recv block @@ -112,9 +112,10 @@ def get_send_recv_template(send_func: str = "mpi_isend", recv_func: str | typing if recv_func in probe_pairs: if recv_func in [["mpi_improbe", "mpi_mrecv"], ["mpi_improbe", "mpi_imrecv"]]: - tm.insert_instruction(Instruction("while (!flag){", rank=0), before_instruction=r) + tm.insert_instruction(Instruction("while (!" + flag_name + "){", rank=0), before_instruction=r) # insertion before the improbe call tm.register_instruction("}", rank_to_execute=0) # end while + # the matched recv tm.register_instruction(CorrectMPICallFactory.get(recv_func[1]), rank_to_execute=0) if send_func in persistent_send_funcs: @@ -401,6 +402,7 @@ def get_intercomm(comm_create_func, name, identifier="COMM"): return None + # todo also for send and non default args def get_matching_recv(call: MPICall) -> MPICall: correct_params = CorrectParameterFactory() @@ -415,4 +417,3 @@ def get_matching_recv(call: MPICall) -> MPICall: ) return recv -