Skip to content
Snippets Groups Projects
Commit 52c50f1b authored by Jammer, Tim's avatar Jammer, Tim
Browse files

Extract method and adjusted when a combnation is labled important

parent cf1ef29f
No related branches found
No related tags found
No related merge requests found
......@@ -240,6 +240,68 @@ def get_type_buffers(tm, type_1, type_2,size_1,size_2):
return type_1_variable, buf_name_1, type_2_variable, buf_name_2
# TODO: move to other file(scoring_table?)
def is_combination_important(real_world_score_table, call,
rank=None, tag=None, count=None, datatype=None, communicator=None, op=None):
# the scoreboard has other values == wildcard, we treat the given standard values as matching
standard_rank = 0
standard_tag = 0
standard_count = 1
standard_datatype = "MPI_INT".lower()
standard_communicator = "MPI_COMM_WORLD".lower()
standard_op = "MPI_SUM"
# Filter based on the 'call' column
relevant = real_world_score_table[real_world_score_table["call"] == call]
# Filter by 'rank'
if rank is not None:
if rank == standard_rank:
relevant = relevant[(relevant["RANK"] == rank) | (relevant["RANK"] == "other")]
else:
relevant = relevant[relevant["RANK"] == rank]
# Filter by 'tag'
if tag is not None:
if tag == standard_tag:
relevant = relevant[(relevant["TAG"] == tag) | (relevant["TAG"] == "other")]
else:
relevant = relevant[relevant["TAG"] == tag]
# Filter by 'count'
if count is not None:
if count == standard_count:
relevant = relevant[(relevant["COUNT"] == count) | (relevant["COUNT"] == "other")]
else:
relevant = relevant[relevant["COUNT"] == count]
# Filter by 'datatype'
if datatype is not None:
if datatype.lower() == standard_datatype:
relevant = relevant[
(relevant["DATATYPE"] == datatype) | (relevant["DATATYPE"] == "other")]
else:
relevant = relevant[relevant["DATATYPE"] == datatype]
# Filter by 'communicator'
if communicator is not None:
if communicator.lower() == standard_communicator:
relevant = relevant[
(relevant["COMMUNICATOR"] == communicator) | (relevant["COMMUNICATOR"] == "other")]
else:
relevant = relevant[relevant["COMMUNICATOR"] == communicator]
# Filter by 'op'
if op is not None:
if op == standard_op:
relevant = relevant[(relevant["OP"] == op) | (relevant["OP"] == "other")]
else:
relevant = relevant[relevant["OP"] == op]
return len(relevant) > 0
class DtypeMissmatch(ErrorGenerator):
invalid_bufs = [CorrectParameterFactory().buf_var_name, "NULL"]
send_funcs = ["mpi_send",
......@@ -293,13 +355,9 @@ class DtypeMissmatch(ErrorGenerator):
continue
if generation_level == REAL_WORLD_TEST_LEVEL:
this_case_scores_send = (real_world_score_table[
(real_world_score_table['call'] == send_func) &
(real_world_score_table['DATATYPE'] == type_1.lower())])
this_case_scores_recv = (real_world_score_table[
(real_world_score_table['call'] == recv_func) &
(real_world_score_table['DATATYPE'] == type_2.lower())])
if len(this_case_scores_send) == 0 or len(this_case_scores_recv) == 0:
if not is_combination_important(real_world_score_table, send_func,
datatype=type_1.lower()) or not is_combination_important(
real_world_score_table, recv_func, datatype=type_2.lower()):
# not relevant in real world
# print("irrelevant: %s %s -> %s %s"%(send_func,type_1,recv_func,type_2))
continue
......@@ -324,11 +382,11 @@ class DtypeMissmatch(ErrorGenerator):
if comm != "MPI_COMM_WORLD" and generation_level < REAL_WORLD_TEST_LEVEL:
continue
if generation_level == REAL_WORLD_TEST_LEVEL:
this_case_scores_send_comm = this_case_scores_send[
this_case_scores_send['COMMUNICATOR'] == comm.lower()]
this_case_scores_recv_comm = this_case_scores_recv[
this_case_scores_recv['COMMUNICATOR'] == comm.lower()]
if len(this_case_scores_send_comm) == 0 or len(this_case_scores_recv_comm) == 0:
if (not is_combination_important(real_world_score_table, send_func,
datatype=type_1.lower(),
communicator=comm) or not
is_combination_important(real_world_score_table,
recv_func, datatype=type_2.lower(), communicator=comm)):
# not relevant in real world
continue
tm = get_send_recv_template(send_func, recv_func)
......@@ -369,7 +427,11 @@ class DtypeMissmatch(ErrorGenerator):
if comm in self.intercomms:
comm_var_name = get_intercomm(comm, tm)
# global missmatch with size = sizeof(a)* sizeof(b) so that total size match both types
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, get_bytes_size_for_type(type_2), get_bytes_size_for_type(type_1))
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2,
get_bytes_size_for_type(
type_2),
get_bytes_size_for_type(
type_1))
call = tm.get_instruction(identifier="MPICALL", rank_excuting=0)
call.set_has_error()
call.set_arg("buf", buf_name_1)
......@@ -393,4 +455,3 @@ class DtypeMissmatch(ErrorGenerator):
# TODO mrecv?
# TODO sendrecv?
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment