Skip to content
Snippets Groups Projects

P2P

Open
Jammer, Timrequested to merge
p2p into main
2 files
+ 41
6
Compare changes
  • Side-by-side
  • Inline

Files

@@ -17,18 +17,19 @@ class CorrectParameterFactory:
@@ -17,18 +17,19 @@ class CorrectParameterFactory:
def __init__(self):
def __init__(self):
pass
pass
def get_buffer_alloc(self)-> AllocCall:
def get_buffer_alloc(self) -> AllocCall:
return AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False)
return AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False)
def get_buffer_free(self)->Instruction:
def get_buffer_free(self) -> Instruction:
return get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False))
return get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False))
def get(self, param:str)->str:
def get(self, param: str) -> str:
if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf", "origin_addr"]:
if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf", "origin_addr"]:
return self.buf_var_name
return self.buf_var_name
if param in ["COUNT", "count", "sendcount", "recvcount", "origin_count", "target_count", "result_count"]:
if param in ["COUNT", "count", "sendcount", "recvcount", "origin_count", "target_count", "result_count"]:
return str(self.buf_size)
return str(self.buf_size)
if param in ["DATATYPE", "datatype", "sendtype", "recvtype", "origin_datatype", "target_datatype", "result_datatype"]:
if param in ["DATATYPE", "datatype", "sendtype", "recvtype", "origin_datatype", "target_datatype",
 
"result_datatype"]:
return self.dtype[1]
return self.dtype[1]
if param in ["DEST", "dest", "target_rank"]:
if param in ["DEST", "dest", "target_rank"]:
return "0"
return "0"
@@ -99,9 +100,18 @@ class CorrectParameterFactory:
@@ -99,9 +100,18 @@ class CorrectParameterFactory:
print("Not Implemented: " + param)
print("Not Implemented: " + param)
assert False, "Param not known"
assert False, "Param not known"
 
def get_initializer(self, variable_type: str) -> str:
 
if variable_type == "int":
 
return "0"
 
if variable_type == "MPI_Request":
 
return "MPI_REQUEST_NULL"
 
# TODO implement other types
 
print("Not Implemented: " + variable_type)
 
assert False, "Param not known"
 
# todo also for send and non default args
# todo also for send and non default args
def get_matching_recv(call:MPICall)->MPICall:
def get_matching_recv(call: MPICall) -> MPICall:
correct_params = CorrectParameterFactory()
correct_params = CorrectParameterFactory()
recv = MPICallFactory().mpi_recv(
recv = MPICallFactory().mpi_recv(
correct_params.get("BUFFER"),
correct_params.get("BUFFER"),
@@ -113,4 +123,4 @@ def get_matching_recv(call:MPICall)->MPICall:
@@ -113,4 +123,4 @@ def get_matching_recv(call:MPICall)->MPICall:
correct_params.get("STATUS", "MPI_Recv"),
correct_params.get("STATUS", "MPI_Recv"),
)
)
return recv
return recv
\ No newline at end of file
Loading