Skip to content
Snippets Groups Projects

RMA Test Cases

Open Simon Schwitanski requested to merge rma into main
Files
3
#! /usr/bin/python3
from scripts.Infrastructure import MPICall
from scripts.Infrastructure.Instruction import Instruction
from scripts.Infrastructure.MPICallFactory import MPICallFactory
# from scripts.Infrastructure.MPICallFactory import MPICallFactory
from scripts.Infrastructure.AllocCall import AllocCall, get_free
class CorrectParameterFactory:
# default params
buf_size = "10"
@@ -13,15 +12,26 @@ class CorrectParameterFactory:
tag = 0
buf_var_name = "buf"
winbuf_var_name = "winbuf"
request_counter = 0
_instance = None
def __init__(self):
pass
@classmethod
def instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
return cls._instance
def get_buffer_alloc(self) -> AllocCall:
return AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False)
def get_buffer_free(self) -> Instruction:
return get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False))
def consume_request(self):
self.request_counter += 1
def get(self, param: str) -> str:
if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf", "origin_addr"]:
@@ -56,7 +66,7 @@ class CorrectParameterFactory:
if param in ["PARTITION", "partition"]:
return "0"
if param in ["REQUEST", "request"]:
return "&mpi_request_0"
return "&mpi_request_" + str(self.request_counter)
if param in ["GROUP", "group"]:
return "&mpi_group_0"
if param in ["color"]:
@@ -114,17 +124,17 @@ class CorrectParameterFactory:
assert False, "Param not known"
# todo also for send and non default args
def get_matching_recv(call: MPICall) -> MPICall:
correct_params = CorrectParameterFactory()
recv = MPICallFactory().mpi_recv(
correct_params.get("BUFFER"),
correct_params.get("COUNT"),
correct_params.get("DATATYPE"),
correct_params.get("SRC"),
correct_params.get("TAG"),
correct_params.get("COMM"),
correct_params.get("STATUS", "MPI_Recv"),
)
# # todo also for send and non default args
# def get_matching_recv(call: MPICall) -> MPICall:
# correct_params = CorrectParameterFactory()
# recv = MPICallFactory().mpi_recv(
# correct_params.get("BUFFER"),
# correct_params.get("COUNT"),
# correct_params.get("DATATYPE"),
# correct_params.get("SRC"),
# correct_params.get("TAG"),
# correct_params.get("COMM"),
# correct_params.get("STATUS", "MPI_Recv"),
# )
return recv
# return recv
Loading