Skip to content
Snippets Groups Projects
Verified Commit 20e62556 authored by Simon Schwitanski's avatar Simon Schwitanski :slight_smile:
Browse files

Simplify LocalConcurrency and GlobalConcurrency generators, add level support

parent 1ab4db4a
No related branches found
No related tags found
No related merge requests found
...@@ -17,6 +17,9 @@ from typing import Tuple, List ...@@ -17,6 +17,9 @@ from typing import Tuple, List
class GlobalConcurrencyErrorRMA(ErrorGenerator): class GlobalConcurrencyErrorRMA(ErrorGenerator):
def __init__(self): def __init__(self):
self.cfmpi = CorrectMPICallFactory()
# RMA calls that perform a local buffer access
localbufwrite = CorrectMPICallFactory().mpi_get() localbufwrite = CorrectMPICallFactory().mpi_get()
localbufwrite.set_arg( localbufwrite.set_arg(
"origin_addr", CorrectParameterFactory().winbuf_var_name) "origin_addr", CorrectParameterFactory().winbuf_var_name)
...@@ -26,8 +29,6 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -26,8 +29,6 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
localbufwrite.set_arg("target_rank", "0") localbufwrite.set_arg("target_rank", "0")
localbufread = CorrectMPICallFactory().mpi_put() localbufread = CorrectMPICallFactory().mpi_put()
# local buffer accesses
localbufread.set_arg( localbufread.set_arg(
"origin_addr", CorrectParameterFactory().winbuf_var_name) "origin_addr", CorrectParameterFactory().winbuf_var_name)
localbufread.set_rank_executing(1) localbufread.set_rank_executing(1)
...@@ -41,14 +42,13 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -41,14 +42,13 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
"localbufread": localbufread, "localbufread": localbufread,
"localbufwrite": localbufwrite "localbufwrite": localbufwrite
} }
pass
def get_feature(self): def get_feature(self):
return ["RMA"] return ["RMA"]
def fence_sync(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool): def fence(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
# open access epoch + sync # open access epoch + sync
tm.register_instruction(CorrectMPICallFactory().mpi_win_fence()) tm.register_instruction(self.cfmpi.mpi_win_fence())
tm.register_instruction(alloc_inst) tm.register_instruction(alloc_inst)
tm.register_instruction(op1, "OP1") tm.register_instruction(op1, "OP1")
...@@ -56,44 +56,44 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -56,44 +56,44 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
# if accesses should be synced, add fence # if accesses should be synced, add fence
if shouldsync: if shouldsync:
tm.register_instruction( tm.register_instruction(
CorrectMPICallFactory().mpi_win_fence(), rank_to_execute="all") self.cfmpi.mpi_win_fence(), rank_to_execute="all")
tm.register_instruction(op2, "OP2") tm.register_instruction(op2, "OP2")
# finish access epoch + sync # finish access epoch + sync
tm.register_instruction(CorrectMPICallFactory().mpi_win_fence()) tm.register_instruction(self.cfmpi.mpi_win_fence())
return tm return True
def lockall_sync(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool): def lockall(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
# open access epoch + sync # open access epoch + sync
tm.register_instruction(CorrectMPICallFactory( tm.register_instruction(
).mpi_win_lock_all(), rank_to_execute="all") self.cfmpi.mpi_win_lock_all(), rank_to_execute="all")
tm.register_instruction(alloc_inst) tm.register_instruction(alloc_inst)
tm.register_instruction(op1, "OP1") tm.register_instruction(op1, "OP1")
tm.register_instruction(CorrectMPICallFactory( tm.register_instruction(
).mpi_win_flush_all(), rank_to_execute="all") self.cfmpi.mpi_win_flush_all(), rank_to_execute="all")
# if accesses should be synced, add barrier # if accesses should be synced, add barrier
if shouldsync: if shouldsync:
tm.register_instruction( tm.register_instruction(
CorrectMPICallFactory().mpi_barrier(), rank_to_execute="all") self.cfmpi.mpi_barrier(), rank_to_execute="all")
tm.register_instruction(op2, "OP2") tm.register_instruction(op2, "OP2")
# finish access epoch + sync # finish access epoch + sync
tm.register_instruction(CorrectMPICallFactory( tm.register_instruction(
).mpi_win_unlock_all(), rank_to_execute="all") self.cfmpi.mpi_win_unlock_all(), rank_to_execute="all")
return tm return True
def lock_flush_sync(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool): def lockflush(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
lock0 = CorrectMPICallFactory().mpi_win_lock() lock0 = self.cfmpi.mpi_win_lock()
unlock0 = CorrectMPICallFactory().mpi_win_unlock() unlock0 = self.cfmpi.mpi_win_unlock()
lock1 = CorrectMPICallFactory().mpi_win_lock() lock1 = self.cfmpi.mpi_win_lock()
unlock1 = CorrectMPICallFactory().mpi_win_unlock() unlock1 = self.cfmpi.mpi_win_unlock()
lock0.set_arg("rank", "1") lock0.set_arg("rank", "1")
unlock0.set_arg("rank", "1") unlock0.set_arg("rank", "1")
lock1.set_arg("rank", "1") lock1.set_arg("rank", "1")
...@@ -116,15 +116,19 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -116,15 +116,19 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
tm.register_instruction( tm.register_instruction(
unlock1, rank_to_execute=op2[-1].get_rank_executing()) unlock1, rank_to_execute=op2[-1].get_rank_executing())
return tm return True
def request(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
# only consider combination where the first operation is a request-based RMA call
if not isinstance(op1[-1], MPICall) or not op1[-1].has_arg("request"):
return False
def req_sync(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
# we assume that the first operation is request-based # we assume that the first operation is request-based
wait = CorrectMPICallFactory().mpi_wait() wait = self.cfmpi.mpi_wait()
wait.set_arg("request", op1[-1].get_arg("request")) wait.set_arg("request", op1[-1].get_arg("request"))
# open access epoch + sync # open access epoch + sync
tm.register_instruction(CorrectMPICallFactory().mpi_win_lock_all()) tm.register_instruction(self.cfmpi.mpi_win_lock_all())
tm.register_instruction(alloc_inst) tm.register_instruction(alloc_inst)
tm.register_instruction(op1, "OP1") tm.register_instruction(op1, "OP1")
...@@ -136,9 +140,9 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -136,9 +140,9 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
tm.register_instruction(op2, "OP2") tm.register_instruction(op2, "OP2")
# finish access epoch + sync # finish access epoch + sync
tm.register_instruction(CorrectMPICallFactory().mpi_win_unlock_all()) tm.register_instruction(self.cfmpi.mpi_win_unlock_all())
return tm return True
def pscw(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool): def pscw(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
tm.register_instruction("MPI_Group world_group;") tm.register_instruction("MPI_Group world_group;")
...@@ -148,7 +152,7 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -148,7 +152,7 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
tm.register_instruction( tm.register_instruction(
"int destrank = 1; MPI_Group mpi_group_0; MPI_Group_incl(world_group, 1, &destrank, &mpi_group_0);", rank_to_execute=0) "int destrank = 1; MPI_Group mpi_group_0; MPI_Group_incl(world_group, 1, &destrank, &mpi_group_0);", rank_to_execute=0)
tm.register_instruction( tm.register_instruction(
CorrectMPICallFactory().mpi_win_start(), rank_to_execute=0) self.cfmpi.mpi_win_start(), rank_to_execute=0)
tm.register_instruction(alloc_inst) tm.register_instruction(alloc_inst)
tm.register_instruction(op1, "OP1") tm.register_instruction(op1, "OP1")
...@@ -156,23 +160,23 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -156,23 +160,23 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
# if accesses should be synced, end access epoch here # if accesses should be synced, end access epoch here
if shouldsync: if shouldsync:
tm.register_instruction( tm.register_instruction(
CorrectMPICallFactory().mpi_win_complete(), rank_to_execute=0) self.cfmpi.mpi_win_complete(), rank_to_execute=0)
tm.register_instruction(op2, "OP2") tm.register_instruction(op2, "OP2")
# if accesses should not be synced, end access epoch here # if accesses should not be synced, end access epoch here
if not shouldsync: if not shouldsync:
tm.register_instruction( tm.register_instruction(
CorrectMPICallFactory().mpi_win_complete(), rank_to_execute=0) self.cfmpi.mpi_win_complete(), rank_to_execute=0)
tm.register_instruction( tm.register_instruction(
"int srcrank = 0; MPI_Group mpi_group_0; MPI_Group_incl(world_group, 1, &srcrank, &mpi_group_0);", rank_to_execute=1) "int srcrank = 0; MPI_Group mpi_group_0; MPI_Group_incl(world_group, 1, &srcrank, &mpi_group_0);", rank_to_execute=1)
tm.register_instruction( tm.register_instruction(
CorrectMPICallFactory().mpi_win_post(), rank_to_execute=1) self.cfmpi.mpi_win_post(), rank_to_execute=1)
tm.register_instruction( tm.register_instruction(
CorrectMPICallFactory().mpi_win_wait(), rank_to_execute=1) self.cfmpi.mpi_win_wait(), rank_to_execute=1)
return tm return True
def get_mem_op(self, name: str, rank) -> Tuple[List[Instruction], List[Instruction]]: def get_mem_op(self, name: str, rank) -> Tuple[List[Instruction], List[Instruction]]:
if name.startswith("mpi"): if name.startswith("mpi"):
...@@ -183,10 +187,12 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -183,10 +187,12 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
def generate(self, generate_level): def generate(self, generate_level):
if generate_level == 1: if generate_level == 1:
# only basic calls
remote_read = ["mpi_get"] remote_read = ["mpi_get"]
remote_write = ["mpi_put"] remote_write = ["mpi_put"]
remote_atomic_update = ["mpi_accumulate"] remote_atomic_update = ["mpi_accumulate"]
else: else:
# everything
remote_read = ["mpi_get", "mpi_rget"] remote_read = ["mpi_get", "mpi_rget"]
remote_write = [ remote_write = [
"mpi_put", "mpi_put",
...@@ -202,7 +208,6 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -202,7 +208,6 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
] ]
cf = CorrectParameterFactory() cf = CorrectParameterFactory()
cfmpi = CorrectMPICallFactory()
# possible combinations of local buffer accesses (hasconflict = True | False) # possible combinations of local buffer accesses (hasconflict = True | False)
remote_access_combinations: List[Tuple[List[str], List[str], bool]] = [ remote_access_combinations: List[Tuple[List[str], List[str], bool]] = [
...@@ -221,8 +226,17 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -221,8 +226,17 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
(remote_atomic_update, ["bufwrite", "localbufwrite"], True), (remote_atomic_update, ["bufwrite", "localbufwrite"], True),
] ]
for ops1, ops2, hasconflict in remote_access_combinations: sync_modes = [self.fence, self.lockall, self.lockflush, self.request, self.pscw]
for sync_mode in ["fence", "lockall", "lock_flush", "request", "pscw"]:
if generate_level <= 2:
# go through all sync modes, but only one access combination per sync mode, fill with fence
combos = itertools.zip_longest(
remote_access_combinations, sync_modes, fillvalue=self.fence)
else:
# combine everything (= nested for loop)
combos = itertools.product(remote_access_combinations, sync_modes)
for (ops1, ops2, hasconflict), sync_mode in combos:
for shouldsync in [False, True]: for shouldsync in [False, True]:
for (op1, op2) in itertools.product(ops1, ops2): for (op1, op2) in itertools.product(ops1, ops2):
self.tm = TemplateManager(min_ranks=3) self.tm = TemplateManager(min_ranks=3)
...@@ -252,24 +266,12 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -252,24 +266,12 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
inst1[-1].set_has_error(False) inst1[-1].set_has_error(False)
inst2[-1].set_has_error(False) inst2[-1].set_has_error(False)
if sync_mode == "fence": # generate code for the given sync_mode
self.fence_sync(self.tm, alloc_inst, valid_case = sync_mode(self.tm, alloc_inst, inst1, inst2, shouldsync)
inst1, inst2, shouldsync)
elif sync_mode == "lockall": if not valid_case:
self.lockall_sync( # this case is not possible / redundant for this sync_mode, continue
self.tm, alloc_inst, inst1, inst2, shouldsync)
elif sync_mode == "lock_flush":
self.lock_flush_sync(
self.tm, alloc_inst, inst1, inst2, shouldsync)
elif sync_mode == "request":
if isinstance(inst1[-1], MPICall) and inst1[-1].has_arg("request"):
self.req_sync(self.tm, alloc_inst,
inst1, inst2, shouldsync)
else:
continue continue
elif sync_mode == "pscw":
self.pscw(self.tm, alloc_inst,
inst1, inst2, shouldsync)
# finalize RMA call (if needed) # finalize RMA call (if needed)
self.tm.register_instruction(inst1_free) self.tm.register_instruction(inst1_free)
...@@ -281,7 +283,7 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -281,7 +283,7 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
self.tm.set_description( self.tm.set_description(
("GlobalConcurrency" if hasconflict and not shouldsync else "Correct") + ("GlobalConcurrency" if hasconflict and not shouldsync else "Correct") +
"-" "-"
+ sync_mode + sync_mode.__name__
+ "-" + "-"
+ op1_name + op1_name
+ "_" + "_"
...@@ -289,6 +291,3 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -289,6 +291,3 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
"full description", "full description",
) )
yield self.tm yield self.tm
# if generate_level <= BASIC_TEST_LEVEL:
# return
...@@ -14,61 +14,112 @@ import itertools ...@@ -14,61 +14,112 @@ import itertools
from typing import Tuple, List from typing import Tuple, List
class LocalConcurrencyErrorRMA(ErrorGenerator): class LocalConcurrencyErrorRMA(ErrorGenerator):
def __init__(self): def __init__(self):
self.cfmpi = CorrectMPICallFactory()
# generate standard buffer access instructions
self.buf_instructions = { self.buf_instructions = {
"bufread": Instruction(f'printf("buf is %d\\n", {CorrectParameterFactory().buf_var_name}[1]);', 0, "bufread"), "bufread": Instruction(f'printf("buf is %d\\n", {CorrectParameterFactory().buf_var_name}[1]);', 0, "bufread"),
"bufwrite": Instruction(f'{CorrectParameterFactory().buf_var_name}[1] = 42;', 0, "bufwrite") "bufwrite": Instruction(f'{CorrectParameterFactory().buf_var_name}[1] = 42;', 0, "bufwrite")
} }
pass
def get_feature(self): def get_feature(self):
return ["RMA"] return ["RMA"]
def fence_sync(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool): def fence(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
# open access epoch + sync # open access epoch + sync
tm.register_instruction(CorrectMPICallFactory().mpi_win_fence()) tm.register_instruction(self.cfmpi.mpi_win_fence())
tm.register_instruction(alloc_inst) tm.register_instruction(alloc_inst)
tm.register_instruction(op1, "OP1") tm.register_instruction(op1, "OP1")
# if accesses should be synced, add another fence (rank 0) # if accesses should be synced, add another fence (rank 0)
if shouldsync: if shouldsync:
tm.register_instruction(CorrectMPICallFactory().mpi_win_fence(), rank_to_execute=0) tm.register_instruction(
self.cfmpi.mpi_win_fence(), rank_to_execute=0)
tm.register_instruction(op2, "OP2") tm.register_instruction(op2, "OP2")
# if accesses should be synced, add another fence (rank 1) # if accesses should be synced, add another fence (rank 1)
if shouldsync: if shouldsync:
tm.register_instruction(CorrectMPICallFactory().mpi_win_fence(), rank_to_execute=1) tm.register_instruction(
self.cfmpi.mpi_win_fence(), rank_to_execute=1)
# finish access epoch + sync # finish access epoch + sync
tm.register_instruction(CorrectMPICallFactory().mpi_win_fence()) tm.register_instruction(self.cfmpi.mpi_win_fence())
return tm return True
def lockall_sync(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool): def lockallflush(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
# open access epoch + sync # open access epoch + sync
tm.register_instruction(CorrectMPICallFactory().mpi_win_lock_all()) tm.register_instruction(self.cfmpi.mpi_win_lock_all())
tm.register_instruction(alloc_inst) tm.register_instruction(alloc_inst)
tm.register_instruction(op1, "OP1") tm.register_instruction(op1, "OP1")
# if accesses should be synced, add flush # if accesses should be synced, add flush
if shouldsync: if shouldsync:
tm.register_instruction(CorrectMPICallFactory().mpi_win_flush_all(), rank_to_execute=0) tm.register_instruction(
self.cfmpi.mpi_win_flush_all(), rank_to_execute=0)
tm.register_instruction(op2, "OP2")
# finish access epoch + sync
tm.register_instruction(self.cfmpi.mpi_win_unlock_all())
return True
def lockallflushlocal(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
# should sync: MPI_Win_lock_all - op1 - MPI_Win_flush_local_all - op2 - MPI_Win_unlock_all
# shold not sync: MPI_Win_lock_all - op1 - op2 - MPI_Win_unlock_all
# open access epoch + sync
tm.register_instruction(self.cfmpi.mpi_win_lock_all())
tm.register_instruction(alloc_inst)
tm.register_instruction(op1, "OP1")
# if accesses should be synced, add flush_local
if shouldsync:
tm.register_instruction(
self.cfmpi.mpi_win_flush_local_all(), rank_to_execute=0)
tm.register_instruction(op2, "OP2") tm.register_instruction(op2, "OP2")
# finish access epoch + sync # finish access epoch + sync
tm.register_instruction(CorrectMPICallFactory().mpi_win_unlock_all()) tm.register_instruction(self.cfmpi.mpi_win_unlock_all())
return tm return True
def lock_flush_sync(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool): def lockunlock(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
lock = CorrectMPICallFactory().mpi_win_lock() # should sync: MPI_Win_lock - op1 - MPI_Win_unlock - op2
flush = CorrectMPICallFactory().mpi_win_flush() # shold not sync: MPI_Win_lock - op1 - op2 - MPI_Win_unlock
unlock = CorrectMPICallFactory().mpi_win_unlock()
lock = self.cfmpi.mpi_win_lock()
unlock = self.cfmpi.mpi_win_unlock()
lock.set_arg("rank", "1")
unlock.set_arg("rank", "1")
tm.register_instruction(lock, rank_to_execute=0)
tm.register_instruction(alloc_inst)
tm.register_instruction(op1, "OP1")
# if accesses should be synced, add flush here
if shouldsync:
tm.register_instruction(unlock, rank_to_execute=0)
tm.register_instruction(op2, "OP2")
else:
tm.register_instruction(op2, "OP2")
tm.register_instruction(unlock, rank_to_execute=0)
return True
def lockflush(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
lock = self.cfmpi.mpi_win_lock()
flush = self.cfmpi.mpi_win_flush()
unlock = self.cfmpi.mpi_win_unlock()
lock.set_arg("rank", "1") lock.set_arg("rank", "1")
flush.set_arg("rank", "1") flush.set_arg("rank", "1")
unlock.set_arg("rank", "1") unlock.set_arg("rank", "1")
...@@ -87,16 +138,43 @@ class LocalConcurrencyErrorRMA(ErrorGenerator): ...@@ -87,16 +138,43 @@ class LocalConcurrencyErrorRMA(ErrorGenerator):
# finish access epoch + sync # finish access epoch + sync
tm.register_instruction(unlock, rank_to_execute=0) tm.register_instruction(unlock, rank_to_execute=0)
return tm return True
def lockflushlocal(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
lock = self.cfmpi.mpi_win_lock()
flush_local = self.cfmpi.mpi_win_flush_local()
unlock = self.cfmpi.mpi_win_unlock()
lock.set_arg("rank", "1")
flush_local.set_arg("rank", "1")
unlock.set_arg("rank", "1")
tm.register_instruction(lock, rank_to_execute=0)
tm.register_instruction(alloc_inst)
tm.register_instruction(op1, "OP1")
# if accesses should be synced, add flush here
if shouldsync:
tm.register_instruction(flush_local, rank_to_execute=0)
tm.register_instruction(op2, "OP2")
# finish access epoch + sync
tm.register_instruction(unlock, rank_to_execute=0)
return True
def request(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
# only consider combination where the first operation is a request-based RMA call
if not isinstance(op1[-1], MPICall) or not op1[-1].has_arg("request"):
return False
def req_sync(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
# we assume that the first operation is request-based # we assume that the first operation is request-based
wait = CorrectMPICallFactory().mpi_wait() wait = self.cfmpi.mpi_wait()
wait.set_arg("request", op1[-1].get_arg("request")) wait.set_arg("request", op1[-1].get_arg("request"))
# open access epoch + sync # open access epoch + sync
tm.register_instruction(CorrectMPICallFactory().mpi_win_lock_all()) tm.register_instruction(self.cfmpi.mpi_win_lock_all())
tm.register_instruction(alloc_inst) tm.register_instruction(alloc_inst)
tm.register_instruction(op1, "OP1") tm.register_instruction(op1, "OP1")
...@@ -108,37 +186,43 @@ class LocalConcurrencyErrorRMA(ErrorGenerator): ...@@ -108,37 +186,43 @@ class LocalConcurrencyErrorRMA(ErrorGenerator):
tm.register_instruction(op2, "OP2") tm.register_instruction(op2, "OP2")
# finish access epoch + sync # finish access epoch + sync
tm.register_instruction(CorrectMPICallFactory().mpi_win_unlock_all()) tm.register_instruction(self.cfmpi.mpi_win_unlock_all())
return tm
return True
def pscw(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool): def pscw(self, tm: TemplateManager, alloc_inst: Instruction, op1: List[Instruction], op2: List[Instruction], shouldsync: bool):
tm.register_instruction("MPI_Group world_group;") tm.register_instruction("MPI_Group world_group;")
tm.register_instruction("MPI_Comm_group(MPI_COMM_WORLD, &world_group);") tm.register_instruction(
"MPI_Comm_group(MPI_COMM_WORLD, &world_group);")
tm.register_instruction("int destrank = 1; MPI_Group mpi_group_0; MPI_Group_incl(world_group, 1, &destrank, &mpi_group_0);", rank_to_execute=0) tm.register_instruction(
tm.register_instruction(CorrectMPICallFactory().mpi_win_start(), rank_to_execute=0) "int destrank = 1; MPI_Group mpi_group_0; MPI_Group_incl(world_group, 1, &destrank, &mpi_group_0);", rank_to_execute=0)
tm.register_instruction(
self.cfmpi.mpi_win_start(), rank_to_execute=0)
tm.register_instruction(alloc_inst) tm.register_instruction(alloc_inst)
tm.register_instruction(op1, "OP1") tm.register_instruction(op1, "OP1")
# if accesses should be synced, end access epoch here # if accesses should be synced, end access epoch here
if shouldsync: if shouldsync:
tm.register_instruction(CorrectMPICallFactory().mpi_win_complete(), rank_to_execute=0) tm.register_instruction(
self.cfmpi.mpi_win_complete(), rank_to_execute=0)
tm.register_instruction(op2, "OP2") tm.register_instruction(op2, "OP2")
# if accesses should not be synced, end access epoch here # if accesses should not be synced, end access epoch here
if not shouldsync: if not shouldsync:
tm.register_instruction(CorrectMPICallFactory().mpi_win_complete(), rank_to_execute=0) tm.register_instruction(
self.cfmpi.mpi_win_complete(), rank_to_execute=0)
tm.register_instruction("int srcrank = 0; MPI_Group mpi_group_0; MPI_Group_incl(world_group, 1, &srcrank, &mpi_group_0);", rank_to_execute=1) tm.register_instruction(
tm.register_instruction(CorrectMPICallFactory().mpi_win_post(), rank_to_execute=1) "int srcrank = 0; MPI_Group mpi_group_0; MPI_Group_incl(world_group, 1, &srcrank, &mpi_group_0);", rank_to_execute=1)
tm.register_instruction(CorrectMPICallFactory().mpi_win_wait(), rank_to_execute=1) tm.register_instruction(
self.cfmpi.mpi_win_post(), rank_to_execute=1)
return tm tm.register_instruction(
self.cfmpi.mpi_win_wait(), rank_to_execute=1)
return True
def get_mem_op(self, name: str, rank) -> Tuple[List[Instruction], List[Instruction]]: def get_mem_op(self, name: str, rank) -> Tuple[List[Instruction], List[Instruction]]:
if name.startswith("mpi"): if name.startswith("mpi"):
...@@ -149,25 +233,22 @@ class LocalConcurrencyErrorRMA(ErrorGenerator): ...@@ -149,25 +233,22 @@ class LocalConcurrencyErrorRMA(ErrorGenerator):
def generate(self, generate_level): def generate(self, generate_level):
# build set of calls based on generate level, for level 1 just a few basic calls, # build set of calls based on generate level, for level 1 just a few basic calls,
# for level >= 2 all calls # for level >= 2 all calls
if generate_level == 1:
# only basic calls
local_origin_addr_read = ["mpi_put", "mpi_accumulate"] local_origin_addr_read = ["mpi_put", "mpi_accumulate"]
local_origin_addr_write = ["mpi_get"] local_origin_addr_write = ["mpi_get"]
else:
if generate_level >= 2: # everything
local_origin_addr_read.extend([ local_origin_addr_read = ["mpi_put", "mpi_accumulate", "mpi_rput",
"mpi_rput",
"mpi_raccumulate", "mpi_raccumulate",
"mpi_get_accumulate", "mpi_get_accumulate",
"mpi_rget_accumulate", "mpi_rget_accumulate",
"mpi_fetch_and_op", "mpi_fetch_and_op",
"mpi_compare_and_swap" "mpi_compare_and_swap"]
]) local_origin_addr_write = ["mpi_get", "mpi_rget"]
local_origin_addr_write.extend([
"mpi_rget"
])
cf = CorrectParameterFactory() cf = CorrectParameterFactory()
cfmpi = CorrectMPICallFactory()
# possible combinations of local buffer accesses (hasconflict = True | False) # possible combinations of local buffer accesses (hasconflict = True | False)
local_access_combinations: List[Tuple[List[str], List[str], bool]] = [ local_access_combinations: List[Tuple[List[str], List[str], bool]] = [
...@@ -188,17 +269,28 @@ class LocalConcurrencyErrorRMA(ErrorGenerator): ...@@ -188,17 +269,28 @@ class LocalConcurrencyErrorRMA(ErrorGenerator):
(local_origin_addr_write, local_origin_addr_write, True), (local_origin_addr_write, local_origin_addr_write, True),
] ]
for ops1, ops2, hasconflict in local_access_combinations: sync_modes = [self.fence, self.lockallflush, self.lockallflushlocal, self.lockflush, self.lockflushlocal, self.lockunlock, self.request, self.pscw]
for sync_mode in ["fence", "lockall", "lock_flush", "request", "pscw"]:
if generate_level <= 2:
# go through all sync modes, but only one access combination per sync mode, fill with fence
combos = itertools.zip_longest(
local_access_combinations, sync_modes, fillvalue=self.fence)
else:
# combine everything (= nested for loop)
combos = itertools.product(local_access_combinations, sync_modes)
for (ops1, ops2, hasconflict), sync_mode in combos:
for shouldsync in [False, True]: for shouldsync in [False, True]:
for (op1, op2) in itertools.product(ops1, ops2): for (op1, op2) in itertools.product(ops1, ops2):
self.tm = TemplateManager() self.tm = TemplateManager()
(win_alloc, win_free) = get_allocated_window("mpi_win_create", cf.get("win"), cf.winbuf_var_name, "int", "10") (win_alloc, win_free) = get_allocated_window(
"mpi_win_create", cf.get("win"), cf.winbuf_var_name, "int", "10")
# window allocation boilerplate # window allocation boilerplate
self.tm.register_instruction(win_alloc) self.tm.register_instruction(win_alloc)
# local buffer allocation # local buffer allocation
alloc_inst = AllocCall(cf.dtype[0], cf.buf_size, cf.buf_var_name, use_malloc=False, identifier="alloc", rank=0) alloc_inst = AllocCall(
cf.dtype[0], cf.buf_size, cf.buf_var_name, use_malloc=False, identifier="alloc", rank=0)
op1_name = op1.replace("mpi_", "") op1_name = op1.replace("mpi_", "")
op2_name = op2.replace("mpi_", "") op2_name = op2.replace("mpi_", "")
...@@ -217,19 +309,12 @@ class LocalConcurrencyErrorRMA(ErrorGenerator): ...@@ -217,19 +309,12 @@ class LocalConcurrencyErrorRMA(ErrorGenerator):
inst1[-1].set_has_error(False) inst1[-1].set_has_error(False)
inst2[-1].set_has_error(False) inst2[-1].set_has_error(False)
if sync_mode == "fence": # generate code for the given sync_mode
self.fence_sync(self.tm, alloc_inst, inst1, inst2, shouldsync) valid_case = sync_mode(self.tm, alloc_inst, inst1, inst2, shouldsync)
elif sync_mode == "lockall":
self.lockall_sync(self.tm, alloc_inst, inst1, inst2, shouldsync) if not valid_case:
elif sync_mode == "lock_flush": # this case is not possible / redundant for this sync_mode, continue
self.lock_flush_sync(self.tm, alloc_inst, inst1, inst2, shouldsync)
elif sync_mode == "request":
if isinstance(inst1[-1], MPICall) and inst1[-1].has_arg("request"):
self.req_sync(self.tm, alloc_inst, inst1, inst2, shouldsync)
else:
continue continue
elif sync_mode == "pscw":
self.pscw(self.tm, alloc_inst, inst1, inst2, shouldsync)
# finalize RMA call (if needed) # finalize RMA call (if needed)
self.tm.register_instruction(inst1_free) self.tm.register_instruction(inst1_free)
...@@ -241,7 +326,7 @@ class LocalConcurrencyErrorRMA(ErrorGenerator): ...@@ -241,7 +326,7 @@ class LocalConcurrencyErrorRMA(ErrorGenerator):
self.tm.set_description( self.tm.set_description(
("LocalConcurrency" if hasconflict and not shouldsync else "Correct") + ("LocalConcurrency" if hasconflict and not shouldsync else "Correct") +
"-" "-"
+ sync_mode + sync_mode.__name__
+ "-" + "-"
+ op1_name + op1_name
+ "_" + "_"
...@@ -249,6 +334,3 @@ class LocalConcurrencyErrorRMA(ErrorGenerator): ...@@ -249,6 +334,3 @@ class LocalConcurrencyErrorRMA(ErrorGenerator):
"full description", "full description",
) )
yield self.tm yield self.tm
# if generate_level <= BASIC_TEST_LEVEL:
# return
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment