From a0eda710e156b1ba58287317ca8a105e80795a90 Mon Sep 17 00:00:00 2001 From: Tim Jammer <tim.jammer@tu-darmstadt.de> Date: Mon, 29 Apr 2024 13:19:06 +0200 Subject: [PATCH] Refactoring: simplified code further --- scripts/errors/dtypes/DtypeMissmatch.py | 172 ++++++++++++++---------- 1 file changed, 104 insertions(+), 68 deletions(-) diff --git a/scripts/errors/dtypes/DtypeMissmatch.py b/scripts/errors/dtypes/DtypeMissmatch.py index a763721f9..28b01b0ef 100644 --- a/scripts/errors/dtypes/DtypeMissmatch.py +++ b/scripts/errors/dtypes/DtypeMissmatch.py @@ -1,5 +1,6 @@ #! /usr/bin/python3 from copy import copy +from random import shuffle from scripts.Infrastructure.AllocCall import AllocCall from scripts.Infrastructure.ErrorGenerator import ErrorGenerator @@ -40,6 +41,32 @@ def get_local_missmatch(type_1, type_2, send_func, recv_func): return tm +def get_correct_case(type_1, size_1, send_func, recv_func, comm): + tm = get_send_recv_template(send_func, recv_func) + tm.set_description("Correct-" + send_func, + "") + comm_var_name = "MPI_COMM_WORLD" + if comm in comm_creators: + comm_var_name = get_communicator(comm, tm) + + if comm in intercomms: + comm_var_name = get_intercomm(comm, tm) + + type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_1, size_1, size_1) + call = tm.get_instruction(identifier="MPICALL", rank_excuting=0) + call.set_arg("buf", buf_name_1) + call.set_arg("datatype", type_var_1) + call.set_arg("count", size_1) + call.set_arg("comm", comm_var_name) + call = tm.get_instruction(identifier="MPICALL", rank_excuting=1) + call.set_arg("buf", buf_name_2) + call.set_arg("datatype", type_var_2) + call.set_arg("count", size_1) + call.set_arg("comm", comm_var_name) + + return tm + + def get_global_missmatch(type_1, type_2, size_1, size_2, send_func, recv_func, comm): tm = get_send_recv_template(send_func, recv_func) tm.set_description("GlobalParameterMissmatch-Dtype-" + send_func, @@ -68,6 +95,25 @@ def get_global_missmatch(type_1, type_2, size_1, size_2, send_func, recv_func, c return tm +def is_combination_compatible(s, r): + t1, send_func, c1 = s + t2, recv_func, c2 = r + + if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"] and recv_func not in ["mpi_irecv", + "mpi_recv_init", + "mpi_precv_init"]: + # leads to deadlock + return False + + if t1 in predefined_types and t2 in predefined_types and predefined_mpi_dtype_consants[ + t1] == predefined_mpi_dtype_consants[t2] and not (t1 == "MPI_BYTE" or t2 == "MPI_BYTE"): + # one type is just the alias of another, this is allowed + # but BYTE may not be mixed with other types see standard section 3.3.1 + return False + + return c1 == c2 and t1 != t2 + + class DtypeMissmatch(ErrorGenerator): invalid_bufs = [CorrectParameterFactory().buf_var_name, "NULL"] send_funcs = ["mpi_send", @@ -86,72 +132,62 @@ class DtypeMissmatch(ErrorGenerator): return ["P2P"] def generate(self, generate_level, real_world_score_table): - for send_func in self.send_funcs: + + # (type,func,comm) + important_sends = [] + important_recvs = [] # + for type in predefined_types + user_defined_types: + for send_func in self.send_funcs: + for comm in predefined_comms + comm_creators + intercomms: + important_sends.append((type, send_func, comm)) + + for type in predefined_types + user_defined_types: for recv_func in self.recv_funcs: - if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"] and recv_func not in ["mpi_irecv", - "mpi_recv_init", - "mpi_precv_init"]: - # invalid combination resulting in deadlock - continue - - checked_types = set() - - for type_1 in predefined_types + user_defined_types: - for type_2 in predefined_types + user_defined_types: - if type_1 == type_2: - # skip: valid case - continue - if type_1 in predefined_types and type_2 in predefined_types and predefined_mpi_dtype_consants[ - type_1] == predefined_mpi_dtype_consants[type_2]: - # one type is just the alias of another, this is allowed - if not (type_2 == "MPI_BYTE" or type_1 == "MPI_BYTE"): - # but BYTE may not be mixed with other types see standard section 3.3.1 - continue - if generate_level < REAL_WORLD_TEST_LEVEL and ( - type_1 in checked_types or type_2 in checked_types): - # unnecessary repetition - continue - - if generate_level == REAL_WORLD_TEST_LEVEL: - if not is_combination_important(real_world_score_table, send_func, - datatype=type_1.lower()) or not is_combination_important( - real_world_score_table, recv_func, datatype=type_2.lower()): - # not relevant in real world - # print("irrelevant: %s %s -> %s %s"%(send_func,type_1,recv_func,type_2)) - continue - - checked_types.add(type_1) - checked_types.add(type_2) - - tm = get_local_missmatch(type_1, type_2, send_func, recv_func) - yield tm - - for comm in predefined_comms + comm_creators + intercomms: - if comm != "MPI_COMM_WORLD" and generate_level < REAL_WORLD_TEST_LEVEL: - continue - if generate_level == REAL_WORLD_TEST_LEVEL: - if (not is_combination_important(real_world_score_table, send_func, - datatype=type_1.lower(), - communicator=comm) or not - is_combination_important(real_world_score_table, - recv_func, datatype=type_2.lower(), communicator=comm)): - # not relevant in real world - continue - tm = get_global_missmatch(type_1, type_2, 1, 1, send_func, recv_func, comm) - - yield tm - # global missmatch with size = sizeof(a)* sizeof(b) so that total size match both types - tm = get_global_missmatch(type_1, type_2, get_bytes_size_for_type(type_2), - get_bytes_size_for_type(type_1), - send_func, recv_func, comm) - - yield tm - if generate_level <= BASIC_TEST_LEVEL: - return - - # end for each pair of send/recv - if generate_level < REAL_WORLD_TEST_LEVEL: - return - - # TODO mrecv? - # TODO sendrecv? + for comm in predefined_comms + comm_creators + intercomms: + important_recvs.append((type, recv_func, comm)) + + # filter to only important ones + if generate_level == REAL_WORLD_TEST_LEVEL: + important_sends = [(t, f, c) for (t, f, c) in important_sends if + is_combination_important(real_world_score_table, f, + datatype=t.lower(), + communicator=c)] + important_recvs = [(t, f, c) for (t, f, c) in important_recvs if + is_combination_important(real_world_score_table, f, + datatype=t.lower(), + communicator=c)] + + print("number of important recvs:") + print(len(important_recvs)) + + print("number of important sends:") + print(len(important_sends)) + + # all possible combinations + combinations_to_use = [(s, r) for s in important_sends for r in important_recvs if + is_combination_compatible(s, r)] + # "re-format" + combinations_to_use = [(t1, t2, s, r, c) for (t1, s, c), (t2, r, _) in combinations_to_use] + + print("combinations:") + print(len(combinations_to_use)) + + correct_types_checked = set() + for type_1, type_2, send_func, recv_func, comm in combinations_to_use: + # local missmatch only for one communicator + if comm == "MPI_COMM_WORLD": + yield get_local_missmatch(type_1, type_2, send_func, recv_func) + + # global missmatch: communicator is important + yield get_global_missmatch(type_1, type_2, 1, 1, send_func, recv_func, comm) + + # global missmatch with size = sizeof(a)* sizeof(b) so that total size match both types + yield get_global_missmatch(type_1, type_2, get_bytes_size_for_type(type_2), + get_bytes_size_for_type(type_1), send_func, recv_func, comm) + + if type_1 not in correct_types_checked: + correct_types_checked.add(type_1) + yield get_correct_case(type_1, 1, send_func, recv_func, comm) + + # TODO mrecv? + # TODO sendrecv? -- GitLab