From a04de33bebb0360433a4e00ac6676235297cd844 Mon Sep 17 00:00:00 2001
From: Emmanuelle Saillard <emmanuelle.saillard@inria.fr>
Date: Thu, 2 May 2024 11:32:48 +0200
Subject: [PATCH] add invalid comm for topology function

---
 scripts/errors/coll/Correct.py     | 13 +------------
 scripts/errors/coll/InvalidComm.py |  6 +++---
 scripts/errors/coll/InvalidType.py |  2 +-
 3 files changed, 5 insertions(+), 16 deletions(-)

diff --git a/scripts/errors/coll/Correct.py b/scripts/errors/coll/Correct.py
index 97ed5bbc5..0e09ce4a8 100644
--- a/scripts/errors/coll/Correct.py
+++ b/scripts/errors/coll/Correct.py
@@ -8,10 +8,8 @@ from scripts.Infrastructure.Template import TemplateManager
 from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_collective_template
 
 class CorrectColl(ErrorGenerator):
-    functions_to_use = ["mpi_allgather","mpi_allreduce","mpi_alltoall","mpi_barrier","mpi_bcast", "mpi_reduce", "mpi_scatter","mpi_exscan","mpi_gather", "mpi_reduce_scatter_block", "mpi_scan", "mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan" ]
+    functions_to_use = ["mpi_allgather","mpi_allreduce","mpi_alltoall","mpi_barrier","mpi_bcast", "mpi_reduce", "mpi_scatter","mpi_exscan","mpi_gather", "mpi_reduce_scatter_block", "mpi_scan", "mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan", "mpi_cart_create" ]
     functions_not_supported_yet = ["mpi_gatherv", "mpi_scatterv", "mpi_igatherv", "mpi_iscatterv"]
-    
-    topology_functions = ["mpi_cart_create"]
 
     def __init__(self):
         pass
@@ -30,12 +28,3 @@ class CorrectColl(ErrorGenerator):
 
             if not generate_full_set:
                 return
-            
-        for func_to_use in self.topology_functions:
-            tm = get_collective_template(func_to_use)
-
-            tm.set_description("Correct-"+func_to_use, "Correct code")
-            yield tm
-
-            if not generate_full_set:
-                return
diff --git a/scripts/errors/coll/InvalidComm.py b/scripts/errors/coll/InvalidComm.py
index 204c0c672..a4240dbb7 100644
--- a/scripts/errors/coll/InvalidComm.py
+++ b/scripts/errors/coll/InvalidComm.py
@@ -9,7 +9,7 @@ from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_c
 
 class InvalidComErrorColl(ErrorGenerator):
     invalid_com = ["MPI_COMM_NULL", "NULL"]
-    functions_to_use = ["mpi_allgather","mpi_allreduce","mpi_alltoall","mpi_barrier","mpi_bcast", "mpi_reduce", "mpi_scatter","mpi_exscan","mpi_gather", "mpi_reduce_scatter_block", "mpi_scan", "mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan" ]
+    functions_to_use = ["mpi_allgather","mpi_allreduce","mpi_alltoall","mpi_barrier","mpi_bcast", "mpi_reduce", "mpi_scatter","mpi_exscan","mpi_gather", "mpi_reduce_scatter_block", "mpi_scan", "mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan", "mpi_cart_create" ]
     functions_not_supported_yet = ["mpi_allgatherv", "mpi_alltoallv", "mpi_alltoallw", "mpi_gatherv", "mpi_reduce_scatter", "mpi_scatterv"]
     ####functions_to_use = ["mpi_allgather","mpi_allgatherv","mpi_allreduce","mpi_alltoall","mpi_alltoallv","mpi_alltoallw","mpi_barrier","mpi_bcast", "mpi_exscan","mpi_gather", "mpi_gatherv","mpi_reduce", "mpi_reduce_scatter", "mpi_reduce_scatter_block", "mpi_scan", "mpi_scatter", "mpi_scatterv", "mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan"]
 
@@ -24,10 +24,10 @@ class InvalidComErrorColl(ErrorGenerator):
         for com_to_use in  self.invalid_com:
             for func_to_use in self.functions_to_use:
                 tm = get_collective_template(func_to_use)
-                arg_to_replace = "comm"
-
+                
                 tm.set_description("InvalidParam-Comm-"+func_to_use, "Invalid communicator: %s" % com_to_use)
                 for call in tm.get_instruction("MPICALL", return_list=True):
+                    arg_to_replace = "comm" if call.has_arg("comm") else "comm_old"
                     call.set_arg(arg_to_replace, com_to_use)
                     call.set_has_error()
 
diff --git a/scripts/errors/coll/InvalidType.py b/scripts/errors/coll/InvalidType.py
index 4b9b1d32c..5fb0e977d 100644
--- a/scripts/errors/coll/InvalidType.py
+++ b/scripts/errors/coll/InvalidType.py
@@ -8,7 +8,7 @@ from scripts.Infrastructure.Template import TemplateManager
 from scripts.Infrastructure.TemplateFactory import get_collective_template
 
 class InvalidComErrorColl(ErrorGenerator):
-    invalid_type = ["MPI_DATATYPE_NULL"]
+    invalid_type = ["MPI_DATATYPE_NULL", "NULL"]
     functions_to_use = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan", "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce", "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter"  ]
     func_one_type_arg = ["mpi_bcast", "mpi_reduce", "mpi_exscan", "mpi_scan", "mpi_ibcast", "mpi_ireduce", "mpi_iscan", "mpi_allreduce", "mpi_iallreduce" ]
     functions_not_supported_yet = ["mpi_reduce_scatter_block", "mpi_allgatherv", "mpi_alltoallv", "mpi_alltoallw", "mpi_gatherv", "mpi_reduce_scatter", "mpi_scatterv"]
-- 
GitLab