Skip to content
Snippets Groups Projects
Commit ab034535 authored by Jammer, Tim's avatar Jammer, Tim
Browse files

updated coll cases to match infrastructure, format code

parent 1c814662
Branches
No related tags found
No related merge requests found
...@@ -6,21 +6,26 @@ from scripts.Infrastructure.Instruction import Instruction ...@@ -6,21 +6,26 @@ from scripts.Infrastructure.Instruction import Instruction
from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv
from scripts.Infrastructure.Template import TemplateManager from scripts.Infrastructure.Template import TemplateManager
from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_collective_template from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_collective_template, \
get_two_collective_template
class InvalidRankErrorColl(ErrorGenerator): class InvalidRankErrorColl(ErrorGenerator):
functions_to_use = ["mpi_allgather","mpi_allreduce","mpi_alltoall","mpi_barrier","mpi_bcast", "mpi_reduce", "mpi_scatter","mpi_exscan","mpi_gather", "mpi_reduce_scatter_block", "mpi_scan", "mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan" ] functions_to_use = ["mpi_allgather", "mpi_allreduce", "mpi_alltoall", "mpi_barrier", "mpi_bcast", "mpi_reduce",
functions_not_supported_yet = ["mpi_allgatherv", "mpi_alltoallv", "mpi_alltoallw", "mpi_gatherv", "mpi_reduce_scatter", "mpi_scatterv"] "mpi_scatter", "mpi_exscan", "mpi_gather", "mpi_reduce_scatter_block", "mpi_scan",
"mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter",
"mpi_igather", "mpi_iscan"]
functions_not_supported_yet = ["mpi_allgatherv", "mpi_alltoallv", "mpi_alltoallw", "mpi_gatherv",
"mpi_reduce_scatter", "mpi_scatterv"]
# need_buf_funcs = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan", "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce", "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter", "mpi_reduce_scatter_block"] # need_buf_funcs = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan", "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce", "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter", "mpi_reduce_scatter_block"]
def __init__(self): def __init__(self):
pass pass
def get_feature(self): def get_feature(self):
return ["COLL"] return ["COLL"]
def generate(self, generate_level, real_world_score_table): def generate(self, generate_level, real_world_score_table):
for func_to_use in self.functions_to_use: for func_to_use in self.functions_to_use:
...@@ -34,18 +39,17 @@ class InvalidRankErrorColl(ErrorGenerator): ...@@ -34,18 +39,17 @@ class InvalidRankErrorColl(ErrorGenerator):
yield tm yield tm
for func1 in self.functions_to_use: for func1 in self.functions_to_use:
for func2 in self.functions_to_use: # this generates func1-func2 and func2-func1 -> we need to remove similar cases for func2 in self.functions_to_use: # this generates func1-func2 and func2-func1 -> we need to remove similar cases
tm = get_two_collective_template(func1, func2) tm = get_two_collective_template(func1, func2)
tm.set_description("CallOrdering-unmatched-"+func1+"-"+func2, "Collective mismatch: "+func1+" is matched with "+func2) tm.set_description("CallOrdering-unmatched-" + func1 + "-" + func2,
"Collective mismatch: " + func1 + " is matched with " + func2)
for call in tm.get_instruction("MPICALL", return_list=True): for call in tm.get_instruction("MPICALL", return_list=True):
call.set_has_error() call.set_has_error()
if func1 != func2: # we want different functions if func1 != func2: # we want different functions
yield tm yield tm
if not generate_level <= BASIC_TEST_LEVEL: if not generate_level <= BASIC_TEST_LEVEL:
return return
...@@ -6,20 +6,24 @@ from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICall ...@@ -6,20 +6,24 @@ from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICall
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv
from scripts.Infrastructure.Template import TemplateManager from scripts.Infrastructure.Template import TemplateManager
from scripts.Infrastructure.TemplateFactory import get_collective_template, get_two_collective_template from scripts.Infrastructure.TemplateFactory import get_collective_template, get_two_collective_template
from scripts.Infrastructure.Variables import *
class CorrectColl(ErrorGenerator): class CorrectColl(ErrorGenerator):
functions_to_use = ["mpi_allgather","mpi_allreduce","mpi_alltoall","mpi_barrier","mpi_bcast", "mpi_reduce", "mpi_scatter","mpi_exscan","mpi_gather", "mpi_reduce_scatter_block", "mpi_scan", "mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan"] functions_to_use = ["mpi_allgather", "mpi_allreduce", "mpi_alltoall", "mpi_barrier", "mpi_bcast", "mpi_reduce",
"mpi_scatter", "mpi_exscan", "mpi_gather", "mpi_reduce_scatter_block", "mpi_scan",
"mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter",
"mpi_igather", "mpi_iscan"]
functions_not_supported_yet = ["mpi_gatherv", "mpi_scatterv", "mpi_igatherv", "mpi_iscatterv"] functions_not_supported_yet = ["mpi_gatherv", "mpi_scatterv", "mpi_igatherv", "mpi_iscatterv"]
topology_functions = ["mpi_cart_create"] topology_functions = ["mpi_cart_create"]
def __init__(self): def __init__(self):
pass pass
def get_feature(self): def get_feature(self):
return ["COLL"] return ["COLL"]
def generate(self, generate_full_set): def generate(self, generate_level, real_world_score_table):
# Only one function called by all processes # Only one function called by all processes
for func_to_use in self.functions_to_use: for func_to_use in self.functions_to_use:
...@@ -44,5 +48,5 @@ class CorrectColl(ErrorGenerator): ...@@ -44,5 +48,5 @@ class CorrectColl(ErrorGenerator):
tm.register_instruction(cart_get) tm.register_instruction(cart_get)
yield tm yield tm
if not generate_full_set: if generate_level >= BASIC_TEST_LEVEL:
return return
#! /usr/bin/python3 #! /usr/bin/python3
from scripts.Infrastructure.MPICallFactory import CorrectMPICallFactory
from scripts.Infrastructure.Variables import * from scripts.Infrastructure.Variables import *
from scripts.Infrastructure.ErrorGenerator import ErrorGenerator from scripts.Infrastructure.ErrorGenerator import ErrorGenerator
from scripts.Infrastructure.TemplateFactory import get_collective_template from scripts.Infrastructure.TemplateFactory import get_collective_template
class InvalidComErrorColl(ErrorGenerator): class InvalidComErrorColl(ErrorGenerator):
invalid_com = ["MPI_COMM_NULL", "NULL"] invalid_com = ["MPI_COMM_NULL", "NULL"]
functions_to_use = ["mpi_allgather","mpi_allreduce","mpi_alltoall","mpi_barrier","mpi_bcast", "mpi_reduce", "mpi_scatter","mpi_exscan","mpi_gather", "mpi_reduce_scatter_block", "mpi_scan", "mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan", "mpi_cart_create" ] functions_to_use = ["mpi_allgather", "mpi_allreduce", "mpi_alltoall", "mpi_barrier", "mpi_bcast", "mpi_reduce",
functions_not_supported_yet = ["mpi_allgatherv", "mpi_alltoallv", "mpi_alltoallw", "mpi_gatherv", "mpi_reduce_scatter", "mpi_scatterv"] "mpi_scatter", "mpi_exscan", "mpi_gather", "mpi_reduce_scatter_block", "mpi_scan",
"mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter",
"mpi_igather", "mpi_iscan", "mpi_cart_create"]
functions_not_supported_yet = ["mpi_allgatherv", "mpi_alltoallv", "mpi_alltoallw", "mpi_gatherv",
"mpi_reduce_scatter", "mpi_scatterv"]
####functions_to_use = ["mpi_allgather","mpi_allgatherv","mpi_allreduce","mpi_alltoall","mpi_alltoallv","mpi_alltoallw","mpi_barrier","mpi_bcast", "mpi_exscan","mpi_gather", "mpi_gatherv","mpi_reduce", "mpi_reduce_scatter", "mpi_reduce_scatter_block", "mpi_scan", "mpi_scatter", "mpi_scatterv", "mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan"] ####functions_to_use = ["mpi_allgather","mpi_allgatherv","mpi_allreduce","mpi_alltoall","mpi_alltoallv","mpi_alltoallw","mpi_barrier","mpi_bcast", "mpi_exscan","mpi_gather", "mpi_gatherv","mpi_reduce", "mpi_reduce_scatter", "mpi_reduce_scatter_block", "mpi_scan", "mpi_scatter", "mpi_scatterv", "mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan"]
topology_functions = ["mpi_cart_create"] topology_functions = ["mpi_cart_create"]
...@@ -30,14 +36,14 @@ class InvalidComErrorColl(ErrorGenerator): ...@@ -30,14 +36,14 @@ class InvalidComErrorColl(ErrorGenerator):
yield tm yield tm
for fun_to_use in self.topology_functions: for fun_to_use in self.topology_functions:
tm = get_collective_template(func_to_use) tm = get_collective_template(func_to_use)
for com_to_use in ["MPI_COMM_NULL", "NULL", "MPI_COMM_WORLD"]: for com_to_use in ["MPI_COMM_NULL", "NULL", "MPI_COMM_WORLD"]:
tm.set_description("InvalidParam-Comm-"+func_to_use+"-mpi_cart_get", "A function tries to get cartesian information of "+com_to_use) tm.set_description("InvalidParam-Comm-" + func_to_use + "-mpi_cart_get",
"A function tries to get cartesian information of " + com_to_use)
cart_get = CorrectMPICallFactory().mpi_cart_get() cart_get = CorrectMPICallFactory.mpi_cart_get()
cart_get.set_arg("comm_cart", com_to_use) cart_get.set_arg("comm_cart", com_to_use)
tm.register_instruction(cart_get) tm.register_instruction(cart_get)
cart_get.set_has_error() cart_get.set_has_error()
...@@ -46,4 +52,3 @@ class InvalidComErrorColl(ErrorGenerator): ...@@ -46,4 +52,3 @@ class InvalidComErrorColl(ErrorGenerator):
# only check for one comm # only check for one comm
if generate_level <= BASIC_TEST_LEVEL: if generate_level <= BASIC_TEST_LEVEL:
return return
...@@ -4,13 +4,13 @@ from scripts.Infrastructure.Variables import * ...@@ -4,13 +4,13 @@ from scripts.Infrastructure.Variables import *
from scripts.Infrastructure.ErrorGenerator import ErrorGenerator from scripts.Infrastructure.ErrorGenerator import ErrorGenerator
from scripts.Infrastructure.TemplateFactory import get_collective_template from scripts.Infrastructure.TemplateFactory import get_collective_template
class InvalidComErrorColl(ErrorGenerator): class InvalidComErrorColl(ErrorGenerator):
invalid_op = ["MPI_OP_NULL"] invalid_op = ["MPI_OP_NULL"]
functions_to_use = ["mpi_reduce", "mpi_ireduce", "mpi_allreduce", "mpi_iallreduce"] functions_to_use = ["mpi_reduce", "mpi_ireduce", "mpi_allreduce", "mpi_iallreduce"]
# TODO invalid op+ type combinations aka MPI_MAXLOC with MPI_BYTE or something klie this # TODO invalid op+ type combinations aka MPI_MAXLOC with MPI_BYTE or something klie this
def __init__(self): def __init__(self):
pass pass
...@@ -31,4 +31,3 @@ class InvalidComErrorColl(ErrorGenerator): ...@@ -31,4 +31,3 @@ class InvalidComErrorColl(ErrorGenerator):
# only check for one comm # only check for one comm
if generate_level <= BASIC_TEST_LEVEL: if generate_level <= BASIC_TEST_LEVEL:
return return
...@@ -4,15 +4,16 @@ from scripts.Infrastructure.Variables import * ...@@ -4,15 +4,16 @@ from scripts.Infrastructure.Variables import *
from scripts.Infrastructure.ErrorGenerator import ErrorGenerator from scripts.Infrastructure.ErrorGenerator import ErrorGenerator
from scripts.Infrastructure.TemplateFactory import get_collective_template from scripts.Infrastructure.TemplateFactory import get_collective_template
class InvalidRankErrorColl(ErrorGenerator): class InvalidRankErrorColl(ErrorGenerator):
invalid_ranks = ["-1", "nprocs", "MPI_PROC_NULL"] invalid_ranks = ["-1", "nprocs", "MPI_PROC_NULL"]
functions_to_use = ["mpi_reduce", "mpi_bcast", "mpi_gather", "mpi_scatter", "mpi_ireduce", "mpi_ibcast", "mpi_igather", "mpi_iscatter"] functions_to_use = ["mpi_reduce", "mpi_bcast", "mpi_gather", "mpi_scatter", "mpi_ireduce", "mpi_ibcast",
"mpi_igather", "mpi_iscatter"]
functions_not_supported_yet = ["mpi_gatherv", "mpi_scatterv", "mpi_igatherv", "mpi_iscatterv"] functions_not_supported_yet = ["mpi_gatherv", "mpi_scatterv", "mpi_igatherv", "mpi_iscatterv"]
def __init__(self): def __init__(self):
pass pass
def get_feature(self): def get_feature(self):
return ["COLL"] return ["COLL"]
......
...@@ -8,12 +8,16 @@ from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory ...@@ -8,12 +8,16 @@ from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory
from scripts.Infrastructure.Template import TemplateManager from scripts.Infrastructure.Template import TemplateManager
from scripts.Infrastructure.TemplateFactory import get_collective_template from scripts.Infrastructure.TemplateFactory import get_collective_template
class InvalidComErrorColl(ErrorGenerator): class InvalidComErrorColl(ErrorGenerator):
invalid_type = ["MPI_DATATYPE_NULL", "NULL"] invalid_type = ["MPI_DATATYPE_NULL", "NULL"]
functions_to_use = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan", "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce", "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter" ] functions_to_use = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan",
func_one_type_arg = ["mpi_bcast", "mpi_reduce", "mpi_exscan", "mpi_scan", "mpi_ibcast", "mpi_ireduce", "mpi_iscan", "mpi_allreduce", "mpi_iallreduce" ] "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce",
functions_not_supported_yet = ["mpi_reduce_scatter_block", "mpi_allgatherv", "mpi_alltoallv", "mpi_alltoallw", "mpi_gatherv", "mpi_reduce_scatter", "mpi_scatterv"] "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter"]
func_one_type_arg = ["mpi_bcast", "mpi_reduce", "mpi_exscan", "mpi_scan", "mpi_ibcast", "mpi_ireduce", "mpi_iscan",
"mpi_allreduce", "mpi_iallreduce"]
functions_not_supported_yet = ["mpi_reduce_scatter_block", "mpi_allgatherv", "mpi_alltoallv", "mpi_alltoallw",
"mpi_gatherv", "mpi_reduce_scatter", "mpi_scatterv"]
def __init__(self): def __init__(self):
pass pass
...@@ -38,4 +42,3 @@ class InvalidComErrorColl(ErrorGenerator): ...@@ -38,4 +42,3 @@ class InvalidComErrorColl(ErrorGenerator):
# only check for one comm # only check for one comm
if generate_level <= BASIC_TEST_LEVEL: if generate_level <= BASIC_TEST_LEVEL:
return return
...@@ -6,19 +6,20 @@ from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICall ...@@ -6,19 +6,20 @@ from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICall
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv
from scripts.Infrastructure.Template import TemplateManager from scripts.Infrastructure.Template import TemplateManager
from scripts.Infrastructure.TemplateFactory import get_collective_template from scripts.Infrastructure.TemplateFactory import get_collective_template
from scripts.Infrastructure.Variables import *
class InvalidRankErrorColl(ErrorGenerator): class InvalidRankErrorColl(ErrorGenerator):
nbfunc_to_use = ["mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan" ] nbfunc_to_use = ["mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather",
"mpi_iscan"]
def __init__(self): def __init__(self):
pass pass
def get_feature(self): def get_feature(self):
return ["COLL"] return ["COLL"]
def generate(self, generate_level, real_world_score_table):
def generate(self, generate_full_set):
for func_to_use in self.nbfunc_to_use: for func_to_use in self.nbfunc_to_use:
tm = get_collective_template(func_to_use) tm = get_collective_template(func_to_use)
...@@ -30,8 +31,7 @@ class InvalidRankErrorColl(ErrorGenerator): ...@@ -30,8 +31,7 @@ class InvalidRankErrorColl(ErrorGenerator):
wait = tm.get_instruction("WAIT", return_list=True) wait = tm.get_instruction("WAIT", return_list=True)
tm.insert_instruction(conflicting_inst, before_instruction=wait) tm.insert_instruction(conflicting_inst, before_instruction=wait)
yield tm yield tm
if not generate_full_set: if generate_level >= BASIC_TEST_LEVEL:
return return
...@@ -7,12 +7,18 @@ from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get ...@@ -7,12 +7,18 @@ from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get
from scripts.Infrastructure.Template import TemplateManager from scripts.Infrastructure.Template import TemplateManager
from scripts.Infrastructure.TemplateFactory import get_collective_template, get_two_collective_template from scripts.Infrastructure.TemplateFactory import get_collective_template, get_two_collective_template
class InvalidComErrorColl(ErrorGenerator): class InvalidComErrorColl(ErrorGenerator):
functions_to_use = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan", "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce", "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter" ] functions_to_use = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan",
func_with_one_type_arg = ["mpi_bcast", "mpi_reduce", "mpi_exscan", "mpi_scan", "mpi_ibcast", "mpi_ireduce", "mpi_iscan", "mpi_allreduce", "mpi_iallreduce" ] "mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce",
functions_not_supported_yet = ["mpi_reduce_scatter_block", "mpi_allgatherv", "mpi_alltoallv", "mpi_alltoallw", "mpi_gatherv", "mpi_reduce_scatter", "mpi_scatterv"] "mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter"]
func_with_one_type_arg = ["mpi_bcast", "mpi_reduce", "mpi_exscan", "mpi_scan", "mpi_ibcast", "mpi_ireduce",
"mpi_iscan", "mpi_allreduce", "mpi_iallreduce"]
functions_not_supported_yet = ["mpi_reduce_scatter_block", "mpi_allgatherv", "mpi_alltoallv", "mpi_alltoallw",
"mpi_gatherv", "mpi_reduce_scatter", "mpi_scatterv"]
func_with_op = ["mpi_reduce", "mpi_ireduce", "mpi_allreduce", "mpi_iallreduce"] func_with_op = ["mpi_reduce", "mpi_ireduce", "mpi_allreduce", "mpi_iallreduce"]
func_with_root = ["mpi_reduce", "mpi_bcast", "mpi_gather", "mpi_scatter", "mpi_ireduce", "mpi_ibcast", "mpi_igather", "mpi_iscatter"] func_with_root = ["mpi_reduce", "mpi_bcast", "mpi_gather", "mpi_scatter", "mpi_ireduce", "mpi_ibcast",
"mpi_igather", "mpi_iscatter"]
def __init__(self): def __init__(self):
pass pass
...@@ -20,7 +26,7 @@ class InvalidComErrorColl(ErrorGenerator): ...@@ -20,7 +26,7 @@ class InvalidComErrorColl(ErrorGenerator):
def get_feature(self): def get_feature(self):
return ["COLL"] return ["COLL"]
def generate(self, generate_full_set): def generate(self, generate_level, real_world_score_table):
# Generate codes with type mismatch # Generate codes with type mismatch
for func_to_use in self.func_with_one_type_arg: for func_to_use in self.func_with_one_type_arg:
...@@ -88,4 +94,3 @@ class InvalidComErrorColl(ErrorGenerator): ...@@ -88,4 +94,3 @@ class InvalidComErrorColl(ErrorGenerator):
if not generate_full_set: if not generate_full_set:
return return
...@@ -6,25 +6,27 @@ from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICall ...@@ -6,25 +6,27 @@ from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICall
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv
from scripts.Infrastructure.Template import TemplateManager from scripts.Infrastructure.Template import TemplateManager
from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_collective_template from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_collective_template
from scripts.Infrastructure.Variables import *
class CorrectColl(ErrorGenerator): class CorrectColl(ErrorGenerator):
nbfunc_to_use = ["mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter", "mpi_igather", "mpi_iscan" ] nbfunc_to_use = ["mpi_ibarrier", "mpi_iallreduce", "mpi_ialltoall", "mpi_ibcast", "mpi_ireduce", "mpi_iscatter",
"mpi_igather", "mpi_iscan"]
functions_not_supported_yet = ["mpi_igatherv", "mpi_iscatterv"] functions_not_supported_yet = ["mpi_igatherv", "mpi_iscatterv"]
def __init__(self): def __init__(self):
pass pass
def get_feature(self): def get_feature(self):
return ["COLL"] return ["COLL"]
def generate(self, generate_level, real_world_score_table):
def generate(self, generate_full_set):
for func_to_use in self.nbfunc_to_use: for func_to_use in self.nbfunc_to_use:
tm = get_collective_template(func_to_use) tm = get_collective_template(func_to_use)
tm.set_description("RequestLifeCycle-"+func_to_use, func_to_use+" is not associated with a completion operation (missing wait)") tm.set_description("RequestLifeCycle-" + func_to_use,
func_to_use + " is not associated with a completion operation (missing wait)")
for call in tm.get_instruction("MPICALL", return_list=True): for call in tm.get_instruction("MPICALL", return_list=True):
wait = tm.get_instruction("WAIT", return_list=True) wait = tm.get_instruction("WAIT", return_list=True)
...@@ -33,5 +35,5 @@ class CorrectColl(ErrorGenerator): ...@@ -33,5 +35,5 @@ class CorrectColl(ErrorGenerator):
yield tm yield tm
if not generate_full_set: if generate_level >= BASIC_TEST_LEVEL:
return return
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment