Skip to content
Snippets Groups Projects
Verified Commit 683de0c2 authored by Simon Schwitanski's avatar Simon Schwitanski :slight_smile:
Browse files

Add initial version of LocalConcurrency RMA generator

parent 17730c7d
No related branches found
No related tags found
2 merge requests!9Infrastructure: Type Hints, Instruction class and lists of instructions,!8Draft: RMA
#! /usr/bin/python3
from scripts.Infrastructure.ErrorGenerator import ErrorGenerator
from scripts.Infrastructure.InstructionBlock import InstructionBlock
from scripts.Infrastructure.MPICallFactory import MPICallFactory, CorrectMPICallFactory
from scripts.Infrastructure.CorrectParameter import (
CorrectParameterFactory,
get_matching_recv,
)
from scripts.Infrastructure.Template import TemplateManager
from scripts.Infrastructure.TemplateFactory import get_allocated_window, get_rma_call
from scripts.Infrastructure.AllocCall import AllocCall
from scripts.Infrastructure.MPICall import MPI_Call
import itertools
from scripts.Infrastructure.Variables import ERROR_MARKER_COMMENT
from typing import Tuple, List
class LocalConcurrencyErrorRMA(ErrorGenerator):
local_origin_addr_write = ["mpi_get", "mpi_rget"]
local_origin_addr_read = [
"mpi_put",
"mpi_rput",
"mpi_accumulate",
"mpi_raccumulate",
"mpi_get_accumulate",
"mpi_rget_accumulate",
"mpi_fetch_and_op",
"mpi_compare_and_swap",
]
functions_to_check = ["mpi_put", "mpi_get", "mpi_rput", "mpi_rget"]
# recv_funcs = ["mpi_irecv", "mpi_recv_init", "mpi_precv_init"]
def __init__(self):
pass
def get_feature(self):
return ["RMA"]
def generate(self, generate_full_set):
cf = CorrectParameterFactory()
cfmpi = CorrectMPICallFactory()
mpi_buf_read = [
get_rma_call("mpi_put", 0),
get_rma_call("mpi_rput", 0),
get_rma_call("mpi_accumulate", 0),
get_rma_call("mpi_raccumulate", 0),
get_rma_call("mpi_get_accumulate", 0),
get_rma_call("mpi_rget_accumulate", 0),
get_rma_call("mpi_fetch_and_op", 0),
get_rma_call("mpi_compare_and_swap", 0),
]
mpi_buf_write = [get_rma_call("mpi_get", 0), get_rma_call("mpi_rget", 0)]
bufread = InstructionBlock("bufread")
bufread.register_operation(f'printf("buf is %d\\n", {cf.buf_var_name}[1]);', 0)
bufwrite = InstructionBlock("write")
bufwrite.register_operation(f'{cf.buf_var_name}[1] = 42;', 0)
# 7 possible combinations of local buffer accesses (hasconflict = True | False)
local_access_combinations: List[Tuple[List[InstructionBlock], List[InstructionBlock], bool]] = [
(mpi_buf_read, [bufread], False),
(mpi_buf_read, [bufwrite], True),
(mpi_buf_write, [bufread], True),
(mpi_buf_write, [bufwrite], True),
(mpi_buf_read, mpi_buf_read, False),
(mpi_buf_read, mpi_buf_write, True),
(mpi_buf_write, mpi_buf_write, True),
]
for ops1, ops2, hasconflict in local_access_combinations:
for (op1, op2) in itertools.product(ops1, ops2):
tm = TemplateManager()
# window allocation boilerplate
b = get_allocated_window("mpi_win_create", "win", "winbuf", "int", "2")
tm.register_instruction_block(b)
# local buffer allocation
alloc = InstructionBlock("alloc")
alloc.register_operation(
AllocCall(cf.dtype[0], cf.buf_size, cf.buf_var_name)
)
tm.register_instruction_block(alloc)
if hasconflict:
if isinstance(op1.get_operation(kind=0, index=-1), MPI_Call):
op1.get_operation(kind=0, index=-1).set_has_error()
else:
op1.replace_operation(op1.get_operation(kind=0, index=-1) + ERROR_MARKER_COMMENT, 0, 0)
if isinstance(op2.get_operation(kind=0, index=-1), MPI_Call):
op2.get_operation(kind=0, index=-1).set_has_error()
else:
op2.replace_operation(op2.get_operation(kind=0, index=-1) + ERROR_MARKER_COMMENT, 0, 0)
# fuse instructions blocks
# combined_ops = InstructionBlock("COMBINED")
# combined_ops.register_operations(op1.get_operations(kind=0), kind=0)
# combined_ops.register_operations(op2.get_operations(kind=0), kind=0)
tm.register_instruction_block(op1)
tm.register_instruction_block(op2)
tm.set_description(
("LocalConcurrency" if hasconflict else "Correct") +
"-"
+ op1.name
+ "_"
+ op2.name,
"full description",
)
yield tm
# get RMA call
# rmaop = get_rma_call(function_to_check, 0)
# tm.register_instruction_block(rmaop)
# bufstring = ""
# if bufop == "read": # local buffer access is read
# bufstring = f'printf("buf is %d\\n", {cf.buf_var_name}[1]);'
# # if RMA call performs local buffer write, this is a race, otherwise no race
# if function_to_check in local_origin_addr_write:
# bufstring += ERROR_MARKER_COMMENT
# # mark RMA call as erroneous
# tm.get_block("RMACALL").get_operation(
# kind=0, index=-1
# ).set_has_error()
# if bufop == "write":
# # a buffer write is always a race
# bufstring = f"{cf.buf_var_name}[1] = 42;" + ERROR_MARKER_COMMENT
# # mark RMA call as erroneous
# tm.get_block("RMACALL").get_operation(
# kind=0, index=-1
# ).set_has_error()
# # finally register buffer access
# tm.get_block("RMACALL").register_operation(bufstring, 0)
# if not generate_full_set:
# return
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment