From eb977b6bf87d09aa2f649e313072089117339aa0 Mon Sep 17 00:00:00 2001
From: Tim Jammer <tim.jammer@tu-darmstadt.de>
Date: Mon, 29 Apr 2024 12:36:01 +0200
Subject: [PATCH] Refactoring: extract method

---
 scripts/errors/dtypes/DtypeMissmatch.py | 130 +++++++++++-------------
 1 file changed, 58 insertions(+), 72 deletions(-)

diff --git a/scripts/errors/dtypes/DtypeMissmatch.py b/scripts/errors/dtypes/DtypeMissmatch.py
index ae3e86e43..a763721f9 100644
--- a/scripts/errors/dtypes/DtypeMissmatch.py
+++ b/scripts/errors/dtypes/DtypeMissmatch.py
@@ -16,6 +16,58 @@ from itertools import chain
 
 from scripts.Infrastructure.Variables import *
 
+# TODO refactoring into different file
+# test if the tool chan deal with messages send over different communicators
+predefined_comms = ["MPI_COMM_WORLD"]
+comm_creators = ["mpi_comm_dup", "mpi_comm_dup_with_info", "mpi_comm_idup",
+                 "mpi_comm_idup_with_info", "mpi_comm_create", "mpi_comm_create_group", "mpi_comm_split",
+                 "mpi_comm_split_type", "mpi_comm_create_from_group"
+                 ]
+intercomms = ["mpi_intercomm_create", "mpi_intercomm_merge", "mpi_intercomm_create_from_groups"]
+
+
+def get_local_missmatch(type_1, type_2, send_func, recv_func):
+    tm = get_send_recv_template(send_func, recv_func)
+    tm.set_description("LocalParameterMissmatch-Dtype-" + send_func,
+                       "datatype missmatch: Buffer: " + type_1 + " MPI_Call: " + type_2)
+    type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 10, 10)
+    # local missmatch
+    for call in tm.get_instruction(identifier="MPICALL", return_list=True):
+        call.set_has_error()
+        call.set_arg("buf", buf_name_1)
+        call.set_arg("datatype", type_var_2)
+
+    return tm
+
+
+def get_global_missmatch(type_1, type_2, size_1, size_2, send_func, recv_func, comm):
+    tm = get_send_recv_template(send_func, recv_func)
+    tm.set_description("GlobalParameterMissmatch-Dtype-" + send_func,
+                       "datatype missmatch: Rank0: " + type_1 + " Rank1: " + type_2)
+    comm_var_name = "MPI_COMM_WORLD"
+    if comm in comm_creators:
+        comm_var_name = get_communicator(comm, tm)
+
+    if comm in intercomms:
+        comm_var_name = get_intercomm(comm, tm)
+
+    type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, size_1, size_2)
+    call = tm.get_instruction(identifier="MPICALL", rank_excuting=0)
+    call.set_has_error()
+    call.set_arg("buf", buf_name_1)
+    call.set_arg("datatype", type_var_1)
+    call.set_arg("count", size_1)
+    call.set_arg("comm", comm_var_name)
+    call = tm.get_instruction(identifier="MPICALL", rank_excuting=1)
+    call.set_has_error()
+    call.set_arg("buf", buf_name_2)
+    call.set_arg("datatype", type_var_2)
+    call.set_arg("count", size_2)
+    call.set_arg("comm", comm_var_name)
+
+    return tm
+
+
 class DtypeMissmatch(ErrorGenerator):
     invalid_bufs = [CorrectParameterFactory().buf_var_name, "NULL"]
     send_funcs = ["mpi_send",
@@ -27,14 +79,6 @@ class DtypeMissmatch(ErrorGenerator):
 
     sendrecv_funcs = ["mpi_sendrecv", "mpi_sendrecv_replace"]
 
-    # test if the tool chan deal with messages send over different communicators
-    predefined_comms = ["MPI_COMM_WORLD"]
-    comm_creators = ["mpi_comm_dup", "mpi_comm_dup_with_info", "mpi_comm_idup",
-                     "mpi_comm_idup_with_info", "mpi_comm_create", "mpi_comm_create_group", "mpi_comm_split",
-                     "mpi_comm_split_type", "mpi_comm_create_from_group"
-                     ]
-    intercomms = ["mpi_intercomm_create", "mpi_intercomm_merge", "mpi_intercomm_create_from_groups"]
-
     def __init__(self):
         pass
 
@@ -79,20 +123,10 @@ class DtypeMissmatch(ErrorGenerator):
                         checked_types.add(type_1)
                         checked_types.add(type_2)
 
-                        tm = get_send_recv_template(send_func, recv_func)
-                        tm.set_description("LocalParameterMissmatch-Dtype-" + send_func,
-                                           "datatype missmatch: Buffer: " + type_1 + " MPI_Call: " + type_2)
-                        type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 10, 10)
-
-                        # local missmatch
-                        for call in tm.get_instruction(identifier="MPICALL", return_list=True):
-                            call.set_has_error()
-                            call.set_arg("buf", buf_name_1)
-                            call.set_arg("datatype", type_var_2)
-
+                        tm = get_local_missmatch(type_1, type_2, send_func, recv_func)
                         yield tm
 
