Skip to content
Snippets Groups Projects

Draft: RMA

Merged Simon Schwitanski requested to merge rma into main
4 files
+ 262
4
Compare changes
  • Side-by-side
  • Inline

Files

@@ -8,8 +8,10 @@ class CorrectParameterFactory:
@@ -8,8 +8,10 @@ class CorrectParameterFactory:
# default params
# default params
buf_size = 10
buf_size = 10
dtype = ['int', 'MPI_INT']
dtype = ['int', 'MPI_INT']
 
buf_size_bytes = f"{buf_size}*sizeof({dtype[0]})"
tag = 0
tag = 0
buf_var_name = "buf"
buf_var_name = "buf"
 
winbuf_var_name = "winbuf"
def __init__(self):
def __init__(self):
pass
pass
@@ -21,13 +23,13 @@ class CorrectParameterFactory:
@@ -21,13 +23,13 @@ class CorrectParameterFactory:
return get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False))
return get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False))
def get(self, param, func=None):
def get(self, param, func=None):
if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf"]:
if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf", "origin_addr"]:
return self.buf_var_name
return self.buf_var_name
if param in ["COUNT", "count", "sendcount", "recvcount"]:
if param in ["COUNT", "count", "sendcount", "recvcount", "origin_count", "target_count", "result_count"]:
return str(self.buf_size)
return str(self.buf_size)
if param in ["DATATYPE", "datatype", "sendtype", "recvtype"]:
if param in ["DATATYPE", "datatype", "sendtype", "recvtype", "origin_datatype", "target_datatype", "result_datatype"]:
return self.dtype[1]
return self.dtype[1]
if param in ["DEST", "dest"]:
if param in ["DEST", "dest", "target_rank"]:
return "0"
return "0"
if param in ["SRC", "source"]:
if param in ["SRC", "source"]:
return "1"
return "1"
@@ -71,6 +73,24 @@ class CorrectParameterFactory:
@@ -71,6 +73,24 @@ class CorrectParameterFactory:
return "MPI_COMM_WORLD"
return "MPI_COMM_WORLD"
if param in ["remote_leader"]:
if param in ["remote_leader"]:
return "0"
return "0"
 
if param in ["target_disp"]:
 
return "0"
 
if param in ["win"]:
 
return "win"
 
if param in ["baseptr"]:
 
return "&" + self.winbuf_var_name
 
if param in ["base"]:
 
return self.winbuf_var_name
 
if param in ["size"]:
 
return self.buf_size_bytes
 
if param in ["disp_unit"]:
 
return "sizeof(int)"
 
if param in ["info"]:
 
return "MPI_INFO_NULL"
 
if param in ["result_addr"]:
 
return "resultbuf"
 
if param in ["compare_addr"]:
 
return "comparebuf"
print("Not Implemented: " + param)
print("Not Implemented: " + param)
assert False, "Param not known"
assert False, "Param not known"
Loading