From a0eda710e156b1ba58287317ca8a105e80795a90 Mon Sep 17 00:00:00 2001
From: Tim Jammer <tim.jammer@tu-darmstadt.de>
Date: Mon, 29 Apr 2024 13:19:06 +0200
Subject: [PATCH] Refactoring: simplified code further

---
 scripts/errors/dtypes/DtypeMissmatch.py | 172 ++++++++++++++----------
 1 file changed, 104 insertions(+), 68 deletions(-)

diff --git a/scripts/errors/dtypes/DtypeMissmatch.py b/scripts/errors/dtypes/DtypeMissmatch.py
index a763721f9..28b01b0ef 100644
--- a/scripts/errors/dtypes/DtypeMissmatch.py
+++ b/scripts/errors/dtypes/DtypeMissmatch.py
@@ -1,5 +1,6 @@
 #! /usr/bin/python3
 from copy import copy
+from random import shuffle
 
 from scripts.Infrastructure.AllocCall import AllocCall
 from scripts.Infrastructure.ErrorGenerator import ErrorGenerator
@@ -40,6 +41,32 @@ def get_local_missmatch(type_1, type_2, send_func, recv_func):
     return tm
 
 
+def get_correct_case(type_1, size_1, send_func, recv_func, comm):
+    tm = get_send_recv_template(send_func, recv_func)
+    tm.set_description("Correct-" + send_func,
+                       "")
+    comm_var_name = "MPI_COMM_WORLD"
+    if comm in comm_creators:
+        comm_var_name = get_communicator(comm, tm)
+
+    if comm in intercomms:
+        comm_var_name = get_intercomm(comm, tm)
+
+    type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_1, size_1, size_1)
+    call = tm.get_instruction(identifier="MPICALL", rank_excuting=0)
+    call.set_arg("buf", buf_name_1)
+    call.set_arg("datatype", type_var_1)
+    call.set_arg("count", size_1)
+    call.set_arg("comm", comm_var_name)
+    call = tm.get_instruction(identifier="MPICALL", rank_excuting=1)
+    call.set_arg("buf", buf_name_2)
+    call.set_arg("datatype", type_var_2)
+    call.set_arg("count", size_1)
+    call.set_arg("comm", comm_var_name)
+
+    return tm
+
+
 def get_global_missmatch(type_1, type_2, size_1, size_2, send_func, recv_func, comm):
     tm = get_send_recv_template(send_func, recv_func)
     tm.set_description("GlobalParameterMissmatch-Dtype-" + send_func,
@@ -68,6 +95,25 @@ def get_global_missmatch(type_1, type_2, size_1, size_2, send_func, recv_func, c
     return tm
 
 
+def is_combination_compatible(s, r):
+    t1, send_func, c1 = s
+    t2, recv_func, c2 = r
+
+    if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"] and recv_func not in ["mpi_irecv",
+                                                                                        "mpi_recv_init",
+                                                                                        "mpi_precv_init"]:
+        # leads to deadlock
+        return False
+
+    if t1 in predefined_types and t2 in predefined_types and predefined_mpi_dtype_consants[
+        t1] == predefined_mpi_dtype_consants[t2] and not (t1 == "MPI_BYTE" or t2 == "MPI_BYTE"):
+        # one type is just the alias of another, this is allowed
+        # but BYTE may not be mixed with other types see standard section 3.3.1
+        return False
+
+    return c1 == c2 and t1 != t2
+
+
 class DtypeMissmatch(ErrorGenerator):
     invalid_bufs = [CorrectParameterFactory().buf_var_name, "NULL"]
     send_funcs = ["mpi_send",
@@ -86,72 +132,62 @@ class DtypeMissmatch(ErrorGenerator):
         return ["P2P"]
 
     def generate(self, generate_level, real_world_score_table):
-        for send_func in self.send_funcs:
+
+        # (type,func,comm)
+        important_sends = []
+        important_recvs = []  #
+        for type in predefined_types + user_defined_types:
+            for send_func in self.send_funcs:
+                for comm in predefined_comms + comm_creators + intercomms:
+                    important_sends.append((type, send_func, comm))
+
+        for type in predefined_types + user_defined_types:
             for recv_func in self.recv_funcs:
-                if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"] and recv_func not in ["mpi_irecv",
-                                                                                                    "mpi_recv_init",
-                                                                                                    "mpi_precv_init"]:
-                    # invalid combination resulting in deadlock
-                    continue
-
-                checked_types = set()
-
-                for type_1 in predefined_types + user_defined_types:
-                    for type_2 in predefined_types + user_defined_types:
-                        if type_1 == type_2:
-                            # skip: valid case
-                            continue
-                        if type_1 in predefined_types and type_2 in predefined_types and predefined_mpi_dtype_consants[
-                            type_1] == predefined_mpi_dtype_consants[type_2]:
-                            # one type is just the alias of another, this is allowed
-                            if not (type_2 == "MPI_BYTE" or type_1 == "MPI_BYTE"):
-                                # but BYTE may not be mixed with other types see standard section 3.3.1
-                                continue
-                        if generate_level < REAL_WORLD_TEST_LEVEL and (
-                                type_1 in checked_types or type_2 in checked_types):
-                            # unnecessary repetition
-                            continue
-
-                        if generate_level == REAL_WORLD_TEST_LEVEL:
-                            if not is_combination_important(real_world_score_table, send_func,
-                                                            datatype=type_1.lower()) or not is_combination_important(
-                                real_world_score_table, recv_func, datatype=type_2.lower()):
-                                # not relevant in real world
-                                # print("irrelevant: %s %s -> %s %s"%(send_func,type_1,recv_func,type_2))
-                                continue
-
-                        checked_types.add(type_1)
-                        checked_types.add(type_2)
-
-                        tm = get_local_missmatch(type_1, type_2, send_func, recv_func)
-                        yield tm
-
-                        for comm in predefined_comms + comm_creators + intercomms:
-                            if comm != "MPI_COMM_WORLD" and generate_level < REAL_WORLD_TEST_LEVEL:
-                                continue
-                            if generate_level == REAL_WORLD_TEST_LEVEL:
-                                if (not is_combination_important(real_world_score_table, send_func,
-                                                                 datatype=type_1.lower(),
-                                                                 communicator=comm) or not
-                                is_combination_important(real_world_score_table,
-                                                         recv_func, datatype=type_2.lower(), communicator=comm)):
-                                    # not relevant in real world
-                                    continue
-                            tm = get_global_missmatch(type_1, type_2, 1, 1, send_func, recv_func, comm)
-
-                            yield tm
-                            # global missmatch with size = sizeof(a)* sizeof(b) so that total size match both types
-                            tm = get_global_missmatch(type_1, type_2, get_bytes_size_for_type(type_2),
-                                                      get_bytes_size_for_type(type_1),
-                                                      send_func, recv_func, comm)
-
-                            yield tm
-                            if generate_level <= BASIC_TEST_LEVEL:
-                                return
-
-                # end for each pair of send/recv
-                if generate_level < REAL_WORLD_TEST_LEVEL:
-                    return
-
-            # TODO mrecv?
-            # TODO sendrecv?
+                for comm in predefined_comms + comm_creators + intercomms:
+                    important_recvs.append((type, recv_func, comm))
+
+        # filter to only important ones
+        if generate_level == REAL_WORLD_TEST_LEVEL:
+            important_sends = [(t, f, c) for (t, f, c) in important_sends if
+                               is_combination_important(real_world_score_table, f,
+                                                        datatype=t.lower(),
+                                                        communicator=c)]
+            important_recvs = [(t, f, c) for (t, f, c) in important_recvs if
+                               is_combination_important(real_world_score_table, f,
+                                                        datatype=t.lower(),
+                                                        communicator=c)]
+
+        print("number of important recvs:")
+        print(len(important_recvs))
+
+        print("number of important sends:")
+        print(len(important_sends))
+
+        # all possible combinations
+        combinations_to_use = [(s, r) for s in important_sends for r in important_recvs if
+                               is_combination_compatible(s, r)]
+        # "re-format"
+        combinations_to_use = [(t1, t2, s, r, c) for (t1, s, c), (t2, r, _) in combinations_to_use]
+
+        print("combinations:")
+        print(len(combinations_to_use))
+
+        correct_types_checked = set()
+        for type_1, type_2, send_func, recv_func, comm in combinations_to_use:
+            # local missmatch only for one communicator
+            if comm == "MPI_COMM_WORLD":
+                yield get_local_missmatch(type_1, type_2, send_func, recv_func)
+
+            # global missmatch: communicator is important
+            yield get_global_missmatch(type_1, type_2, 1, 1, send_func, recv_func, comm)
+
+            # global missmatch with size = sizeof(a)* sizeof(b) so that total size match both types
+            yield get_global_missmatch(type_1, type_2, get_bytes_size_for_type(type_2),
+                                       get_bytes_size_for_type(type_1), send_func, recv_func, comm)
+
+            if type_1 not in correct_types_checked:
+                correct_types_checked.add(type_1)
+                yield get_correct_case(type_1, 1, send_func, recv_func, comm)
+
+        # TODO mrecv?
+        # TODO sendrecv?
-- 
GitLab