Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
Loading items

Target

Select target project
  • hpc-public/mpi-bugbench
1 result
Select Git revision
Loading items
Show changes

Commits on Source 24

#! /usr/bin/python3
from __future__ import annotations
from typing_extensions import override
from scripts.Infrastructure.Instruction import Instruction
......@@ -25,7 +27,8 @@ alloc_template = """
class AllocCall(Instruction):
@override
def __init__(self, type: str, num_elements: str, name: str = "buf", use_malloc: bool = False):
def __init__(self, type: str, num_elements: str, name: str = "buf", use_malloc: bool = False,
rank: str | int = 'all', identifier: str = None):
"""
Creates a New allocation Call
......@@ -35,7 +38,7 @@ class AllocCall(Instruction):
name: name of buffer variable
use_malloc: True: use Malloc, False: use calloc for allocation
"""
super().__init__("")
super().__init__("", rank,identifier)
self._use_malloc = use_malloc
self._type = type
self._num_elements = num_elements
......
......@@ -2,7 +2,6 @@
from scripts.Infrastructure import MPICall
from scripts.Infrastructure.Instruction import Instruction
from scripts.Infrastructure.MPICallFactory import MPICallFactory
from scripts.Infrastructure.Template import InstructionBlock
from scripts.Infrastructure.AllocCall import AllocCall, get_free
......@@ -29,7 +28,8 @@ class CorrectParameterFactory:
return self.buf_var_name
if param in ["COUNT", "count", "sendcount", "recvcount", "origin_count", "target_count", "result_count"]:
return str(self.buf_size)
if param in ["DATATYPE", "datatype", "sendtype", "recvtype", "origin_datatype", "target_datatype", "result_datatype"]:
if param in ["DATATYPE", "datatype", "sendtype", "recvtype", "origin_datatype", "target_datatype",
"result_datatype"]:
return self.dtype[1]
if param in ["DEST", "dest", "target_rank"]:
return "0"
......@@ -44,7 +44,7 @@ class CorrectParameterFactory:
if param in ["COMM", "comm"]:
return "MPI_COMM_WORLD"
if param in ["newcomm", "newintercomm"]:
return "newcomm"
return "mpi_comm_0"
if param in ["STATUS", "status"]:
return "MPI_STATUS_IGNORE"
if param in ["OPERATION", "op"]:
......@@ -56,17 +56,21 @@ class CorrectParameterFactory:
if param in ["PARTITION", "partition"]:
return "0"
if param in ["REQUEST", "request"]:
return "&request"
return "&mpi_request_0"
if param in ["GROUP", "group"]:
return "&group"
return "&mpi_group_0"
if param in ["color"]:
return "1"
if param in ["message"]:
return "&mpi_message_0"
if param in ["flag"]:
return "&int_0"
if param in ["split_type"]:
return "MPI_COMM_TYPE_SHARED"
if param in ["key"]:
return "rank"
if param in ["errhandler"]:
return "MPI_ERRORS_ARE_FATAL" #
return "MPI_ERRORS_ARE_FATAL"
if param in ["local_comm"]:
return "MPI_COMM_SELF"
if param in ["local_leader"]:
......@@ -78,7 +82,7 @@ class CorrectParameterFactory:
if param in ["target_disp"]:
return "0"
if param in ["win"]:
return "win"
return "mpi_win_0"
if param in ["baseptr"]:
return "&" + self.winbuf_var_name
if param in ["base"]:
......@@ -96,6 +100,19 @@ class CorrectParameterFactory:
print("Not Implemented: " + param)
assert False, "Param not known"
def get_initializer(self, variable_type: str) -> str:
if variable_type == "int":
return "0"
if variable_type == "MPI_Request":
return "MPI_REQUEST_NULL"
if variable_type == "MPI_Comm":
return "MPI_COMM_NULL"
if variable_type == "MPI_Message":
return "MPI_MESSAGE_NULL"
# TODO implement other types
print("Not Implemented: " + variable_type)
assert False, "Param not known"
# todo also for send and non default args
def get_matching_recv(call: MPICall) -> MPICall:
......
......@@ -69,7 +69,8 @@ class GeneratorManager:
return case_name + "-" + str(num).zfill(digits_to_use) + suffix
def generate(self, outpath: str | Path | os.PathLike[str], filterlist_:typing.Sequence[str]=None, print_progress_bar:bool=True, overwrite:bool=True, generate_full_set:bool=False,
def generate(self, outpath: str | Path | os.PathLike[str], filterlist_: typing.Sequence[str] = None,
print_progress_bar: bool = True, overwrite: bool = True, generate_full_set: bool = False,
try_compile: bool = False, max_mpi_version: str = "4.0", use_clang_format: bool = True):
"""
Generates test cases based on the specified parameters.
......@@ -115,6 +116,9 @@ class GeneratorManager:
cases_generated = 0
for generator in generators_to_use:
# use first feature as category if generator has multiple
category_path = os.path.join(outpath, generator.get_feature()[0])
os.makedirs(category_path, exist_ok=True)
for result_error in generator.generate(generate_full_set):
assert isinstance(result_error, TemplateManager)
......@@ -122,7 +126,7 @@ class GeneratorManager:
if not float(result_error.get_version()) > float(max_mpi_version):
case_name = result_error.get_short_descr()
fname = self.get_filename(case_name)
full_name = os.path.join(outpath, fname)
full_name = os.path.join(category_path, fname)
if not overwrite and os.path.isfile(full_name):
assert False and "File Already exists"
......
#! /usr/bin/python3
from __future__ import annotations
from scripts.Infrastructure.Variables import ERROR_MARKER_COMMENT_BEGIN, ERROR_MARKER_COMMENT_END
class Instruction(object):
"""
Base class to represent an Instruction
the identifier is used, in order to reference that instruction in the Template Manager (e.g. to change it). can be None
"""
def __init__(self, str_representation):
def __init__(self, str_representation: str, rank: str | int = 'all', identifier: str = None):
self._str_representation = str_representation
self._has_error = False
self._identifier = identifier
if isinstance(rank, str):
assert rank in ['all', 'not0']
self._rank = rank
def set_has_error(self, has_error: bool = True):
self._has_error = has_error
def get_identifier(self) -> str:
return self._identifier
def set_identifier(self, identifier: str):
self._identifier = identifier
def get_rank_executing(self) -> str | int:
return self._rank
def set_rank_executing(self, rank: str | int):
if isinstance(rank, str):
assert rank in ['all', 'not0']
self._rank = rank
def __str__(self):
if self._has_error:
return ERROR_MARKER_COMMENT_BEGIN + self._str_representation + ERROR_MARKER_COMMENT_END
......
from __future__ import annotations
import typing
from scripts.Infrastructure.Instruction import Instruction
from scripts.Infrastructure.MPICall import MPICall
class InstructionBlock:
"""
Class Overview:
The `InstructionBlock` class represents a block of instructions in a Testcase (to be registered for a template).
First, the Instructions for all ranks are executed (in the order they are registered)
Then each thread executes the instructions registered to this specific rank
If one need a different Order: use multiple Instruction Blocks
Methods:
- `__init__(self)`: Initializes a new instance of the InstructionBlock class.
- `register_operation(self, op, kind='all')`: Registers an operation based on rank.
- `get_version(self)`: Retrieves required MPI version
- `__str__(self)`: Converts the InstructionBlock instance to a string, replacing placeholders.
"""
def __init__(self, name: str = None):
"""
Initialize an empty InstructionBlock
Parameters:
- name (str): The name of the block (for referencing this block with the template Manager)
May be None, does not influence the code generated
"""
self.operations = {'all': [], 'not0': [], }
assert not isinstance(name, int)
self.name = name
def register_instruction(self, op: str | Instruction | typing.List[Instruction], kind: str | int = 'all'):
"""
Registers an operation based on rank.
Parameters:
- op: The operation (or list of Operations) to register.
- kind: Rank to execute the operation ('all', 'not0', or integer).
- all: all Ranks execute this operation
- not0: all Ranks but the Root (rank 0) execute
- Or the integer of the rank that should execute
Note: if a str is passed as the operation, it will create a new Instruction from the given string
"""
if isinstance(op, str):
op = Instruction(op)
if kind == 'all':
if isinstance(op, list):
self.operations['all'].extend(op)
else:
self.operations['all'].append(op)
elif kind == 'not0':
if isinstance(op, list):
self.operations['not0'].extend(op)
else:
self.operations['not0'].append(op)
else:
as_int = int(kind) # will Raise ValueError if not integer
if as_int not in self.operations:
self.operations[as_int] = []
if isinstance(op, list):
self.operations[as_int].extend(op)
else:
self.operations[as_int].append(op)
def get_version(self) -> str:
"""
Retrieves the minimum required MPI version.
Returns:
str: The MPI version used.
"""
max_v = "0.0"
for k, v in self.operations.items():
for op in v:
if isinstance(op, MPICall):
max_v = max(op.get_version(), max_v)
return max_v
def __str__(self):
"""
Converts the InstructionBlock instance to a string, replacing placeholders.
Returns:
str: The string representation of the InstructionBlock.
"""
result_str = ""
for key, val in self.operations.items():
if key == 'all':
for op in val:
result_str += str(op) + "\n"
elif key == 'not0':
if len(val) > 0:
result_str += "if (rank != 0) {\n"
for op in val:
result_str += str(op) + "\n"
result_str += "}\n"
else:
if len(val) > 0:
result_str += "if (rank == %d) {\n" % int(key)
for op in val:
result_str += str(op) + "\n"
result_str += "}\n"
return result_str
def has_instruction(self, kind: int | str = 'all', index: int = 0) -> bool:
"""
Checks if the Block has an operation with the given index and kind
Parameters:
- kind ('all','not0' or integer): which ranks should execute the operation
- index (int ): the index of the operation within the given kind
Returns:
boolean
"""
try:
result = self.operations[kind][index]
return True
except (KeyError, IndexError) as e:
return False
def get_instruction(self, kind: int | str = 'all', index: str | int = 0) -> Instruction | typing.List[Instruction]:
"""
Retrieve the operation registered. will Raise IndexError if not present
Parameters:
- kind ('all','not0' or integer): which ranks should execute the operation
- index ('all' or int): the index of the operation within the given kind; 'all' means that the list of all operations for the kind is returned
Returns:
str: The operation specified by kind and index
"""
if index == 'all':
if kind not in self.operations:
return []
return self.operations[kind]
else:
as_int = int(index) # will Raise ValueError if not integer
return self.operations[kind][as_int]
def replace_instruction(self, op: str | Instruction | typing.List[Instruction], kind: str | int = 'all',
index: str | int = 0):
"""
Replace the operation registered. will Raise IndexError if not present
Parameters:
- op the new operation or list of operations
- kind ('all','not0' or integer): which ranks should execute the operation
- index ('all' or int): the index of the operation within the given kind; 'all' means all operations will be replaced with the given list
Notes : if one wants to replace all operations one needs to provide a list
if one only wants to replace one operation: no list of operations is allowed
if a string is passed as the operation, it will create a new Instruction
"""
if isinstance(op, str):
op = Instruction(op)
if index == 'all':
if not isinstance(op, list):
raise ValueError('Provide List for replacement')
self.operations[kind] = op
else:
as_int = int(index) # will Raise ValueError if not integer
if not isinstance(op, Instruction):
raise ValueError('Provide Instruction')
if len(self.operations[kind]) < as_int:
raise IndexError("Operation Not Found")
self.operations[kind][as_int] = op
def insert_instruction(self, op: str | Instruction | typing.List[Instruction], kind: str | int = 'all',
before_index: int = 0):
"""
Inserts an operation before the specified one. will Raise IndexError if not present
Parameters:
- op the new operation or list of operations
- kind ('all','not0' or integer): which ranks should execute the operation
- index (int): the index of the operation within the given kind
note: if str is passed as the operation, it will Create a New Instruction
"""
if isinstance(op, str):
op = Instruction(op)
as_int = int(before_index) # will Raise ValueError if not integer
if len(self.operations[kind]) < before_index:
raise IndexError("Operation Not Found")
if isinstance(op, list):
self.operations[kind] = (
self.operations[kind][0:before_index - 1] + op + self.operations[kind][before_index:])
else:
self.operations[kind].insert(before_index, op)
def remove_instruction(self, kind: str | int = 'all', index: str | int = 0):
"""
Removes the operation registered. will Raise IndexError if not present
Parameters:
- kind ('all','not0' or integer): which ranks should execute the operation
- index ('all' or int): the index of the operation within the given kind
"""
if index == 'all':
self.operations[kind] = []
else:
as_int = int(index) # will Raise ValueError if not integer
if len(self.operations[kind]) < index:
raise IndexError("Operation Not Found")
del self.operations[kind][index]
#! /usr/bin/python3
from __future__ import annotations
import typing
from typing_extensions import override
......@@ -9,12 +11,12 @@ from scripts.Infrastructure.Variables import ERROR_MARKER_COMMENT_BEGIN, ERROR_M
class MPICall(Instruction):
@override
def __init__(self, function: str, args: typing.OrderedDict[str, str], version: str):
super().__init__("")
def __init__(self, function: str, args: typing.OrderedDict[str, str], version: str, rank: str | int = 'all',
identifier: str = None):
super().__init__("", rank, identifier)
self._function = function
self._args = args
self._version = version
self._has_error = False
@override
def __str__(self):
......@@ -33,8 +35,18 @@ class MPICall(Instruction):
assert self.has_arg(arg)
self._args[arg] = value
def get_arg(self, arg: str) -> str:
assert self.has_arg(arg)
return self._args[arg]
def has_arg(self, arg: str) -> bool:
return arg in self._args
def get_function(self) -> str:
return self._function.lower()
def get_version(self) -> str:
return self._version
def __copy__(self) -> MPICall:
return MPICall(self._function, self._args.copy(), self._version, self.get_rank_executing(), self.get_identifier())
This diff is collapsed.
......@@ -5,7 +5,7 @@ import typing
from scripts.Infrastructure.AllocCall import AllocCall
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory
from scripts.Infrastructure.InstructionBlock import InstructionBlock
from scripts.Infrastructure.Instruction import Instruction
from scripts.Infrastructure.MPICall import MPICall
from scripts.Infrastructure.MPICallFactory import CorrectMPICallFactory, MPICallFactory
from scripts.Infrastructure.Template import TemplateManager
......@@ -24,7 +24,7 @@ def get_default_template(mpi_func):
def get_send_recv_template(send_func: str = "mpi_isend", recv_func: str | typing.Tuple[str, str] = "mpi_irecv"):
"""
Contructs a default template for the given mpi send recv function pair
Contructs a default template for the given mpi send recv function pair it contains a send from rank 1 to rank 0
Returns:
TemplateManager Initialized with a default template
The function is contained in a block named MPICALL with seperate calls for rank 1 and 2)
......@@ -58,48 +58,40 @@ def get_send_recv_template(send_func: str = "mpi_isend", recv_func: str | typing
tm = TemplateManager()
cf = CorrectParameterFactory()
alloc_block = InstructionBlock("alloc")
alloc_block.register_instruction(cf.get_buffer_alloc())
tm.register_instruction(cf.get_buffer_alloc(), identifier="ALLOC")
if send_func in sendrecv_funcs:
# spilt send and recv buf
alloc = cf.get_buffer_alloc()
alloc.set_identifier("ALLOC")
alloc.set_name("recv_buf")
alloc_block.register_instruction(alloc)
tm.register_instruction(alloc)
if recv_func in probe_pairs:
alloc_block.register_instruction("MPI_Message msg;")
tm.add_stack_variable("MPI_Message")
if send_func in ["mpi_bsend", "mpi_ibsend", "mpi_bsend_init"]:
buf_size = "sizeof(int)*10 + MPI_BSEND_OVERHEAD"
alloc_block.register_instruction(AllocCall("char", buf_size, "mpi_buf"))
alloc_block.register_instruction(MPICallFactory().mpi_buffer_attach("mpi_buf", buf_size))
tm.register_instruction(AllocCall("char", buf_size, "mpi_buf"), identifier="ALLOC")
tm.register_instruction(MPICallFactory().mpi_buffer_attach("mpi_buf", buf_size), identifier="ALLOC")
if (send_func in isend_funcs + persistent_send_funcs or
recv_func in persistent_recv_funcs + irecv_funcs + probe_pairs):
alloc_block.register_instruction("MPI_Request request;", 'all')
tm.add_stack_variable("MPI_Request")
if recv_func in probe_pairs:
alloc_block.register_instruction("int flag=0;")
tm.register_instruction_block(alloc_block)
flag_name = tm.add_stack_variable("int")
# end preperation of all local variables
# before the send/recv block
if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"]:
b = InstructionBlock("SYNC")
# sender needs to wait until recv is started
b.register_instruction(CorrectMPICallFactory().mpi_barrier(), 1)
tm.register_instruction_block(b)
cmpicf = CorrectMPICallFactory()
tm.register_instruction(CorrectMPICallFactory.mpi_barrier(), identifier="SYNC", rank_to_execute=1)
# get the send and recv block
recv_to_use = recv_func
if recv_func in probe_pairs:
recv_to_use = recv_func[0]
s = cmpicf.get(send_func)
r = cmpicf.get(recv_to_use)
s = CorrectMPICallFactory.get(send_func)
r = CorrectMPICallFactory.get(recv_to_use)
if send_func in sendrecv_funcs:
# sending the second msg
......@@ -110,68 +102,75 @@ def get_send_recv_template(send_func: str = "mpi_isend", recv_func: str | typing
if r.has_arg("recvbuf"):
r.set_arg("recvbuf", "recv_buf")
b = InstructionBlock("MPICALL")
b.register_instruction(s, 1)
b.register_instruction(r, 0)
s.set_identifier("MPICALL")
s.set_rank_executing(1)
r.set_identifier("MPICALL")
r.set_rank_executing(0)
tm.register_instruction(s)
tm.register_instruction(r)
if recv_func in probe_pairs:
if recv_func in [["mpi_improbe", "mpi_mrecv"],
["mpi_improbe", "mpi_imrecv"]]:
b.insert_instruction("while (!flag){", 0, 0)
tm.insert_instruction(Instruction("while (!" + flag_name + "){", rank=0), before_instruction=r)
# insertion before the improbe call
b.register_instruction("}", 0) # end while
b.register_instruction(CorrectMPICallFactory().get(recv_func[1]), 0)
tm.register_instruction("}", rank_to_execute=0) # end while
# the matched recv
tm.register_instruction(CorrectMPICallFactory.get(recv_func[1]), rank_to_execute=0)
if send_func in persistent_send_funcs:
b.register_instruction(cmpicf.mpi_start(), 1)
tm.register_instruction(CorrectMPICallFactory.mpi_start(), rank_to_execute=1, identifier="START")
if send_func == "mpi_psend_init":
# the pready takes a Request NOt a request*
b.register_instruction(MPICallFactory().mpi_pready("0", cf.get("request")[1:]), 1)
tm.register_instruction(MPICallFactory.mpi_pready("0", cf.get("request")[1:]), rank_to_execute=1)
if recv_func in persistent_recv_funcs:
b.register_instruction(cmpicf.mpi_start(), 0) #
tm.register_instruction(CorrectMPICallFactory.mpi_start(), rank_to_execute=0, identifier="START")
# parrived is not necessary
tm.register_instruction_block(b)
# after send/recv
if send_func in ["mpi_rsend", "mpi_irsend", "mpi_rsend_init"]:
b = InstructionBlock("SYNC")
# barrier indicating recv has started
b.register_instruction(CorrectMPICallFactory().mpi_barrier(), 0)
tm.register_instruction_block(b)
tm.register_instruction(CorrectMPICallFactory.mpi_barrier(), rank_to_execute=0, identifier="SYNC")
# wait for op to complete
if send_func in isend_funcs + persistent_send_funcs:
b = InstructionBlock("WAIT")
b.register_instruction(CorrectMPICallFactory().mpi_wait(), 1)
tm.register_instruction_block(b)
if send_func in isend_funcs + persistent_send_funcs:
tm.register_instruction(CorrectMPICallFactory.mpi_wait(), rank_to_execute=1, identifier="WAIT")
if recv_func in irecv_funcs + persistent_recv_funcs + [["mpi_mprobe", "mpi_imrecv"],
["mpi_improbe", "mpi_imrecv"]]:
b = InstructionBlock("WAIT")
b.register_instruction(CorrectMPICallFactory().mpi_wait(), 0)
tm.register_instruction_block(b)
tm.register_instruction(CorrectMPICallFactory.mpi_wait(), rank_to_execute=0, identifier="WAIT")
# end MPI operation
# cleanup
free_block = InstructionBlock("buf_free")
if send_func in ["mpi_bsend", "mpi_ibsend", "mpi_bsend_init"]:
free_block = InstructionBlock("buf_detach")
free_block.register_instruction("int freed_size;")
free_block.register_instruction(MPICallFactory().mpi_buffer_detach("mpi_buf", "&freed_size"))
free_block.register_instruction("free(mpi_buf);")
tm.register_instruction("int freed_size;", identifier="FREE")
tm.register_instruction(MPICallFactory.mpi_buffer_detach("mpi_buf", "&freed_size"), identifier="FREE")
tm.register_instruction("free(mpi_buf);", identifier="FREE")
free_block.register_instruction(cf.get_buffer_free())
tm.register_instruction(cf.get_buffer_free(), identifier="FREE")
if send_func in sendrecv_funcs:
# spilt send and recv buf
free_block.register_instruction("free(recv_buf);")
tm.register_instruction("free(recv_buf);", identifier="FREE")
if send_func in persistent_send_funcs:
free_block.register_instruction(cmpicf.mpi_request_free())
tm.register_instruction(CorrectMPICallFactory.mpi_request_free(), identifier="FREE")
return tm
tm.register_instruction_block(free_block)
def get_invalid_param_p2p_case(param, value, check_receive, send_func, recv_func):
tm = get_send_recv_template(send_func, recv_func)
rank = 1
if check_receive:
rank = 0
for call in tm.get_instruction(identifier="MPICALL", return_list=True):
if call.get_rank_executing() == rank:
assert call.has_arg(param)
call.set_arg(param, value)
return tm
......@@ -281,7 +280,7 @@ def get_rma_call(rma_func, rank):
return b
def get_communicator(comm_create_func, name):
def get_communicator(comm_create_func, name, identifier="COMM"):
"""
:param comm_create_func: teh function used to create the new communicator
:param name: name of the communicator variable
......@@ -290,30 +289,36 @@ def get_communicator(comm_create_func, name):
assert comm_create_func in ["mpi_comm_dup", "mpi_comm_dup_with_info", "mpi_comm_idup",
"mpi_comm_idup_with_info", "mpi_comm_create", "mpi_comm_create_group", "mpi_comm_split",
"mpi_comm_split_type", "mpi_comm_create_from_group"]
b = InstructionBlock("comm_create")
b.register_instruction("MPI_Comm " + name + ";")
inst_list = []
inst_list.append(Instruction("MPI_Comm " + name + ";", identifier=identifier))
if comm_create_func.startswith("mpi_comm_i"):
b.register_instruction("MPI_Request comm_create_req;")
inst_list.append(Instruction("MPI_Request comm_create_req;", identifier=identifier))
if comm_create_func in ["mpi_comm_create", "mpi_comm_create_group"]:
b.register_instruction("MPI_Group group;")
b.register_instruction(CorrectMPICallFactory().mpi_comm_group())
cmpicf = CorrectMPICallFactory()
call_creator_function = getattr(cmpicf, comm_create_func)
call = call_creator_function()
inst_list.append(Instruction("MPI_Group group;", identifier=identifier))
group = CorrectMPICallFactory.mpi_comm_group()
group.set_identifier(identifier)
inst_list.append(group)
call = CorrectMPICallFactory.get(comm_create_func)
call.set_arg("newcomm", "&" + name)
if comm_create_func.startswith("mpi_comm_i"):
call.set_arg("request", "&comm_create_req")
if comm_create_func in ["mpi_comm_create", "mpi_comm_create_group"]:
call.set_arg("group", "group") # not &group
b.register_instruction(call)
call.set_identifier(identifier)
inst_list.append(call)
if comm_create_func.startswith("mpi_comm_i"):
b.register_instruction(MPICallFactory().mpi_wait("&comm_create_req", "MPI_STATUS_IGNORE"))
wait = MPICallFactory.mpi_wait("&comm_create_req", "MPI_STATUS_IGNORE")
wait.set_identifier(identifier)
inst_list.append(wait)
if comm_create_func in ["mpi_comm_create", "mpi_comm_create_group"]:
b.register_instruction(cmpicf.mpi_group_free())
return b
group_free = CorrectMPICallFactory.mpi_group_free()
group_free.set_identifier(identifier)
inst_list.append(group_free)
return inst_list
def get_intercomm(comm_create_func, name):
def get_intercomm(comm_create_func, name, identifier="COMM"):
"""
:param comm_create_func: the function used to create the new communicator
:param name: name of the communicator variable
......@@ -324,54 +329,91 @@ def get_intercomm(comm_create_func, name):
"""
assert comm_create_func in ["mpi_intercomm_create", "mpi_intercomm_create_from_groups", "mpi_intercomm_merge"]
assert name != "intercomm_base_comm"
if comm_create_func == "mpi_intercomm_create":
b = InstructionBlock("comm_create")
b.register_instruction("MPI_Comm intercomm_base_comm;")
b.register_instruction(
MPICallFactory().mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&intercomm_base_comm"))
b.register_instruction("MPI_Comm " + name + ";")
b.register_instruction(
MPICallFactory().mpi_intercomm_create("intercomm_base_comm", "0", "MPI_COMM_WORLD", "!(rank %2)",
CorrectParameterFactory().get("tag"), "&" + name))
b.register_instruction(MPICallFactory().mpi_comm_free("&intercomm_base_comm"))
return b
inst_list = []
inst_list.append(Instruction("MPI_Comm intercomm_base_comm;", identifier=identifier))
call = MPICallFactory.mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&intercomm_base_comm")
call.set_identifier(identifier)
inst_list.append(call)
inst_list.append(Instruction("MPI_Comm " + name + ";", identifier=identifier))
call = MPICallFactory.mpi_intercomm_create("intercomm_base_comm", "0", "MPI_COMM_WORLD", "!(rank %2)",
CorrectParameterFactory().get("tag"), "&" + name)
call.set_identifier(identifier)
inst_list.append(call)
call = MPICallFactory.mpi_comm_free("&intercomm_base_comm")
call.set_identifier(identifier)
inst_list.append(call)
return inst_list
if comm_create_func == "mpi_intercomm_create_from_groups":
b = InstructionBlock("comm_create")
b.register_instruction("MPI_Group world_group,even_group,odd_group;")
b.register_instruction(MPICallFactory().mpi_comm_group("MPI_COMM_WORLD", "&world_group"))
b.register_instruction(
MPICallFactory().mpi_comm_group("intercomm_base_comm", "&intercomm_base_comm_group"))
b.register_instruction("int[3] triplet;")
b.register_instruction("triplet[0] =0;")
b.register_instruction("triplet[1] =size;")
b.register_instruction("triplet[2] =2;")
b.register_instruction(MPICallFactory().mpi_group_incl("world_group", "1", "&triplet", "even_group"))
b.register_instruction("triplet[0] =1;")
b.register_instruction(MPICallFactory().mpi_group_incl("world_group", "1", "&triplet", "odd_group"))
b.register_instruction("MPI_Comm " + name + ";")
b.register_instruction(
MPICallFactory().mpi_intercomm_create_from_groups("(rank % 2 ? even_group:odd_group)", "0",
inst_list = []
inst_list.append(Instruction("MPI_Group world_group,even_group,odd_group;", identifier=identifier))
call = MPICallFactory.mpi_comm_group("MPI_COMM_WORLD", "&world_group")
call.set_identifier(identifier)
inst_list.append(call)
call = MPICallFactory.mpi_comm_group("intercomm_base_comm", "&intercomm_base_comm_group")
call.set_identifier(identifier)
inst_list.append(call)
Instruction("int[3] triplet;"
"triplet[0] =0;"
"triplet[1] =size;"
"triplet[2] =2;", identifier=identifier)
call = MPICallFactory.mpi_group_incl("world_group", "1", "&triplet", "even_group")
call.set_identifier(identifier)
inst_list.append(call)
inst_list.append(Instruction("triplet[0] =1;", identifier=identifier))
call = MPICallFactory.mpi_group_incl("world_group", "1", "&triplet", "odd_group")
call.set_identifier(identifier)
inst_list.append(call)
inst_list.append(Instruction("MPI_Comm " + name + ";", identifier=identifier))
call = MPICallFactory.mpi_intercomm_create_from_groups("(rank % 2 ? even_group:odd_group)", "0",
"(!(rank % 2) ? even_group:odd_group)", "0",
CorrectParameterFactory().get("stringtag"),
CorrectParameterFactory().get("INFO"),
CorrectParameterFactory().get("errhandler"),
"&" + name), )
return b
"&" + name)
call.set_identifier(identifier)
inst_list.append(call)
return inst_list
if comm_create_func == "mpi_intercomm_merge":
b = InstructionBlock("comm_create")
b.register_instruction("MPI_Comm intercomm_base_comm;")
b.register_instruction("MPI_Comm to_merge_intercomm_comm;")
b.register_instruction(
MPICallFactory().mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&intercomm_base_comm"))
b.register_instruction("MPI_Comm " + name + ";")
b.register_instruction(
MPICallFactory().mpi_intercomm_create("intercomm_base_comm", "0", "MPI_COMM_WORLD", "!(rank %2)",
CorrectParameterFactory().get("tag"), "&to_merge_intercomm_comm"))
b.register_instruction(MPICallFactory().mpi_intercomm_merge("to_merge_intercomm_comm", "rank %2", "&" + name))
b.register_instruction(MPICallFactory().mpi_comm_free("&to_merge_intercomm_comm"))
b.register_instruction(MPICallFactory().mpi_comm_free("&intercomm_base_comm"))
return b
inst_list = []
inst_list.append(Instruction("MPI_Comm intercomm_base_comm;", identifier=identifier))
inst_list.append(Instruction("MPI_Comm to_merge_intercomm_comm;", identifier=identifier))
call = MPICallFactory.mpi_comm_split("MPI_COMM_WORLD", "rank % 2", "rank", "&intercomm_base_comm")
call.set_identifier(identifier)
inst_list.append(call)
inst_list.append(Instruction("MPI_Comm " + name + ";", identifier=identifier))
call = MPICallFactory.mpi_intercomm_create("intercomm_base_comm", "0", "MPI_COMM_WORLD", "!(rank %2)",
CorrectParameterFactory().get("tag"), "&to_merge_intercomm_comm")
call.set_identifier(identifier)
inst_list.append(call)
call = MPICallFactory.mpi_intercomm_merge("to_merge_intercomm_comm", "rank %2", "&" + name)
call.set_identifier(identifier)
inst_list.append(call)
call = MPICallFactory.mpi_comm_free("&to_merge_intercomm_comm")
call.set_identifier(identifier)
inst_list.append(call)
call = MPICallFactory.mpi_comm_free("&intercomm_base_comm")
call.set_identifier(identifier)
inst_list.append(call)
return inst_list
return None
# todo also for send and non default args
def get_matching_recv(call: MPICall) -> MPICall:
correct_params = CorrectParameterFactory()
recv = MPICallFactory().mpi_recv(
correct_params.get("BUFFER"),
correct_params.get("COUNT"),
correct_params.get("DATATYPE"),
correct_params.get("SRC"),
correct_params.get("TAG"),
correct_params.get("COMM"),
correct_params.get("STATUS", "MPI_Recv"),
)
return recv
......@@ -8,12 +8,15 @@ if __name__ == "__main__":
#gm = GeneratorManager("./errors")
gm = GeneratorManager("./errors/devel")
# remove all testcases from previous execution (ease of debugging)
filelist = [f for f in os.listdir(gencodes_dir) if f.endswith(".c")]
for f in filelist:
os.remove(os.path.join(gencodes_dir, f))
for root, dirs, files in os.walk(gencodes_dir):
for file in files:
if file.endswith(".c"):
os.remove(os.path.join(root, file))
# gm.generate(gencodes_dir, try_compile=True, generate_full_set=False) # default
gm.generate(gencodes_dir, try_compile=True, generate_full_set=True, max_mpi_version="3.1") #all cases that can compile for my local mpi installation
gm.generate(gencodes_dir, try_compile=True, generate_full_set=True,
max_mpi_version="3.1") # all cases that can compile for my local mpi installation
pass