Skip to content
Snippets Groups Projects
Commit aaba72de authored by Jammer, Tim's avatar Jammer, Tim
Browse files

same refactoring for collectives

parent 7f12944a
Branches
No related tags found
No related merge requests found
......@@ -7,16 +7,7 @@ from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICall
from scripts.Infrastructure.TemplateFactory import get_collective_template, predefined_types, user_defined_types, \
predefined_mpi_dtype_consants, get_type_buffers, get_bytes_size_for_type, get_communicator, get_intercomm
class InvalidComErrorColl(ErrorGenerator):
functions_to_use = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan",
"mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce",
"mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter"]
func_one_type_arg = ["mpi_bcast", "mpi_reduce", "mpi_exscan", "mpi_scan", "mpi_ibcast", "mpi_ireduce", "mpi_iscan",
"mpi_allreduce", "mpi_iallreduce"]
functions_not_supported_yet = ["mpi_reduce_scatter_block", "mpi_allgatherv", "mpi_alltoallv", "mpi_alltoallw",
"mpi_gatherv", "mpi_reduce_scatter", "mpi_scatterv"]
# TODO refactor into different file
# test if the tool chan deal with messages send over different communicators
predefined_comms = ["MPI_COMM_WORLD"]
comm_creators = ["mpi_comm_dup", "mpi_comm_dup_with_info", "mpi_comm_idup",
......@@ -25,45 +16,9 @@ class InvalidComErrorColl(ErrorGenerator):
]
intercomms = ["mpi_intercomm_create", "mpi_intercomm_merge", "mpi_intercomm_create_from_groups"]
def __init__(self):
pass
def get_feature(self):
return ["COLL"]
def generate(self, generate_level, real_world_score_table):
for func_to_use in self.func_one_type_arg:
checked_types = set()
for type_1 in predefined_types + user_defined_types:
for type_2 in predefined_types + user_defined_types:
if type_1 == type_2:
# skip: valid case
continue
if type_1 in predefined_types and type_2 in predefined_types and predefined_mpi_dtype_consants[
type_1] == predefined_mpi_dtype_consants[type_2]:
# one type is just the alias of another, this is allowed
if not (type_2 == "MPI_BYTE" or type_1 == "MPI_BYTE"):
# but BYTE may not be mixed with other types see standard section 3.3.1
continue
if generate_level < REAL_WORLD_TEST_LEVEL and (
type_1 in checked_types or type_2 in checked_types):
# unnecessary repetition
continue
if generate_level == REAL_WORLD_TEST_LEVEL:
if not is_combination_important(real_world_score_table, func_to_use,
datatype=type_1.lower()):
# not relevant in real world
continue
checked_types.add(type_1)
checked_types.add(type_2)
def get_local_missmatch(type_1, type_2, func_to_use):
tm = get_collective_template(func_to_use)
## local missmatch
tm.set_description("ParamMatching-Type-" + func_to_use,
"Wrong datatype matching: %s vs %s" % (type_1, type_2))
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 1, 1)
......@@ -77,30 +32,18 @@ class InvalidComErrorColl(ErrorGenerator):
call.set_arg("sendbuf", buf_name_2)
else:
call.set_arg("buffer", buf_name_2)
yield tm
return tm
if generate_level < BASIC_TEST_LEVEL:
return
# global missmatch
for comm in self.predefined_comms + self.comm_creators + self.intercomms:
if comm != "MPI_COMM_WORLD" and generate_level < REAL_WORLD_TEST_LEVEL:
continue
if generate_level == REAL_WORLD_TEST_LEVEL:
if (not is_combination_important(real_world_score_table, func_to_use,
datatype=type_1.lower(),
communicator=comm)):
# not relevant in real world
continue
def get_global_missmatch(type_1, type_2, count_1, count_2, func_to_use, comm):
tm = get_collective_template(func_to_use)
comm_var_name = "MPI_COMM_WORLD"
if comm in self.comm_creators:
if comm in comm_creators:
comm_var_name = get_communicator(comm, tm)
if comm in self.intercomms:
if comm in intercomms:
comm_var_name = get_intercomm(comm, tm)
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 1, 1)
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, count_1, count_2)
tm.set_description("ParamMatching-Type-" + func_to_use,
"Wrong datatype matching: %s vs %s" % (type_1, type_2))
......@@ -108,7 +51,7 @@ class InvalidComErrorColl(ErrorGenerator):
for call in tm.get_instruction("MPICALL", return_list=True):
call.set_rank_executing(0)
call.set_arg("datatype", type_var_1)
call.set_arg("count", 1)
call.set_arg("count", count_1)
call.set_arg("comm", comm_var_name)
call.set_has_error()
if call.has_arg("recvbuf"):
......@@ -120,7 +63,7 @@ class InvalidComErrorColl(ErrorGenerator):
c = CorrectMPICallFactory.get(func_to_use)
c.set_rank_executing('not0')
c.set_arg("datatype", type_var_2)
c.set_arg("count", 1)
c.set_arg("count", count_2)
c.set_arg("comm", comm_var_name)
c.set_has_error()
if c.has_arg("recvbuf"):
......@@ -133,30 +76,25 @@ class InvalidComErrorColl(ErrorGenerator):
yield tm
# missmatch with matching sizes
def get_correct_case(type_1, count_1, func_to_use, comm):
tm = get_collective_template(func_to_use)
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2,
get_bytes_size_for_type(
type_2),
get_bytes_size_for_type(
type_1))
comm_var_name = "MPI_COMM_WORLD"
if comm in self.comm_creators:
if comm in comm_creators:
comm_var_name = get_communicator(comm, tm)
if comm in self.intercomms:
if comm in intercomms:
comm_var_name = get_intercomm(comm, tm)
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_1, count_1, count_1)
tm.set_description("ParamMatching-Type-" + func_to_use,
"Wrong datatype matching: %s vs %s" % (type_1, type_2))
tm.set_description("Correct-" + func_to_use, "")
for call in tm.get_instruction("MPICALL", return_list=True):
call.set_rank_executing(0)
call.set_arg("datatype", type_var_1)
call.set_arg("count", get_bytes_size_for_type(type_2))
call.set_arg("count", count_1)
call.set_arg("comm", comm_var_name)
call.set_has_error()
if call.has_arg("recvbuf"):
call.set_arg("recvbuf", buf_name_1)
call.set_arg("sendbuf", buf_name_1)
......@@ -166,7 +104,7 @@ class InvalidComErrorColl(ErrorGenerator):
c = CorrectMPICallFactory.get(func_to_use)
c.set_rank_executing('not0')
c.set_arg("datatype", type_var_2)
c.set_arg("count", get_bytes_size_for_type(type_1))
c.set_arg("count", count_1)
c.set_arg("comm", comm_var_name)
c.set_has_error()
if c.has_arg("recvbuf"):
......@@ -178,3 +116,76 @@ class InvalidComErrorColl(ErrorGenerator):
tm.insert_instruction(c, after_instruction=call)
yield tm
def is_combination_compatible(t1, t2, f):
if t1 in predefined_types and t2 in predefined_types and predefined_mpi_dtype_consants[
t1] == predefined_mpi_dtype_consants[t2] and not (t1 == "MPI_BYTE" or t2 == "MPI_BYTE"):
# one type is just the alias of another, this is allowed
# but BYTE may not be mixed with other types see standard section 3.3.1
return False
return t1 != t2
class InvalidComErrorColl(ErrorGenerator):
functions_to_use = ["mpi_bcast", "mpi_ibcast", "mpi_reduce", "mpi_ireduce", "mpi_exscan", "mpi_scan", "mpi_iscan",
"mpi_gather", "mpi_igather", "mpi_allgather", "mpi_iallgather", "mpi_allreduce",
"mpi_iallreduce", "mpi_alltoall", "mpi_ialltoall", "mpi_scatter", "mpi_iscatter"]
func_one_type_arg = ["mpi_bcast", "mpi_reduce", "mpi_exscan", "mpi_scan", "mpi_ibcast", "mpi_ireduce", "mpi_iscan",
"mpi_allreduce", "mpi_iallreduce"]
functions_not_supported_yet = ["mpi_reduce_scatter_block", "mpi_allgatherv", "mpi_alltoallv", "mpi_alltoallw",
"mpi_gatherv", "mpi_reduce_scatter", "mpi_scatterv"]
def __init__(self):
pass
def get_feature(self):
return ["COLL"]
def generate(self, generate_level, real_world_score_table):
combinations_to_use = []
for f in self.functions_to_use:
for t1 in predefined_types + user_defined_types:
for t2 in predefined_types + user_defined_types:
for comm in predefined_comms + comm_creators + intercomms:
if is_combination_compatible(t1, t2, f):
combinations_to_use.append((t1, t2, f, comm))
if generate_level == REAL_WORLD_TEST_LEVEL:
combinations_to_use = [(t1, t2, f, comm) for (t1, t2, f, comm) in combinations_to_use if
is_combination_important(real_world_score_table,
f, datatype=t1.lower(),
communicator=comm) and
is_combination_important(real_world_score_table,
f, datatype=t2.lower(),
communicator=comm)]
if generate_level == SUFFICIENT_TEST_LEVEL:
types_checked = set()
combinations_to_use_filtered = []
for (t1, t2, f, c) in combinations_to_use:
if t1 not in types_checked and t2 not in types_checked:
types_checked.add(t1)
types_checked.add(t2)
combinations_to_use_filtered.append((t1, t2, f, c))
combinations_to_use = combinations_to_use_filtered
if generate_level == BASIC_TEST_LEVEL:
combinations_to_use = combinations_to_use[0:1]
correct_types_checked = set()
for type_1, type_2, func_to_use, comm in combinations_to_use:
if comm == "MPI_COMM_WORLD":
yield get_local_missmatch(type_1, type_2, func_to_use)
yield get_global_missmatch(type_1, type_2, 1, 1, func_to_use, comm)
# missmatch with matching sizes
yield get_global_missmatch(type_1, type_2, get_bytes_size_for_type(type_2),
get_bytes_size_for_type(type_1),
func_to_use, comm)
if type_1 not in correct_types_checked:
correct_types_checked.add(type_1)
yield get_correct_case(type_1, func_to_use, comm)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment