Skip to content
Snippets Groups Projects
Commit 20a74b67 authored by Oraji, Yussur Mustafa's avatar Oraji, Yussur Mustafa
Browse files

Port ComplexMismatch

parent b92c3d05
No related branches found
No related tags found
1 merge request!25Draft: Fortran Support
......@@ -38,20 +38,17 @@ class Branch:
class IfBranch(Branch):
@override
def __init__(self, cond: str):
def __init__(self, cond: str, elseif: bool = False):
self._cond = cond
self._elseif = elseif
@override
def header(self):
if infvars.generator_language == "c":
return if_template_c.replace("@COND@", self._cond)
res = if_template_c.replace("@COND@", self._cond)
else:
actual_cond = self._cond
# Replace array access [*] with (*)
actual_cond = re.sub(r"([A-z0-9]+)\s*\[(.+)\]", r"\1(\2)", actual_cond)
# Replace not-equal comparator
actual_cond = actual_cond.replace("!=", "/=")
return if_template_fort.replace("@COND@", actual_cond)
res = if_template_fort.replace("@COND@", adjust_var_language(self._cond)).replace("!=", "/=")
return res if not self._elseif else f"else {res}"
@override
@staticmethod
......
#! /usr/bin/python3
from __future__ import annotations
from typing_extensions import override
import re
from Infrastructure.Instruction import Instruction
from Infrastructure.Variables import ERROR_MARKER_COMMENT_BEGIN, ERROR_MARKER_COMMENT_END, ERROR_MARKER_COMMENT_BEGIN_FORT, ERROR_MARKER_COMMENT_END_FORT, adjust_var_language
import Infrastructure.Variables as infvars
print_template_c = "printf(\"@STRING@\"@ARGS@);"
print_template_fort = "print *, \"@STRING@\"@ARGS@"
"""
Class Overview:
The `PrintInst` class is a helper for creating print instructions
Methods:
- `__init__(self)`: Initializes a new print instruction
- `__str__(self)`: Converts the print instance to a string, replacing placeholders.
"""
class PrintInst(Instruction):
@override
def __init__(self, string: str, args: List[str] = []):
"""
Creates a new print instruction
Args:
string: String to print
args: List of variables to print (postfix)
"""
super().__init__("")
self._string = string
self._args = args
@override
def __str__(self):
actual_template = print_template_c if infvars.generator_language == "c" else print_template_fort
actual_string = self._string
arg_str = ""
for arg in self._args:
arg_str = ", " + adjust_var_language(arg)
if infvars.generator_language == "c":
actual_string += " %d"
if infvars.generator_language == "c":
actual_string += "\\n"
result = actual_template.replace("@STRING@", actual_string).replace("@ARGS@", arg_str)
if infvars.generator_language == "c":
error_begin = ERROR_MARKER_COMMENT_BEGIN
error_end = ERROR_MARKER_COMMENT_END
else:
error_begin = ERROR_MARKER_COMMENT_BEGIN_FORT
error_end = ERROR_MARKER_COMMENT_END_FORT
if self.has_error():
result = error_begin + result + error_end
return result
......@@ -5,6 +5,10 @@ from Infrastructure.Variables import *
from Infrastructure.ErrorGenerator import ErrorGenerator
from Infrastructure.Template import TemplateManager
from Infrastructure.TemplateFactory import get_send_recv_template
from Infrastructure.AllocCall import AllocCall
from Infrastructure.MPICallFactory import MPICallFactory
from Infrastructure.Branches import IfBranch, ForLoop
from Infrastructure.PrintInst import PrintInst
class UnmatchedP2Pcall(ErrorGenerator):
......@@ -79,7 +83,6 @@ class ComplexMissmach(ErrorGenerator):
def generate(self, generate_level, real_world_score_table):
tm = TemplateManager()
tm.add_stack_variable("int", "i")
tm.add_stack_variable("MPI_Request", "request")
tm.add_stack_variable("MPI_Status", "status")
tm.add_stack_variable("int", "countEvenNumbers")
......@@ -89,40 +92,37 @@ class ComplexMissmach(ErrorGenerator):
tm.register_instruction(Instruction("#define MSG_TAG_A 124523"))
tm.register_instruction(Instruction("#define N 10"))
tm.register_instruction(Instruction("#define EVEN 0"))
code = """
int buffer[N];
for (i = 0; i < 10; i++) {
if (rank == 0) {
tag_sender = i * N;
MPI_Isend(buffer, 1, MPI_INT, 1, tag_sender, MPI_COMM_WORLD, &request);
MPI_Wait(&request, &status);
}
else if (rank == 1) {
tag_receiver = i * N;
if (i % 2 == EVEN) {
(countEvenNumbers)++;
}
if ((countEvenNumbers) == (N / 2)) {
tag_receiver++; // mismatch
}
printf(\"Count Even Numbers: %d \\n\", countEvenNumbers);
MPI_Irecv(buffer, 1, MPI_INT, 0, tag_receiver, MPI_COMM_WORLD, &request);
MPI_Wait(&request, &status);
}
}
"""
i = Instruction(code)
i.set_has_error()
tm.register_instruction(i)
tm.set_can_deadlock()
tm.register_instruction(AllocCall("int", "N", "buffer"))
tm.register_instruction(ForLoop(0, 10).header())
tm.register_instruction(IfBranch("rank == 0").header())
tm.register_instruction("tag_sender = i*N")
tm.register_instruction(MPICallFactory.mpi_isend("buffer", 1, "MPI_INT", 1, "tag_sender", "MPI_COMM_WORLD", "&request"))
tm.register_instruction(MPICallFactory.mpi_wait("&request", "&status"))
tm.register_instruction(IfBranch("rank == 1", elseif=True).header())
#TODO SET_HAS_ERROR
tm.register_instruction("tag_receiver = i*N")
tm.register_instruction(IfBranch("i % 2 == EVEN").header())
tm.register_instruction("countEvenNumbers = countEvenNumbers + 1")
tm.register_instruction(IfBranch.trailer())
tm.register_instruction(IfBranch("(countEvenNumbers) == (N / 2)").header())
tm.register_instruction("tag_receiver = tag_receiver + 1")
tm.register_instruction(IfBranch.trailer())
tm.register_instruction(PrintInst("Count Even Numbers: ", ["countEvenNumbers"]))
recv_call = MPICallFactory.mpi_irecv("buffer", 1, "MPI_INT", 0, "tag_receiver", "MPI_COMM_WORLD", "&request")
recv_call.set_has_error()
tm.register_instruction(recv_call)
tm.register_instruction(MPICallFactory.mpi_wait("&request", "&status"))
tm.register_instruction(IfBranch.trailer())
tm.register_instruction(ForLoop.trailer())
tm.set_can_deadlock()
tm.set_description("GlobalParameterMissmatch-tag-mpi_send", "Missmatching message tags in iteration 10")
yield tm
......@@ -12,7 +12,7 @@ from Infrastructure.Template import TemplateManager
from Infrastructure.TemplateFactory import get_allocated_window, get_rma_call
from Infrastructure.AllocCall import AllocCall
from Infrastructure.ArrAsgn import ArrAsgn
import Infrastructure.Variables as infvars
from Infrastructure.PrintInst import PrintInst
import itertools
......@@ -40,12 +40,8 @@ class GlobalConcurrencyErrorRMA(ErrorGenerator):
localbufread.set_arg("target_count", "1")
localbufread.set_arg("target_rank", "0")
if infvars.generator_language == "c":
read_str = f'printf("winbuf is %d\\n", {CorrectParameterFactory().winbuf_var_name}[1]);'
else:
read_str = f'print *, "winbuf is ", {CorrectParameterFactory().winbuf_var_name}(1)'
self.buf_instructions = {
"bufread": Instruction(read_str, 1, "bufread"),
"bufread": PrintInst("winbuf is ", [f"{CorrectParameterFactory().winbuf_var_name}[1]"]),
"bufwrite": ArrAsgn(CorrectParameterFactory().winbuf_var_name, 1, 42, rank=1, identifier="bufwrite"),
"localbufread": localbufread,
"localbufwrite": localbufwrite
......
......@@ -12,7 +12,7 @@ from Infrastructure.Template import TemplateManager
from Infrastructure.TemplateFactory import get_allocated_window, get_rma_call
from Infrastructure.AllocCall import AllocCall
from Infrastructure.ArrAsgn import ArrAsgn
import Infrastructure.Variables as infvars
from Infrastructure.PrintInst import PrintInst
import itertools
......@@ -23,12 +23,8 @@ class LocalConcurrencyErrorRMA(ErrorGenerator):
def __init__(self):
self.cfmpi = CorrectMPICallFactory()
# generate standard buffer access instructions
if infvars.generator_language == "c":
read_str = f'printf("buf is %d\\n", {CorrectParameterFactory().buf_var_name}[0]);'
else:
read_str = f'print *, "buf is ", {CorrectParameterFactory().buf_var_name}(0)'
self.buf_instructions = {
"bufread": Instruction(read_str, 0, "bufread"),
"bufread": PrintInst("buf is ", [f"{CorrectParameterFactory().buf_var_name}[0]"]),
"bufwrite": ArrAsgn(CorrectParameterFactory().buf_var_name, 0, 42)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment