From 0d73c92a33f13f83582c2119f7191df33c56b8df Mon Sep 17 00:00:00 2001
From: Emmanuelle Saillard <emmanuelle.saillard@inria.fr>
Date: Tue, 30 Apr 2024 08:55:17 +0200
Subject: [PATCH] fix an error and update existing generators

---
 scripts/Infrastructure/TemplateFactory.py | 54 +++++++++++++++++++++--
 scripts/errors/coll/CallOrdering.py       | 26 +++++------
 scripts/errors/coll/Correct.py            |  1 +
 scripts/errors/coll/ParamMatching.py      | 22 ++++++++-
 scripts/errors/coll/RequestLifeCycle.py   |  2 +-
 5 files changed, 83 insertions(+), 22 deletions(-)

diff --git a/scripts/Infrastructure/TemplateFactory.py b/scripts/Infrastructure/TemplateFactory.py
index a61ce90ba..2ef9f31b4 100644
--- a/scripts/Infrastructure/TemplateFactory.py
+++ b/scripts/Infrastructure/TemplateFactory.py
@@ -178,19 +178,18 @@ def get_invalid_param_p2p_case(param, value, check_receive, send_func, recv_func
 
 def get_collective_template(collective_func):
     """
-    Contructs a default template for the given mpi collecive
+    Contructs a default template for the given mpi collective
     Returns:
         TemplateManager Initialized with a default template
         The function is contained in a block named MPICALL
     """
-    need_buf_funcs = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan", "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce", "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter"]
+    need_buf_funcs = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan", "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce", "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter", "mpi_reduce_scatter_block"]
     need_recv_and_send_buf_funcs = []
 
     tm = TemplateManager()
     cf = CorrectParameterFactory()
 
     # spilt send and recv buf
-    # to remove for barrier operation
     #if collective_func in need_buf_funcs:
     alloc = cf.get_buffer_alloc()
     alloc.set_identifier("ALLOC")
@@ -217,11 +216,60 @@ def get_collective_template(collective_func):
     if collective_func.startswith("mpi_i"):
         tm.register_instruction(CorrectMPICallFactory.mpi_wait(), rank_to_execute='all', identifier="WAIT")
 
+    #if collective_func in need_buf_funcs:
     tm.register_instruction(cf.get_buffer_free(), identifier="FREE")
 
     return tm
 
+def get_two_collective_template(collective_func1, collective_func2):
+    """
+    Contructs a default template for two given mpi collectives
+    Returns:
+        TemplateManager Initialized with a default template
+        The function is contained in a block named MPICALL
+    """
+    need_buf_funcs = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan", "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce", "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter", "mpi_reduce_scatter_block"]
+    need_recv_and_send_buf_funcs = []
+
+    tm = TemplateManager()
+    cf = CorrectParameterFactory()
+
+    # spilt send and recv buf
+    #if collective_func in need_buf_funcs:
+    alloc = cf.get_buffer_alloc()
+    alloc.set_identifier("ALLOC")
+    alloc.set_name("buf")
+    tm.register_instruction(alloc)
 
+    cmpicf = CorrectMPICallFactory()
+    call_creator_function = getattr(cmpicf, collective_func1)
+    c = call_creator_function()
+
+    # add request for nonblocking collectives
+    if collective_func1.startswith("mpi_i") or collective_func2.startswith("mpi_i"): 
+        tm.add_stack_variable("MPI_Request")
+
+    # Set parameters for some collectives: sendcount, recvcounts
+    #if collective_func in ["mpi_alltoallv"]:
+    # TODO 
+
+    coll1 = CorrectMPICallFactory.get(collective_func1)
+    coll1.set_identifier("MPICALL")
+    tm.register_instruction(coll1)
+    coll1.set_rank_executing(0)
+
+    coll2 = CorrectMPICallFactory.get(collective_func2)
+    coll2.set_identifier("MPICALL")
+    tm.register_instruction(coll2)
+    coll2.set_rank_executing('not0')
+
+    # add wait function for nonblocking collectives
+    if collective_func1.startswith("mpi_i") or collective_func2.startswith("mpi_i"):
+        tm.register_instruction(CorrectMPICallFactory.mpi_wait(), rank_to_execute='all', identifier="WAIT")
+
+    tm.register_instruction(cf.get_buffer_free(), identifier="FREE")
+
+    return tm
 
 def get_allocated_window(win_alloc_func, name, bufname, ctype, num_elements):
     """
diff --git a/scripts/errors/coll/CallOrdering.py b/scripts/errors/coll/CallOrdering.py
index d46849492..e8e3cdacd 100644
--- a/scripts/errors/coll/CallOrdering.py
+++ b/scripts/errors/coll/CallOrdering.py
@@ -5,12 +5,13 @@ from scripts.Infrastructure.Instruction import Instruction
 from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory
 from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv
 from scripts.Infrastructure.Template import TemplateManager
-from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_collective_template
+from scripts.Infrastructure.TemplateFactory import get_collective_template, get_two_collective_template
 
 class InvalidRankErrorColl(ErrorGenerator):
     functions_to_use = ["mpi_allgather","mpi_allreduce","mpi_alltoall","mpi_barrier","mpi_bcast", "mpi_reduce", "mpi_scatter","mpi_exscan","mpi_gather", "mpi_reduce_scatter_block", "mpi_scan", "mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan" ]
     functions_not_supported_yet = ["mpi_allgatherv", "mpi_alltoallv", "mpi_alltoallw", "mpi_gatherv", "mpi_reduce_scatter", "mpi_scatterv"]
-
+    #need_buf_funcs = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan", "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce", "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter", "mpi_reduce_scatter_block"]
+   
     def __init__(self):
         pass
 
@@ -32,24 +33,17 @@ class InvalidRankErrorColl(ErrorGenerator):
 
             yield tm
 
-        for func_to_use1 in self.functions_to_use:
-            for func_to_use2 in self.functions_to_use:  # pb: func1-func2 and func2-func1 -> remove some cases
-                tm = get_collective_template(func_to_use1)
 
-                tm.set_description("CallOrdering-unmatched-"+func_to_use1+"-"+func_to_use2, "Collective mismatch: "+func_to_use1+" is matched with "+func_to_use2)
+        for func1 in self.functions_to_use:
+            for func2 in self.functions_to_use:  # this generates func1-func2 and func2-func1 -> we need to remove similar cases
+                tm = get_two_collective_template(func1, func2)
+
+                tm.set_description("CallOrdering-unmatched-"+func1+"-"+func2, "Collective mismatch: "+func1+" is matched with "+func2)
 
                 for call in tm.get_instruction("MPICALL", return_list=True):
-                    call.set_rank_executing(0)
                     call.set_has_error()
-
-                    c = CorrectMPICallFactory.get(func_to_use2)
-                    if c.get_function() != call.get_function():
-                        if c.get_function().startswith("mpi_i"): 
-                            tm.add_stack_variable("MPI_Request") # not working..
-                        c.set_rank_executing('not0')
-                        c.set_has_error()
-                        tm.insert_instruction(c, after_instruction=call)
-                        yield tm
+                if func1 != func2: # we want different functions
+                    yield tm
 
 
             if not generate_full_set:
diff --git a/scripts/errors/coll/Correct.py b/scripts/errors/coll/Correct.py
index 3ac85f779..44a6c165c 100644
--- a/scripts/errors/coll/Correct.py
+++ b/scripts/errors/coll/Correct.py
@@ -10,6 +10,7 @@ from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_c
 class CorrectColl(ErrorGenerator):
     functions_to_use = ["mpi_allgather","mpi_allreduce","mpi_alltoall","mpi_barrier","mpi_bcast", "mpi_reduce", "mpi_scatter","mpi_exscan","mpi_gather", "mpi_reduce_scatter_block", "mpi_scan", "mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan" ]
     functions_not_supported_yet = ["mpi_gatherv", "mpi_scatterv", "mpi_igatherv", "mpi_iscatterv"]
+    
 
     def __init__(self):
         pass
diff --git a/scripts/errors/coll/ParamMatching.py b/scripts/errors/coll/ParamMatching.py
index e55e522cd..892c12515 100644
--- a/scripts/errors/coll/ParamMatching.py
+++ b/scripts/errors/coll/ParamMatching.py
@@ -5,7 +5,7 @@ from scripts.Infrastructure.Instruction import Instruction
 from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory
 from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv
 from scripts.Infrastructure.Template import TemplateManager
-from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_collective_template
+from scripts.Infrastructure.TemplateFactory import get_collective_template, get_two_collective_template
 
 class InvalidComErrorColl(ErrorGenerator):
     functions_to_use = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan", "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce", "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter"  ]
@@ -21,6 +21,7 @@ class InvalidComErrorColl(ErrorGenerator):
 
     def generate(self, generate_full_set):
 
+        # Generate codes with type mismatch
         for func_to_use in self.func_with_one_type_arg:
             tm = get_collective_template(func_to_use)
             type_to_use = "MPI_INT"
@@ -37,7 +38,7 @@ class InvalidComErrorColl(ErrorGenerator):
 
             yield tm
 
-
+        # Generate codes with op mismatch
         for func_to_use in self.func_with_op:
             tm = get_collective_template(func_to_use)
             op_to_use = "MPI_SUM"
@@ -54,6 +55,23 @@ class InvalidComErrorColl(ErrorGenerator):
 
             yield tm
 
+        # Generate codes with communicator mismatch
+        for func_to_use in self.functions_to_use:
+            tm = get_collective_template(func_to_use)
+            com_to_use = "MPI_COMM_SELF"
+            tm.set_description("ParamMatching-Com-"+func_to_use, "Wrong communicator matching")
+
+            for call in tm.get_instruction("MPICALL", return_list=True):
+                call.set_rank_executing(0)
+                call.set_arg("comm", com_to_use) 
+                call.set_has_error()
+                c = CorrectMPICallFactory.get(func_to_use)
+                c.set_rank_executing('not0')
+                c.set_has_error()
+                tm.insert_instruction(c, after_instruction=call)
+
+            yield tm
+
         # only check for one comm
         if not generate_full_set:
             return
diff --git a/scripts/errors/coll/RequestLifeCycle.py b/scripts/errors/coll/RequestLifeCycle.py
index 6da03bf89..f90d0ff6f 100644
--- a/scripts/errors/coll/RequestLifeCycle.py
+++ b/scripts/errors/coll/RequestLifeCycle.py
@@ -24,7 +24,7 @@ class CorrectColl(ErrorGenerator):
         for func_to_use in self.nbfunc_to_use:
             tm = get_collective_template(func_to_use)
 
-            tm.set_description("RequestLifeCycle-"+func_to_use, func_to_use+" has no completion")
+            tm.set_description("RequestLifeCycle-"+func_to_use, func_to_use+" is not associated with a completion operation (missing wait)")
 
             for call in tm.get_instruction("MPICALL", return_list=True):
                 wait = tm.get_instruction("WAIT", return_list=True)
-- 
GitLab