Skip to content
Snippets Groups Projects

RMA Test Cases

Open Simon Schwitanski requested to merge rma into main
Files
13
#! /usr/bin/python3
#! /usr/bin/python3
from scripts.Infrastructure import MPICall
from scripts.Infrastructure import MPICall
from scripts.Infrastructure.Instruction import Instruction
from scripts.Infrastructure.Instruction import Instruction
from scripts.Infrastructure.MPICallFactory import MPICallFactory
from scripts.Infrastructure.AllocCall import AllocCall, get_free
from scripts.Infrastructure.AllocCall import AllocCall, get_free
@@ -31,8 +30,10 @@ class CorrectParameterFactory:
@@ -31,8 +30,10 @@ class CorrectParameterFactory:
if param in ["DATATYPE", "datatype", "sendtype", "recvtype", "origin_datatype", "target_datatype",
if param in ["DATATYPE", "datatype", "sendtype", "recvtype", "origin_datatype", "target_datatype",
"result_datatype"]:
"result_datatype"]:
return self.dtype[1]
return self.dtype[1]
if param in ["DEST", "dest", "target_rank"]:
if param in ["DEST", "dest", "rank"]:
return "0"
return "0"
 
if param in ["target_rank"]:
 
return "1"
if param in ["SRC", "source"]:
if param in ["SRC", "source"]:
return "1"
return "1"
if param in ["RANK", "root"]:
if param in ["RANK", "root"]:
@@ -58,7 +59,7 @@ class CorrectParameterFactory:
@@ -58,7 +59,7 @@ class CorrectParameterFactory:
if param in ["REQUEST", "request"]:
if param in ["REQUEST", "request"]:
return "&mpi_request_0"
return "&mpi_request_0"
if param in ["GROUP", "group"]:
if param in ["GROUP", "group"]:
return "&mpi_group_0"
return "mpi_group_0"
if param in ["color"]:
if param in ["color"]:
return "1"
return "1"
if param in ["message"]:
if param in ["message"]:
@@ -83,6 +84,10 @@ class CorrectParameterFactory:
@@ -83,6 +84,10 @@ class CorrectParameterFactory:
return "0"
return "0"
if param in ["win"]:
if param in ["win"]:
return "mpi_win_0"
return "mpi_win_0"
 
if param in ["lock_type"]:
 
return "MPI_LOCK_EXCLUSIVE"
 
if param in ["assert"]:
 
return "0"
if param in ["baseptr"]:
if param in ["baseptr"]:
return "&" + self.winbuf_var_name
return "&" + self.winbuf_var_name
if param in ["base"]:
if param in ["base"]:
@@ -112,19 +117,3 @@ class CorrectParameterFactory:
@@ -112,19 +117,3 @@ class CorrectParameterFactory:
# TODO implement other types
# TODO implement other types
print("Not Implemented: " + variable_type)
print("Not Implemented: " + variable_type)
assert False, "Param not known"
assert False, "Param not known"
# todo also for send and non default args
def get_matching_recv(call: MPICall) -> MPICall:
correct_params = CorrectParameterFactory()
recv = MPICallFactory().mpi_recv(
correct_params.get("BUFFER"),
correct_params.get("COUNT"),
correct_params.get("DATATYPE"),
correct_params.get("SRC"),
correct_params.get("TAG"),
correct_params.get("COMM"),
correct_params.get("STATUS", "MPI_Recv"),
)
return recv
Loading