Skip to content
Snippets Groups Projects
Commit 6d7f543a authored by Jammer, Tim's avatar Jammer, Tim
Browse files

fix correct cases allocating 2 buffers

parent 0e1c1acf
No related branches found
No related tags found
No related merge requests found
......@@ -5,7 +5,8 @@ from scripts.Infrastructure.Variables import *
from scripts.Infrastructure.ErrorGenerator import ErrorGenerator
from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory
from scripts.Infrastructure.TemplateFactory import get_collective_template, predefined_types, user_defined_types, \
predefined_mpi_dtype_consants, get_type_buffers, get_bytes_size_for_type, get_communicator, get_intercomm
predefined_mpi_dtype_consants, get_type_buffers, get_bytes_size_for_type, get_communicator, get_intercomm, \
get_buffer_for_type, get_buffer_for_usertype
# TODO refactor into different file
# test if the tool chan deal with messages send over different communicators
......@@ -105,7 +106,12 @@ def get_correct_case(type_1, count_1, func_to_use, comm):
if comm in intercomms:
comm_var_name = get_intercomm(comm, tm)
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_1, count_1, count_1)
if type_1 in predefined_types:
buf_name = get_buffer_for_type(type_1, count_1)
type_var = type_1
else:
buf_name, type_var = get_buffer_for_usertype(type_1, tm, tm._instructions[0], count_1)
tm.set_description("Correct-" + func_to_use, "")
......@@ -113,15 +119,15 @@ def get_correct_case(type_1, count_1, func_to_use, comm):
call.set_rank_executing(0)
call.set_arg("comm", comm_var_name)
if call.has_arg("recvbuf"):
call.set_arg("recvbuf", buf_name_1)
call.set_arg("sendbuf", buf_name_1)
call.set_arg("recvbuf", buf_name)
call.set_arg("sendbuf", buf_name)
else:
call.set_arg("buffer", buf_name_1)
call.set_arg("buffer", buf_name)
if call.has_arg("recvtype"):
call.set_arg("recvtype", type_var_1)
call.set_arg("sendtype", type_var_1)
call.set_arg("recvtype", type_var)
call.set_arg("sendtype", type_var)
else:
call.set_arg("datatype", type_var_1)
call.set_arg("datatype", type_var)
if call.has_arg("recvcount"):
call.set_arg("recvcount", count_1)
call.set_arg("sendcount", count_1)
......@@ -132,15 +138,15 @@ def get_correct_case(type_1, count_1, func_to_use, comm):
c.set_rank_executing('not0')
c.set_arg("comm", comm_var_name)
if c.has_arg("recvbuf"):
c.set_arg("recvbuf", buf_name_2)
c.set_arg("sendbuf", buf_name_2)
c.set_arg("recvbuf", buf_name)
c.set_arg("sendbuf", buf_name)
else:
c.set_arg("buffer", buf_name_2)
c.set_arg("buffer", buf_name)
if c.has_arg("recvtype"):
c.set_arg("recvtype", type_var_2)
c.set_arg("sendtype", type_var_2)
c.set_arg("recvtype", type_var)
c.set_arg("sendtype", type_var)
else:
c.set_arg("datatype", type_var_2)
c.set_arg("datatype", type_var)
if c.has_arg("recvcount"):
c.set_arg("recvcount", count_1)
c.set_arg("sendcount", count_1)
......
......@@ -11,7 +11,7 @@ from scripts.Infrastructure.ScoingModule.ScoringTable import is_combination_impo
from scripts.Infrastructure.Template import TemplateManager
from scripts.Infrastructure.TemplateFactory import get_send_recv_template, get_invalid_param_p2p_case, get_communicator, \
get_intercomm, predefined_types, user_defined_types, predefined_mpi_dtype_consants, get_type_buffers, \
get_bytes_size_for_type
get_bytes_size_for_type, get_buffer_for_type, get_buffer_for_usertype
from itertools import chain
......@@ -52,15 +52,21 @@ def get_correct_case(type_1, size_1, send_func, recv_func, comm):
if comm in intercomms:
comm_var_name = get_intercomm(comm, tm)
type_var_1, buf_name_1, type_var_2, buf_name_2 = get_type_buffers(tm, type_1, type_1, size_1, size_1)
call = tm.get_instruction(identifier="MPICALL", rank_excuting=1)
call.set_arg("buf", buf_name_1)
call.set_arg("datatype", type_var_1)
if type_1 in predefined_types:
buf_name = get_buffer_for_type(type_1, size_1)
type_var = type_1
else:
buf_name, type_var = get_buffer_for_usertype(type_1, tm, tm._instructions[0], size_1)
call.set_arg("buf", buf_name)
call.set_arg("datatype", type_var)
call.set_arg("count", size_1)
call.set_arg("comm", comm_var_name)
call = tm.get_instruction(identifier="MPICALL", rank_excuting=0)
call.set_arg("buf", buf_name_2)
call.set_arg("datatype", type_var_2)
call.set_arg("buf", buf_name)
call.set_arg("datatype", type_var)
call.set_arg("count", size_1)
call.set_arg("comm", comm_var_name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment