Skip to content
Snippets Groups Projects
Commit e4e8bacf authored by Jammer, Tim's avatar Jammer, Tim
Browse files

Used the functionality to replace Parameters in InvalidRankError

parent ffb18c09
No related branches found
No related tags found
1 merge request!3more work on infrastructure II
...@@ -15,7 +15,7 @@ class CorrectParameterFactory: ...@@ -15,7 +15,7 @@ class CorrectParameterFactory:
pass pass
def get_buffer_alloc(self): def get_buffer_alloc(self):
b = InstructionBlock() b = InstructionBlock("alloc")
b.register_operation(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False), kind='all') b.register_operation(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False), kind='all')
return b return b
......
#! /usr/bin/python3 #! /usr/bin/python3
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory
from scripts.Infrastructure.InstructionBlock import InstructionBlock from scripts.Infrastructure.InstructionBlock import InstructionBlock
from scripts.Infrastructure.MPICall import MPI_Call from scripts.Infrastructure.MPICall import MPI_Call
from scripts.Infrastructure.MPICallFactory import CorrectMPICallFactory
from scripts.Infrastructure.Template import TemplateManager from scripts.Infrastructure.Template import TemplateManager
def get_default_template(mpi_func):
""" """
Contructs a default template for the given mpi function Contructs a default template for the given mpi function
Returns: Returns:
...@@ -10,7 +14,26 @@ Returns: ...@@ -10,7 +14,26 @@ Returns:
The function is contained in a block named MPICALL with seperate calls for rank 1 and 2) The function is contained in a block named MPICALL with seperate calls for rank 1 and 2)
""" """
def get_default_template(mpi_func): pass
def get_send_recv_template(send_func, recv_func):
tm = TemplateManager() tm = TemplateManager()
cf = Co cf = CorrectParameterFactory()
tm.register_instruction_block(cf.get_buffer_alloc())
cmpicf = CorrectMPICallFactory()
send_func_creator_function = getattr(cmpicf, send_func)
s = send_func_creator_function()
recv_func_creator_function = getattr(cmpicf, recv_func)
r = recv_func_creator_function()
b = InstructionBlock("MPICALL")
b.register_operation(s, 1)
b.register_operation(r, 0)
tm.register_instruction_block(b)
return tm
...@@ -5,10 +5,11 @@ from scripts.Infrastructure.InstructionBlock import InstructionBlock ...@@ -5,10 +5,11 @@ from scripts.Infrastructure.InstructionBlock import InstructionBlock
from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory
from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv from scripts.Infrastructure.CorrectParameter import CorrectParameterFactory, get_matching_recv
from scripts.Infrastructure.Template import TemplateManager from scripts.Infrastructure.Template import TemplateManager
from scripts.Infrastructure.TemplateFactory import get_send_recv_template
class Invalid_negative_rank_error(ErrorGenerator): class InvalidRankError(ErrorGenerator):
invalid_ranks = ["-1", "size", "NULL", "MPI_PROC_NULL"] invalid_ranks = ["-1", "size", "MPI_PROC_NULL"]
def __init__(self): def __init__(self):
pass pass
...@@ -24,30 +25,12 @@ class Invalid_negative_rank_error(ErrorGenerator): ...@@ -24,30 +25,12 @@ class Invalid_negative_rank_error(ErrorGenerator):
return ["P2P"] return ["P2P"]
def generate(self, i): def generate(self, i):
tm = TemplateManager() rank_to_use = self.invalid_ranks[i]
correct_params = CorrectParameterFactory() tm = get_send_recv_template("mpi_send", "mpi_recv")
tm.set_description("InvalidParam-Rank-MPI_Send", "Invalid Rank: %s" % self.invalid_ranks[i]) tm.set_description("InvalidParam-Rank-MPI_Send", "Invalid Rank: %s" % self.invalid_ranks[i])
# include the buffer allocation in the template (all ranks execute it) tm.get_block("MPICALL").get_operation(kind=0, index=0).set_arg("source",rank_to_use)
tm.register_instruction_block(correct_params.get_buffer_alloc()) tm.get_block("MPICALL").get_operation(kind=0, index=0).set_has_error()
send = MPICallFactory().mpi_send(
correct_params.get("buf"),
correct_params.get("count"),
correct_params.get("datatype"),
self.invalid_ranks[i], # invalid rank
correct_params.get("tag"),
correct_params.get("comm"),
)
send.set_has_error()
recv = CorrectMPICallFactory().mpi_recv()
b = InstructionBlock()
# only rank 0 execute the send
b.register_operation(send, 0)
# only rank 1 execute the recv
b.register_operation(recv, 1)
tm.register_instruction_block(b)
return tm return tm
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment