From eb977b6bf87d09aa2f649e313072089117339aa0 Mon Sep 17 00:00:00 2001 From: Tim Jammer <tim.jammer@tu-darmstadt.de> Date: Mon, 29 Apr 2024 12:36:01 +0200 Subject: [PATCH] Refactoring: extract method --- scripts/errors/dtypes/DtypeMissmatch.py | 130 +++++++++++------------- 1 file changed, 58 insertions(+), 72 deletions(-) diff --git a/scripts/errors/dtypes/DtypeMissmatch.py b/scripts/errors/dtypes/DtypeMissmatch.py index ae3e86e43..a763721f9 100644 --- a/scripts/errors/dtypes/DtypeMissmatch.py +++ b/scripts/errors/dtypes/DtypeMissmatch.py @@ -16,6 +16,58 @@ from itertools import chain from scripts.Infrastructure.Variables import * +# TODO refactoring into different file +# test if the tool chan deal with messages send over different communicators +predefined_comms = ["MPI_COMM_WORLD"] +comm_creators = ["mpi_comm_dup", "mpi_comm_dup_with_info", "mpi_comm_idup", + "mpi_comm_idup_with_info", "mpi_comm_create", "mpi_comm_create_group", "mpi_comm_split", + "mpi_comm_split_type", "mpi_comm_create_from_group" + ] +intercomms = ["mpi_intercomm_create", "mpi_intercomm_merge", "mpi_intercomm_create_from_groups"] + + +def get_local_missmatch(type_1, type_2, send_func, recv_func): + tm = get_send_recv_template(send_func, recv_func) + tm.set_description("LocalParameterMissmatch-Dtype-" + send_func, + "datatype missmatch: Buffer: " + type_1 + " MPI_Call: " + type_2) + type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 10, 10) + # local missmatch + for call in tm.get_instruction(identifier="MPICALL", return_list=True): + call.set_has_error() + call.set_arg("buf", buf_name_1) + call.set_arg("datatype", type_var_2) + + return tm + + +def get_global_missmatch(type_1, type_2, size_1, size_2, send_func, recv_func, comm): + tm = get_send_recv_template(send_func, recv_func) + tm.set_description("GlobalParameterMissmatch-Dtype-" + send_func, + "datatype missmatch: Rank0: " + type_1 + " Rank1: " + type_2) + comm_var_name = "MPI_COMM_WORLD" + if comm in comm_creators: + comm_var_name = get_communicator(comm, tm) + + if comm in intercomms: + comm_var_name = get_intercomm(comm, tm) + + type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, size_1, size_2) + call = tm.get_instruction(identifier="MPICALL", rank_excuting=0) + call.set_has_error() + call.set_arg("buf", buf_name_1) + call.set_arg("datatype", type_var_1) + call.set_arg("count", size_1) + call.set_arg("comm", comm_var_name) + call = tm.get_instruction(identifier="MPICALL", rank_excuting=1) + call.set_has_error() + call.set_arg("buf", buf_name_2) + call.set_arg("datatype", type_var_2) + call.set_arg("count", size_2) + call.set_arg("comm", comm_var_name) + + return tm + + class DtypeMissmatch(ErrorGenerator): invalid_bufs = [CorrectParameterFactory().buf_var_name, "NULL"] send_funcs = ["mpi_send", @@ -27,14 +79,6 @@ class DtypeMissmatch(ErrorGenerator): sendrecv_funcs = ["mpi_sendrecv", "mpi_sendrecv_replace"] - # test if the tool chan deal with messages send over different communicators - predefined_comms = ["MPI_COMM_WORLD"] - comm_creators = ["mpi_comm_dup", "mpi_comm_dup_with_info", "mpi_comm_idup", - "mpi_comm_idup_with_info", "mpi_comm_create", "mpi_comm_create_group", "mpi_comm_split", - "mpi_comm_split_type", "mpi_comm_create_from_group" - ] - intercomms = ["mpi_intercomm_create", "mpi_intercomm_merge", "mpi_intercomm_create_from_groups"] - def __init__(self): pass @@ -79,20 +123,10 @@ class DtypeMissmatch(ErrorGenerator): checked_types.add(type_1) checked_types.add(type_2) - tm = get_send_recv_template(send_func, recv_func) - tm.set_description("LocalParameterMissmatch-Dtype-" + send_func, - "datatype missmatch: Buffer: " + type_1 + " MPI_Call: " + type_2) - type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 10, 10) - - # local missmatch - for call in tm.get_instruction(identifier="MPICALL", return_list=True): - call.set_has_error() - call.set_arg("buf", buf_name_1) - call.set_arg("datatype", type_var_2) - + tm = get_local_missmatch(type_1, type_2, send_func, recv_func) yield tm - for comm in self.predefined_comms + self.comm_creators + self.intercomms: + for comm in predefined_comms + comm_creators + intercomms: if comm != "MPI_COMM_WORLD" and generate_level < REAL_WORLD_TEST_LEVEL: continue if generate_level == REAL_WORLD_TEST_LEVEL: @@ -103,61 +137,13 @@ class DtypeMissmatch(ErrorGenerator): recv_func, datatype=type_2.lower(), communicator=comm)): # not relevant in real world continue - tm = get_send_recv_template(send_func, recv_func) - tm.set_description("GlobalParameterMissmatch-Dtype-" + send_func, - "datatype missmatch: Rank0: " + type_1 + " Rank1: " + type_2) - comm_var_name = "MPI_COMM_WORLD" - if comm in self.comm_creators: - comm_var_name = get_communicator(comm, tm) - - if comm in self.intercomms: - comm_var_name = get_intercomm(comm, tm) - - type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 1, 1) - - # global missmatch with size 1 - call = tm.get_instruction(identifier="MPICALL", rank_excuting=0) - call.set_has_error() - call.set_arg("buf", buf_name_1) - call.set_arg("datatype", type_var_1) - call.set_arg("count", 1) - call.set_arg("comm", comm_var_name) - call = tm.get_instruction(identifier="MPICALL", rank_excuting=1) - call.set_has_error() - call.set_arg("buf", buf_name_2) - call.set_arg("datatype", type_var_2) - call.set_arg("count", 1) - call.set_arg("comm", comm_var_name) + tm = get_global_missmatch(type_1, type_2, 1, 1, send_func, recv_func, comm) yield tm - - tm = get_send_recv_template(send_func, recv_func) - tm.set_description("GlobalParameterMissmatch-Dtype-" + send_func, - "datatype missmatch: Rank0: " + type_1 + " Rank1: " + type_2) - comm_var_name = "MPI_COMM_WORLD" - if comm in self.comm_creators: - comm_var_name = get_communicator(comm, tm) - - if comm in self.intercomms: - comm_var_name = get_intercomm(comm, tm) # global missmatch with size = sizeof(a)* sizeof(b) so that total size match both types - type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, - get_bytes_size_for_type( - type_2), - get_bytes_size_for_type( - type_1)) - call = tm.get_instruction(identifier="MPICALL", rank_excuting=0) - call.set_has_error() - call.set_arg("buf", buf_name_1) - call.set_arg("datatype", type_var_1) - call.set_arg("count", get_bytes_size_for_type(type_2)) - call.set_arg("comm", comm_var_name) - call = tm.get_instruction(identifier="MPICALL", rank_excuting=1) - call.set_has_error() - call.set_arg("buf", buf_name_2) - call.set_arg("datatype", type_var_2) - call.set_arg("count", get_bytes_size_for_type(type_1)) - call.set_arg("comm", comm_var_name) + tm = get_global_missmatch(type_1, type_2, get_bytes_size_for_type(type_2), + get_bytes_size_for_type(type_1), + send_func, recv_func, comm) yield tm if generate_level <= BASIC_TEST_LEVEL: -- GitLab