Skip to content
Snippets Groups Projects

RMA Test Cases

Open Simon Schwitanski requested to merge rma into main

Files

#! /usr/bin/python3
#! /usr/bin/python3
from scripts.Infrastructure import MPICall
from scripts.Infrastructure import MPICall
from scripts.Infrastructure.Instruction import Instruction
from scripts.Infrastructure.Instruction import Instruction
# from scripts.Infrastructure.MPICallFactory import MPICallFactory
#from scripts.Infrastructure.MPICallFactory import MPICallFactory
from scripts.Infrastructure.AllocCall import AllocCall, get_free
from scripts.Infrastructure.AllocCall import AllocCall, get_free
 
class CorrectParameterFactory:
class CorrectParameterFactory:
# default params
# default params
buf_size = "10"
buf_size = "10"
@@ -12,26 +13,15 @@ class CorrectParameterFactory:
@@ -12,26 +13,15 @@ class CorrectParameterFactory:
tag = 0
tag = 0
buf_var_name = "buf"
buf_var_name = "buf"
winbuf_var_name = "winbuf"
winbuf_var_name = "winbuf"
request_counter = 0
_instance = None
def __init__(self):
def __init__(self):
pass
pass
@classmethod
def instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
return cls._instance
def get_buffer_alloc(self) -> AllocCall:
def get_buffer_alloc(self) -> AllocCall:
return AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False)
return AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False)
def get_buffer_free(self) -> Instruction:
def get_buffer_free(self) -> Instruction:
return get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False))
return get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False))
def consume_request(self):
self.request_counter += 1
def get(self, param: str) -> str:
def get(self, param: str) -> str:
if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf", "origin_addr"]:
if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf", "origin_addr"]:
@@ -41,8 +31,10 @@ class CorrectParameterFactory:
@@ -41,8 +31,10 @@ class CorrectParameterFactory:
if param in ["DATATYPE", "datatype", "sendtype", "recvtype", "origin_datatype", "target_datatype",
if param in ["DATATYPE", "datatype", "sendtype", "recvtype", "origin_datatype", "target_datatype",
"result_datatype"]:
"result_datatype"]:
return self.dtype[1]
return self.dtype[1]
if param in ["DEST", "dest", "target_rank"]:
if param in ["DEST", "dest", "rank"]:
return "0"
return "0"
 
if param in ["target_rank"]:
 
return "1"
if param in ["SRC", "source"]:
if param in ["SRC", "source"]:
return "1"
return "1"
if param in ["RANK", "root"]:
if param in ["RANK", "root"]:
@@ -66,9 +58,9 @@ class CorrectParameterFactory:
@@ -66,9 +58,9 @@ class CorrectParameterFactory:
if param in ["PARTITION", "partition"]:
if param in ["PARTITION", "partition"]:
return "0"
return "0"
if param in ["REQUEST", "request"]:
if param in ["REQUEST", "request"]:
return "&mpi_request_" + str(self.request_counter)
return "&mpi_request_0"
if param in ["GROUP", "group"]:
if param in ["GROUP", "group"]:
return "&mpi_group_0"
return "mpi_group_0"
if param in ["color"]:
if param in ["color"]:
return "1"
return "1"
if param in ["message"]:
if param in ["message"]:
@@ -93,6 +85,10 @@ class CorrectParameterFactory:
@@ -93,6 +85,10 @@ class CorrectParameterFactory:
return "0"
return "0"
if param in ["win"]:
if param in ["win"]:
return "mpi_win_0"
return "mpi_win_0"
 
if param in ["lock_type"]:
 
return "MPI_LOCK_EXCLUSIVE"
 
if param in ["assert"]:
 
return "0"
if param in ["baseptr"]:
if param in ["baseptr"]:
return "&" + self.winbuf_var_name
return "&" + self.winbuf_var_name
if param in ["base"]:
if param in ["base"]:
@@ -124,17 +120,17 @@ class CorrectParameterFactory:
@@ -124,17 +120,17 @@ class CorrectParameterFactory:
assert False, "Param not known"
assert False, "Param not known"
# # todo also for send and non default args
# todo also for send and non default args
# def get_matching_recv(call: MPICall) -> MPICall:
def get_matching_recv(call: MPICall) -> MPICall:
# correct_params = CorrectParameterFactory()
correct_params = CorrectParameterFactory()
# recv = MPICallFactory().mpi_recv(
recv = MPICallFactory().mpi_recv(
# correct_params.get("BUFFER"),
correct_params.get("BUFFER"),
# correct_params.get("COUNT"),
correct_params.get("COUNT"),
# correct_params.get("DATATYPE"),
correct_params.get("DATATYPE"),
# correct_params.get("SRC"),
correct_params.get("SRC"),
# correct_params.get("TAG"),
correct_params.get("TAG"),
# correct_params.get("COMM"),
correct_params.get("COMM"),
# correct_params.get("STATUS", "MPI_Recv"),
correct_params.get("STATUS", "MPI_Recv"),
# )
)
# return recv
return recv
Loading