-                        for comm in self.predefined_comms + self.comm_creators + self.intercomms:
+                        for comm in predefined_comms + comm_creators + intercomms:
                             if comm != "MPI_COMM_WORLD" and generate_level < REAL_WORLD_TEST_LEVEL:
                                 continue
                             if generate_level == REAL_WORLD_TEST_LEVEL:
@@ -103,61 +137,13 @@ class DtypeMissmatch(ErrorGenerator):
                                                          recv_func, datatype=type_2.lower(), communicator=comm)):
                                     # not relevant in real world
                                     continue
-                            tm = get_send_recv_template(send_func, recv_func)
-                            tm.set_description("GlobalParameterMissmatch-Dtype-" + send_func,
-                                               "datatype missmatch: Rank0: " + type_1 + " Rank1: " + type_2)
-                            comm_var_name = "MPI_COMM_WORLD"
-                            if comm in self.comm_creators:
-                                comm_var_name = get_communicator(comm, tm)
-
-                            if comm in self.intercomms:
-                                comm_var_name = get_intercomm(comm, tm)
-
-                            type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 1, 1)
-
-                            # global missmatch with size 1
-                            call = tm.get_instruction(identifier="MPICALL", rank_excuting=0)
-                            call.set_has_error()
-                            call.set_arg("buf", buf_name_1)
-                            call.set_arg("datatype", type_var_1)
-                            call.set_arg("count", 1)
-                            call.set_arg("comm", comm_var_name)
-                            call = tm.get_instruction(identifier="MPICALL", rank_excuting=1)
-                            call.set_has_error()
-                            call.set_arg("buf", buf_name_2)
-                            call.set_arg("datatype", type_var_2)
-                            call.set_arg("count", 1)
-                            call.set_arg("comm", comm_var_name)
+                            tm = get_global_missmatch(type_1, type_2, 1, 1, send_func, recv_func, comm)
 
                             yield tm
-
-                            tm = get_send_recv_template(send_func, recv_func)
-                            tm.set_description("GlobalParameterMissmatch-Dtype-" + send_func,
-                                               "datatype missmatch: Rank0: " + type_1 + " Rank1: " + type_2)
-                            comm_var_name = "MPI_COMM_WORLD"
-                            if comm in self.comm_creators:
-                                comm_var_name = get_communicator(comm, tm)
-
-                            if comm in self.intercomms:
-                                comm_var_name = get_intercomm(comm, tm)
                             # global missmatch with size = sizeof(a)* sizeof(b) so that total size match both types
-                            type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2,
-                                                                                              get_bytes_size_for_type(
-                                                                                                  type_2),
-                                                                                              get_bytes_size_for_type(
-                                                                                                  type_1))
-                            call = tm.get_instruction(identifier="MPICALL", rank_excuting=0)
-                            call.set_has_error()
-                            call.set_arg("buf", buf_name_1)
-                            call.set_arg("datatype", type_var_1)
-                            call.set_arg("count", get_bytes_size_for_type(type_2))
-                            call.set_arg("comm", comm_var_name)
-                            call = tm.get_instruction(identifier="MPICALL", rank_excuting=1)
-                            call.set_has_error()
-                            call.set_arg("buf", buf_name_2)
-                            call.set_arg("datatype", type_var_2)
-                            call.set_arg("count", get_bytes_size_for_type(type_1))
-                            call.set_arg("comm", comm_var_name)
+                            tm = get_global_missmatch(type_1, type_2, get_bytes_size_for_type(type_2),
+                                                      get_bytes_size_for_type(type_1),
+                                                      send_func, recv_func, comm)
 
                             yield tm
                             if generate_level <= BASIC_TEST_LEVEL:
-- 
GitLab