Skip to content
Snippets Groups Projects
Verified Commit d396c711 authored by Simon Schwitanski's avatar Simon Schwitanski :slight_smile:
Browse files

Improve distribution of sync patterns in RMA race tests

parent b124dffe
No related branches found
No related tags found
1 merge request!20Parsing and tools updates
...@@ -126,7 +126,7 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -126,7 +126,7 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
return True return True
def request(self, tm: TemplateManager, alloc_inst: Instruction, alloc1: List[Instruction], op1: Instruction, alloc2: List[Instruction], op2: Instruction, shouldsync: bool): def rmarequest(self, tm: TemplateManager, alloc_inst: Instruction, alloc1: List[Instruction], op1: Instruction, alloc2: List[Instruction], op2: Instruction, shouldsync: bool):
# only consider combination where the first operation is a request-based RMA call # only consider combination where the first operation is a request-based RMA call
if not isinstance(op1, MPICall) or not op1.has_arg("request"): if not isinstance(op1, MPICall) or not op1.has_arg("request"):
return False return False
...@@ -240,12 +240,11 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator): ...@@ -240,12 +240,11 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
(remote_atomic_update, ["bufwrite", "localbufwrite"], True), (remote_atomic_update, ["bufwrite", "localbufwrite"], True),
] ]
sync_modes = [self.fence, self.lockall, self.lock, self.request] sync_modes = [self.fence, self.lockall, self.lock, self.rmarequest]
if generate_level <= SUFFICIENT_TEST_LEVEL: if generate_level <= SUFFICIENT_TEST_LEVEL:
# go through all sync modes, but only one access combination per sync mode, fill with fence # go through all sync modes, but only one access combination per sync mode
combos = itertools.zip_longest( combos = [(comb, sync_modes[i % len(sync_modes)]) for (i, comb) in enumerate(remote_access_combinations)]
remote_access_combinations, sync_modes, fillvalue=self.fence)
else: else:
# combine everything (= nested for loop) # combine everything (= nested for loop)
combos = itertools.product(remote_access_combinations, sync_modes) combos = itertools.product(remote_access_combinations, sync_modes)
......
...@@ -179,7 +179,7 @@ class LocalConcurrencyErrorRMA(ErrorGenerator): ...@@ -179,7 +179,7 @@ class LocalConcurrencyErrorRMA(ErrorGenerator):
return True return True
def request(self, tm: TemplateManager, alloc_inst: Instruction, alloc1: List[Instruction], op1: Instruction, alloc2: List[Instruction], op2: Instruction, shouldsync: bool): def rmarequest(self, tm: TemplateManager, alloc_inst: Instruction, alloc1: List[Instruction], op1: Instruction, alloc2: List[Instruction], op2: Instruction, shouldsync: bool):
# only consider combination where the first operation is a request-based RMA call # only consider combination where the first operation is a request-based RMA call
if not isinstance(op1, MPICall) or not op1.has_arg("request"): if not isinstance(op1, MPICall) or not op1.has_arg("request"):
return False return False
...@@ -288,12 +288,11 @@ class LocalConcurrencyErrorRMA(ErrorGenerator): ...@@ -288,12 +288,11 @@ class LocalConcurrencyErrorRMA(ErrorGenerator):
(local_origin_addr_write, local_origin_addr_write, True), (local_origin_addr_write, local_origin_addr_write, True),
] ]
sync_modes = [self.fence, self.lockallflush, self.lockallflushlocal, self.lockflush, self.lockflushlocal, self.lockunlock, self.request] sync_modes = [self.fence, self.lockallflush, self.lockallflushlocal, self.lockflush, self.lockflushlocal, self.lockunlock, self.rmarequest]
if generate_level <= SUFFICIENT_TEST_LEVEL: if generate_level <= SUFFICIENT_TEST_LEVEL:
# go through all sync modes, but only one access combination per sync mode, fill with fence # go through all sync modes, but only one access combination per sync mode
combos = itertools.zip_longest( combos = [(comb, sync_modes[i % len(sync_modes)]) for (i, comb) in enumerate(local_access_combinations)]
local_access_combinations, sync_modes, fillvalue=self.fence)
else: else:
# combine everything (= nested for loop) # combine everything (= nested for loop)
combos = itertools.product(local_access_combinations, sync_modes) combos = itertools.product(local_access_combinations, sync_modes)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment