Skip to content
Snippets Groups Projects
Commit a0eda710 authored by Jammer, Tim's avatar Jammer, Tim
Browse files

Refactoring: simplified code further

parent eb977b6b
No related branches found
No related tags found
No related merge requests found
#! /usr/bin/python3 #! /usr/bin/python3
from copy import copy from copy import copy
from random import shuffle
from scripts.Infrastructure.AllocCall import AllocCall from scripts.Infrastructure.AllocCall import AllocCall
from scripts.Infrastructure.ErrorGenerator import ErrorGenerator from scripts.Infrastructure.ErrorGenerator import ErrorGenerator
...@@ -40,6 +41,32 @@ def get_local_missmatch(type_1, type_2, send_func, recv_func): ...@@ -40,6 +41,32 @@ def get_local_missmatch(type_1, type_2, send_func, recv_func):
return tm return tm
def get_correct_case(type_1, size_1, send_func, recv_func, comm):
tm = get_send_recv_template(send_func, recv_func)
tm.set_description("Correct-" + send_func,
"")
comm_var_name = "MPI_COMM_WORLD"
if comm in comm_creators:
comm_var_name = get_communicator(comm, tm)
if comm in intercomms:
comm_var_name = get_intercomm(comm, tm)
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_1, size_1, size_1)
call = tm.get_instruction(identifier="MPICALL", rank_excuting=0)
call.set_arg("buf", buf_name_1)
call.set_arg("datatype", type_var_1)
call.set_arg("count", size_1)
call.set_arg("comm", comm_var_name)
call = tm.get_instruction(identifier="MPICALL", rank_excuting=1)
call.set_arg("buf", buf_name_2)
call.set_arg("datatype", type_var_2)
call.set_arg("count", size_1)
call.set_arg("comm", comm_var_name)
return tm
def get_global_missmatch(type_1, type_2, size_1, size_2, send_func, recv_func, comm): def get_global_missmatch(type_1, type_2, size_1, size_2, send_func, recv_func, comm):
tm = get_send_recv_template(send_func, recv_func) tm = get_send_recv_template(send_func, recv_func)
tm.set_description("GlobalParameterMissmatch-Dtype-" + send_func, tm.set_description("GlobalParameterMissmatch-Dtype-" + send_func,
...@@ -68,6 +95,25 @@ def get_global_missmatch(type_1, type_2, size_1, size_2, send_func, recv_func, c ...@@ -68,6 +95,25 @@ def get_global_missmatch(type_1, type_2, size_1, size_2, send_func, recv_func, c
return tm return tm
def is_combination_compatible(s, r):
t1, send_func, c1 = s
t2, recv_func, c2 = r
if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"] and recv_func not in ["mpi_irecv",
"mpi_recv_init",
"mpi_precv_init"]:
# leads to deadlock
return False
if t1 in predefined_types and t2 in predefined_types and predefined_mpi_dtype_consants[
t1] == predefined_mpi_dtype_consants[t2] and not (t1 == "MPI_BYTE" or t2 == "MPI_BYTE"):
# one type is just the alias of another, this is allowed
# but BYTE may not be mixed with other types see standard section 3.3.1
return False
return c1 == c2 and t1 != t2
class DtypeMissmatch(ErrorGenerator): class DtypeMissmatch(ErrorGenerator):
invalid_bufs = [CorrectParameterFactory().buf_var_name, "NULL"] invalid_bufs = [CorrectParameterFactory().buf_var_name, "NULL"]
send_funcs = ["mpi_send", send_funcs = ["mpi_send",
...@@ -86,72 +132,62 @@ class DtypeMissmatch(ErrorGenerator): ...@@ -86,72 +132,62 @@ class DtypeMissmatch(ErrorGenerator):
return ["P2P"] return ["P2P"]
def generate(self, generate_level, real_world_score_table): def generate(self, generate_level, real_world_score_table):
# (type,func,comm)
important_sends = []
important_recvs = [] #
for type in predefined_types + user_defined_types:
for send_func in self.send_funcs: for send_func in self.send_funcs:
for comm in predefined_comms + comm_creators + intercomms:
important_sends.append((type, send_func, comm))
for type in predefined_types + user_defined_types:
for recv_func in self.recv_funcs: for recv_func in self.recv_funcs:
if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"] and recv_func not in ["mpi_irecv", for comm in predefined_comms + comm_creators + intercomms:
"mpi_recv_init", important_recvs.append((type, recv_func, comm))
"mpi_precv_init"]:
# invalid combination resulting in deadlock
continue
checked_types = set()
for type_1 in predefined_types + user_defined_types:
for type_2 in predefined_types + user_defined_types:
if type_1 == type_2:
# skip: valid case
continue
if type_1 in predefined_types and type_2 in predefined_types and predefined_mpi_dtype_consants[
type_1] == predefined_mpi_dtype_consants[type_2]:
# one type is just the alias of another, this is allowed
if not (type_2 == "MPI_BYTE" or type_1 == "MPI_BYTE"):
# but BYTE may not be mixed with other types see standard section 3.3.1
continue
if generate_level < REAL_WORLD_TEST_LEVEL and (
type_1 in checked_types or type_2 in checked_types):
# unnecessary repetition
continue
# filter to only important ones
if generate_level == REAL_WORLD_TEST_LEVEL: if generate_level == REAL_WORLD_TEST_LEVEL:
if not is_combination_important(real_world_score_table, send_func, important_sends = [(t, f, c) for (t, f, c) in important_sends if
datatype=type_1.lower()) or not is_combination_important( is_combination_important(real_world_score_table, f,
real_world_score_table, recv_func, datatype=type_2.lower()): datatype=t.lower(),
# not relevant in real world communicator=c)]
# print("irrelevant: %s %s -> %s %s"%(send_func,type_1,recv_func,type_2)) important_recvs = [(t, f, c) for (t, f, c) in important_recvs if
continue is_combination_important(real_world_score_table, f,
datatype=t.lower(),
communicator=c)]
print("number of important recvs:")
print(len(important_recvs))
print("number of important sends:")
print(len(important_sends))
# all possible combinations
combinations_to_use = [(s, r) for s in important_sends for r in important_recvs if
is_combination_compatible(s, r)]
# "re-format"
combinations_to_use = [(t1, t2, s, r, c) for (t1, s, c), (t2, r, _) in combinations_to_use]
print("combinations:")
print(len(combinations_to_use))
correct_types_checked = set()
for type_1, type_2, send_func, recv_func, comm in combinations_to_use:
# local missmatch only for one communicator
if comm == "MPI_COMM_WORLD":
yield get_local_missmatch(type_1, type_2, send_func, recv_func)
# global missmatch: communicator is important
yield get_global_missmatch(type_1, type_2, 1, 1, send_func, recv_func, comm)
checked_types.add(type_1)
checked_types.add(type_2)
tm = get_local_missmatch(type_1, type_2, send_func, recv_func)
yield tm
for comm in predefined_comms + comm_creators + intercomms:
if comm != "MPI_COMM_WORLD" and generate_level < REAL_WORLD_TEST_LEVEL:
continue
if generate_level == REAL_WORLD_TEST_LEVEL:
if (not is_combination_important(real_world_score_table, send_func,
datatype=type_1.lower(),
communicator=comm) or not
is_combination_important(real_world_score_table,
recv_func, datatype=type_2.lower(), communicator=comm)):
# not relevant in real world
continue
tm = get_global_missmatch(type_1, type_2, 1, 1, send_func, recv_func, comm)
yield tm
# global missmatch with size = sizeof(a)* sizeof(b) so that total size match both types # global missmatch with size = sizeof(a)* sizeof(b) so that total size match both types
tm = get_global_missmatch(type_1, type_2, get_bytes_size_for_type(type_2), yield get_global_missmatch(type_1, type_2, get_bytes_size_for_type(type_2),
get_bytes_size_for_type(type_1), get_bytes_size_for_type(type_1), send_func, recv_func, comm)
send_func, recv_func, comm)
yield tm
if generate_level <= BASIC_TEST_LEVEL:
return
# end for each pair of send/recv if type_1 not in correct_types_checked:
if generate_level < REAL_WORLD_TEST_LEVEL: correct_types_checked.add(type_1)
return yield get_correct_case(type_1, 1, send_func, recv_func, comm)
# TODO mrecv? # TODO mrecv?
# TODO sendrecv? # TODO sendrecv?
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment