diff --git a/scripts/errors/coll/ParamMatchingType.py b/scripts/errors/coll/ParamMatchingType.py index e03e18167731ac6012b005752e671a105b0a51f3..a415a42bcdb9621f59677e67de3c4200ff4009a5 100644 --- a/scripts/errors/coll/ParamMatchingType.py +++ b/scripts/errors/coll/ParamMatchingType.py @@ -5,7 +5,8 @@ from scripts.Infrastructure.Variables import * from scripts.Infrastructure.ErrorGenerator import ErrorGenerator from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory from scripts.Infrastructure.TemplateFactory import get_collective_template, predefined_types, user_defined_types, \ - predefined_mpi_dtype_consants, get_type_buffers, get_bytes_size_for_type, get_communicator, get_intercomm + predefined_mpi_dtype_consants, get_type_buffers, get_bytes_size_for_type, get_communicator, get_intercomm, \ + get_buffer_for_type, get_buffer_for_usertype # TODO refactor into different file # test if the tool chan deal with messages send over different communicators @@ -105,7 +106,12 @@ def get_correct_case(type_1, count_1, func_to_use, comm): if comm in intercomms: comm_var_name = get_intercomm(comm, tm) - type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_1, count_1, count_1) + + if type_1 in predefined_types: + buf_name = get_buffer_for_type(type_1, count_1) + type_var = type_1 + else: + buf_name, type_var = get_buffer_for_usertype(type_1, tm, tm._instructions[0], count_1) tm.set_description("Correct-" + func_to_use, "") @@ -113,15 +119,15 @@ def get_correct_case(type_1, count_1, func_to_use, comm): call.set_rank_executing(0) call.set_arg("comm", comm_var_name) if call.has_arg("recvbuf"): - call.set_arg("recvbuf", buf_name_1) - call.set_arg("sendbuf", buf_name_1) + call.set_arg("recvbuf", buf_name) + call.set_arg("sendbuf", buf_name) else: - call.set_arg("buffer", buf_name_1) + call.set_arg("buffer", buf_name) if call.has_arg("recvtype"): - call.set_arg("recvtype", type_var_1) - call.set_arg("sendtype", type_var_1) + call.set_arg("recvtype", type_var) + call.set_arg("sendtype", type_var) else: - call.set_arg("datatype", type_var_1) + call.set_arg("datatype", type_var) if call.has_arg("recvcount"): call.set_arg("recvcount", count_1) call.set_arg("sendcount", count_1) @@ -132,15 +138,15 @@ def get_correct_case(type_1, count_1, func_to_use, comm): c.set_rank_executing('not0') c.set_arg("comm", comm_var_name) if c.has_arg("recvbuf"): - c.set_arg("recvbuf", buf_name_2) - c.set_arg("sendbuf", buf_name_2) + c.set_arg("recvbuf", buf_name) + c.set_arg("sendbuf", buf_name) else: - c.set_arg("buffer", buf_name_2) + c.set_arg("buffer", buf_name) if c.has_arg("recvtype"): - c.set_arg("recvtype", type_var_2) - c.set_arg("sendtype", type_var_2) + c.set_arg("recvtype", type_var) + c.set_arg("sendtype", type_var) else: - c.set_arg("datatype", type_var_2) + c.set_arg("datatype", type_var) if c.has_arg("recvcount"): c.set_arg("recvcount", count_1) c.set_arg("sendcount", count_1) diff --git a/scripts/errors/dtypes/DtypeMissmatch.py b/scripts/errors/dtypes/DtypeMissmatch.py index 9adfbb19e9d70ca30cf0e9de164cb426d43e44bc..febd745830129fe417d6d65653247bd8db9e765f 100644 --- a/scripts/errors/dtypes/DtypeMissmatch.py +++ b/scripts/errors/dtypes/DtypeMissmatch.py @@ -11,7 +11,7 @@ from scripts.Infrastructure.ScoingModule.ScoringTable import is_combination_impo from scripts.Infrastructure.Template import TemplateManager from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_invalid_param_p2p_case, get_communicator, \ get_intercomm, predefined_types, user_defined_types, predefined_mpi_dtype_consants, get_type_buffers, \ - get_bytes_size_for_type + get_bytes_size_for_type, get_buffer_for_type, get_buffer_for_usertype from itertools import chain @@ -52,15 +52,21 @@ def get_correct_case(type_1, size_1, send_func, recv_func, comm): if comm in intercomms: comm_var_name = get_intercomm(comm, tm) - type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_1, size_1, size_1) call = tm.get_instruction(identifier="MPICALL", rank_excuting=1) - call.set_arg("buf", buf_name_1) - call.set_arg("datatype", type_var_1) + + if type_1 in predefined_types: + buf_name = get_buffer_for_type(type_1, size_1) + type_var = type_1 + else: + buf_name, type_var = get_buffer_for_usertype(type_1, tm, tm._instructions[0], size_1) + + call.set_arg("buf", buf_name) + call.set_arg("datatype", type_var) call.set_arg("count", size_1) call.set_arg("comm", comm_var_name) call = tm.get_instruction(identifier="MPICALL", rank_excuting=0) - call.set_arg("buf", buf_name_2) - call.set_arg("datatype", type_var_2) + call.set_arg("buf", buf_name) + call.set_arg("datatype", type_var) call.set_arg("count", size_1) call.set_arg("comm", comm_var_name)