Skip to content
Snippets Groups Projects
Commit eb977b6b authored by Jammer, Tim's avatar Jammer, Tim
Browse files

Refactoring: extract method

parent 8c67cc97
No related branches found
No related tags found
No related merge requests found
...@@ -16,6 +16,58 @@ from itertools import chain ...@@ -16,6 +16,58 @@ from itertools import chain
from scripts.Infrastructure.Variables import * from scripts.Infrastructure.Variables import *
# TODO refactoring into different file
# test if the tool chan deal with messages send over different communicators
predefined_comms = ["MPI_COMM_WORLD"]
comm_creators = ["mpi_comm_dup", "mpi_comm_dup_with_info", "mpi_comm_idup",
"mpi_comm_idup_with_info", "mpi_comm_create", "mpi_comm_create_group", "mpi_comm_split",
"mpi_comm_split_type", "mpi_comm_create_from_group"
]
intercomms = ["mpi_intercomm_create", "mpi_intercomm_merge", "mpi_intercomm_create_from_groups"]
def get_local_missmatch(type_1, type_2, send_func, recv_func):
tm = get_send_recv_template(send_func, recv_func)
tm.set_description("LocalParameterMissmatch-Dtype-" + send_func,
"datatype missmatch: Buffer: " + type_1 + " MPI_Call: " + type_2)
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 10, 10)
# local missmatch
for call in tm.get_instruction(identifier="MPICALL", return_list=True):
call.set_has_error()
call.set_arg("buf", buf_name_1)
call.set_arg("datatype", type_var_2)
return tm
def get_global_missmatch(type_1, type_2, size_1, size_2, send_func, recv_func, comm):
tm = get_send_recv_template(send_func, recv_func)
tm.set_description("GlobalParameterMissmatch-Dtype-" + send_func,
"datatype missmatch: Rank0: " + type_1 + " Rank1: " + type_2)
comm_var_name = "MPI_COMM_WORLD"
if comm in comm_creators:
comm_var_name = get_communicator(comm, tm)
if comm in intercomms:
comm_var_name = get_intercomm(comm, tm)
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, size_1, size_2)
call = tm.get_instruction(identifier="MPICALL", rank_excuting=0)
call.set_has_error()
call.set_arg("buf", buf_name_1)
call.set_arg("datatype", type_var_1)
call.set_arg("count", size_1)
call.set_arg("comm", comm_var_name)
call = tm.get_instruction(identifier="MPICALL", rank_excuting=1)
call.set_has_error()
call.set_arg("buf", buf_name_2)
call.set_arg("datatype", type_var_2)
call.set_arg("count", size_2)
call.set_arg("comm", comm_var_name)
return tm
class DtypeMissmatch(ErrorGenerator): class DtypeMissmatch(ErrorGenerator):
invalid_bufs = [CorrectParameterFactory().buf_var_name, "NULL"] invalid_bufs = [CorrectParameterFactory().buf_var_name, "NULL"]
send_funcs = ["mpi_send", send_funcs = ["mpi_send",
...@@ -27,14 +79,6 @@ class DtypeMissmatch(ErrorGenerator): ...@@ -27,14 +79,6 @@ class DtypeMissmatch(ErrorGenerator):
sendrecv_funcs = ["mpi_sendrecv", "mpi_sendrecv_replace"] sendrecv_funcs = ["mpi_sendrecv", "mpi_sendrecv_replace"]
# test if the tool chan deal with messages send over different communicators
predefined_comms = ["MPI_COMM_WORLD"]
comm_creators = ["mpi_comm_dup", "mpi_comm_dup_with_info", "mpi_comm_idup",
"mpi_comm_idup_with_info", "mpi_comm_create", "mpi_comm_create_group", "mpi_comm_split",
"mpi_comm_split_type", "mpi_comm_create_from_group"
]
intercomms = ["mpi_intercomm_create", "mpi_intercomm_merge", "mpi_intercomm_create_from_groups"]
def __init__(self): def __init__(self):
pass pass
...@@ -79,20 +123,10 @@ class DtypeMissmatch(ErrorGenerator): ...@@ -79,20 +123,10 @@ class DtypeMissmatch(ErrorGenerator):
checked_types.add(type_1) checked_types.add(type_1)
checked_types.add(type_2) checked_types.add(type_2)
tm = get_send_recv_template(send_func, recv_func) tm = get_local_missmatch(type_1, type_2, send_func, recv_func)
tm.set_description("LocalParameterMissmatch-Dtype-" + send_func,
"datatype missmatch: Buffer: " + type_1 + " MPI_Call: " + type_2)
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 10, 10)
# local missmatch
for call in tm.get_instruction(identifier="MPICALL", return_list=True):
call.set_has_error()
call.set_arg("buf", buf_name_1)
call.set_arg("datatype", type_var_2)
yield tm yield tm
for comm in self.predefined_comms + self.comm_creators + self.intercomms: for comm in predefined_comms + comm_creators + intercomms:
if comm != "MPI_COMM_WORLD" and generate_level < REAL_WORLD_TEST_LEVEL: if comm != "MPI_COMM_WORLD" and generate_level < REAL_WORLD_TEST_LEVEL:
continue continue
if generate_level == REAL_WORLD_TEST_LEVEL: if generate_level == REAL_WORLD_TEST_LEVEL:
...@@ -103,61 +137,13 @@ class DtypeMissmatch(ErrorGenerator): ...@@ -103,61 +137,13 @@ class DtypeMissmatch(ErrorGenerator):
recv_func, datatype=type_2.lower(), communicator=comm)): recv_func, datatype=type_2.lower(), communicator=comm)):
# not relevant in real world # not relevant in real world
continue continue
tm = get_send_recv_template(send_func, recv_func) tm = get_global_missmatch(type_1, type_2, 1, 1, send_func, recv_func, comm)
tm.set_description("GlobalParameterMissmatch-Dtype-" + send_func,
"datatype missmatch: Rank0: " + type_1 + " Rank1: " + type_2)
comm_var_name = "MPI_COMM_WORLD"
if comm in self.comm_creators:
comm_var_name = get_communicator(comm, tm)
if comm in self.intercomms:
comm_var_name = get_intercomm(comm, tm)
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, 1, 1)
# global missmatch with size 1
call = tm.get_instruction(identifier="MPICALL", rank_excuting=0)
call.set_has_error()
call.set_arg("buf", buf_name_1)
call.set_arg("datatype", type_var_1)
call.set_arg("count", 1)
call.set_arg("comm", comm_var_name)
call = tm.get_instruction(identifier="MPICALL", rank_excuting=1)
call.set_has_error()
call.set_arg("buf", buf_name_2)
call.set_arg("datatype", type_var_2)
call.set_arg("count", 1)
call.set_arg("comm", comm_var_name)
yield tm yield tm
tm = get_send_recv_template(send_func, recv_func)
tm.set_description("GlobalParameterMissmatch-Dtype-" + send_func,
"datatype missmatch: Rank0: " + type_1 + " Rank1: " + type_2)
comm_var_name = "MPI_COMM_WORLD"
if comm in self.comm_creators:
comm_var_name = get_communicator(comm, tm)
if comm in self.intercomms:
comm_var_name = get_intercomm(comm, tm)
# global missmatch with size = sizeof(a)* sizeof(b) so that total size match both types # global missmatch with size = sizeof(a)* sizeof(b) so that total size match both types
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_2, tm = get_global_missmatch(type_1, type_2, get_bytes_size_for_type(type_2),
get_bytes_size_for_type( get_bytes_size_for_type(type_1),
type_2), send_func, recv_func, comm)
get_bytes_size_for_type(
type_1))
call = tm.get_instruction(identifier="MPICALL", rank_excuting=0)
call.set_has_error()
call.set_arg("buf", buf_name_1)
call.set_arg("datatype", type_var_1)
call.set_arg("count", get_bytes_size_for_type(type_2))
call.set_arg("comm", comm_var_name)
call = tm.get_instruction(identifier="MPICALL", rank_excuting=1)
call.set_has_error()
call.set_arg("buf", buf_name_2)
call.set_arg("datatype", type_var_2)
call.set_arg("count", get_bytes_size_for_type(type_1))
call.set_arg("comm", comm_var_name)
yield tm yield tm
if generate_level <= BASIC_TEST_LEVEL: if generate_level <= BASIC_TEST_LEVEL:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment