Skip to content
Snippets Groups Projects

RMA Test Cases

Open Simon Schwitanski requested to merge rma into main
+ 140
0
 
#! /usr/bin/python3
 
 
from scripts.Infrastructure.ErrorGenerator import ErrorGenerator
 
from scripts.Infrastructure.InstructionBlock import InstructionBlock
 
from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory
 
from scripts.Infrastructure.CorrectParameter import (
 
CorrectParameterFactory,
 
get_matching_recv,
 
)
 
from scripts.Infrastructure.Template import TemplateManager
 
from scripts.Infrastructure.TemplateFactory import get_allocated_window, get_rma_call
 
from scripts.Infrastructure.AllocCall import AllocCall
 
from scripts.Infrastructure.MPICall import MPI_Call
 
 
import itertools
 
 
from scripts.Infrastructure.Variables import ERROR_MARKER_COMMENT
 
 
from typing import Tuple, List
 
 
 
class LocalConcurrencyErrorRMA(ErrorGenerator):
 
local_origin_addr_write = ["mpi_get", "mpi_rget"]
 
local_origin_addr_read = [
 
"mpi_put",
 
"mpi_rput",
 
"mpi_accumulate",
 
"mpi_raccumulate",
 
"mpi_get_accumulate",
 
"mpi_rget_accumulate",
 
"mpi_fetch_and_op",
 
"mpi_compare_and_swap",
 
]
 
functions_to_check = ["mpi_put", "mpi_get", "mpi_rput", "mpi_rget"]
 
 
# recv_funcs = ["mpi_irecv", "mpi_recv_init", "mpi_precv_init"]
 
 
def __init__(self):
 
pass
 
 
def get_feature(self):
 
return ["RMA"]
 
 
def generate(self, generate_full_set):
 
 
cf = CorrectParameterFactory()
 
cfmpi = CorrectMPICallFactory()
 
 
mpi_buf_read = [
 
get_rma_call("mpi_put", 0),
 
get_rma_call("mpi_rput", 0),
 
get_rma_call("mpi_accumulate", 0),
 
get_rma_call("mpi_raccumulate", 0),
 
get_rma_call("mpi_get_accumulate", 0),
 
get_rma_call("mpi_rget_accumulate", 0),
 
get_rma_call("mpi_fetch_and_op", 0),
 
get_rma_call("mpi_compare_and_swap", 0),
 
]
 
mpi_buf_write = [get_rma_call("mpi_get", 0), get_rma_call("mpi_rget", 0)]
 
 
bufread = InstructionBlock("bufread")
 
bufread.register_instruction(f'printf("buf is %d\\n", {cf.buf_var_name}[1]);', 0)
 
bufwrite = InstructionBlock("write")
 
bufwrite.register_instruction(f'{cf.buf_var_name}[1] = 42;', 0)
 
 
# 7 possible combinations of local buffer accesses (hasconflict = True | False)
 
local_access_combinations: List[Tuple[List[InstructionBlock], List[InstructionBlock], bool]] = [
 
(mpi_buf_read, [bufread], False),
 
(mpi_buf_read, [bufwrite], True),
 
(mpi_buf_write, [bufread], True),
 
(mpi_buf_write, [bufwrite], True),
 
(mpi_buf_read, mpi_buf_read, False),
 
(mpi_buf_read, mpi_buf_write, True),
 
(mpi_buf_write, mpi_buf_write, True),
 
]
 
 
for ops1, ops2, hasconflict in local_access_combinations:
 
for (op1, op2) in itertools.product(ops1, ops2):
 
tm = TemplateManager()
 
# window allocation boilerplate
 
b = get_allocated_window("mpi_win_create", "win", "winbuf", "int", "2")
 
tm.register_instruction_block(b)
 
 
# local buffer allocation
 
alloc = InstructionBlock("alloc")
 
alloc.register_instruction(
 
AllocCall(cf.dtype[0], cf.buf_size, cf.buf_var_name)
 
)
 
tm.register_instruction_block(alloc)
 
 
if hasconflict:
 
op1.get_instruction(kind=0, index=-1).set_has_error()
 
op2.get_instruction(kind=0, index=-1).set_has_error()
 
 
# fuse instructions blocks
 
# combined_ops = InstructionBlock("COMBINED")
 
# combined_ops.register_operations(op1.get_operations(kind=0), kind=0)
 
# combined_ops.register_operations(op2.get_operations(kind=0), kind=0)
 
tm.register_instruction_block(op1)
 
tm.register_instruction_block(op2)
 
 
tm.set_description(
 
("LocalConcurrency" if hasconflict else "Correct") +
 
"-"
 
+ op1.name
 
+ "_"
 
+ op2.name,
 
"full description",
 
)
 
yield tm
 
 
# get RMA call
 
# rmaop = get_rma_call(function_to_check, 0)
 
 
# tm.register_instruction_block(rmaop)
 
 
# bufstring = ""
 
# if bufop == "read": # local buffer access is read
 
# bufstring = f'printf("buf is %d\\n", {cf.buf_var_name}[1]);'
 
# # if RMA call performs local buffer write, this is a race, otherwise no race
 
# if function_to_check in local_origin_addr_write:
 
# bufstring += ERROR_MARKER_COMMENT
 
# # mark RMA call as erroneous
 
# tm.get_block("RMACALL").get_operation(
 
# kind=0, index=-1
 
# ).set_has_error()
 
 
# if bufop == "write":
 
# # a buffer write is always a race
 
# bufstring = f"{cf.buf_var_name}[1] = 42;" + ERROR_MARKER_COMMENT
 
# # mark RMA call as erroneous
 
# tm.get_block("RMACALL").get_operation(
 
# kind=0, index=-1
 
# ).set_has_error()
 
 
# # finally register buffer access
 
# tm.get_block("RMACALL").register_operation(bufstring, 0)
 
 
# if not generate_full_set:
 
# return
Loading