Skip to content
Snippets Groups Projects
Commit b92c3d05 authored by Oraji, Yussur Mustafa's avatar Oraji, Yussur Mustafa
Browse files

Add generic for loops and if branches

parent 0cff18bb
Branches
No related tags found
1 merge request!25Draft: Fortran Support
#! /usr/bin/python3
from __future__ import annotations
from typing_extensions import override
import re
from Infrastructure.Instruction import Instruction
from Infrastructure.Variables import ERROR_MARKER_COMMENT_BEGIN, ERROR_MARKER_COMMENT_END, ERROR_MARKER_COMMENT_BEGIN_FORT, ERROR_MARKER_COMMENT_END_FORT, adjust_var_language
import Infrastructure.Variables as infvars
for_template_c = "for (int i = @START@; i < @END@; ++i)"
for_template_fort = "do i=@START@, @END@"
if_template_c = "if (@COND@) {"
if_template_fort = "if (@COND@) then"
"""
Class Overview:
The `Branch` class is a prototype for the specific IfBranch/ForBranch helpers for creating if/for branches/loops.
It should not be used directly
Methods:
- `__init__(self)`: Initializes a new branch
- `header(self)`: Returns a string representing the branch header
- `trailer(self)`: Returns a string representing the branch trailer
"""
class Branch:
@override
def __init__(self):
pass
def header(self):
return ""
@staticmethod
def trailer():
return ""
class IfBranch(Branch):
@override
def __init__(self, cond: str):
self._cond = cond
@override
def header(self):
if infvars.generator_language == "c":
return if_template_c.replace("@COND@", self._cond)
else:
actual_cond = self._cond
# Replace array access [*] with (*)
actual_cond = re.sub(r"([A-z0-9]+)\s*\[(.+)\]", r"\1(\2)", actual_cond)
# Replace not-equal comparator
actual_cond = actual_cond.replace("!=", "/=")
return if_template_fort.replace("@COND@", actual_cond)
@override
@staticmethod
def trailer():
if infvars.generator_language == "c":
return "}"
else:
return "end if"
class ForLoop(Branch):
@override
def __init__(self, start: str, end: str):
self._start = start
self._end = end
@override
def header(self):
if infvars.generator_language == "c":
return for_template_c.replace("@START@", str(self._start)).replace("@END@", str(self._end))
else:
return for_template_fort.replace("@START@", str(self._end)).replace("@END@", str(self._end))
@override
@staticmethod
def trailer():
if infvars.generator_language == "c":
return "}"
else:
return "end do"
......@@ -8,6 +8,7 @@ from Infrastructure.AllocCall import AllocCall
from Infrastructure.MPICall import MPICall
from Infrastructure.CorrectParameter import CorrectParameterFactory
from Infrastructure.Variables import ERROR_CLASSES
from Infrastructure.Branches import IfBranch
import Infrastructure.Variables as infvars
deadlock_marker = "\n!\n! This testcase can result in a Deadlock\n!"
......@@ -100,6 +101,7 @@ program main
integer :: double_size
integer :: integer_size
integer :: logical_size
integer :: i ! Loop index used by some tests
@{stack_vars}@
@{mpi_init}@
......@@ -126,17 +128,6 @@ if (@{thread_level}@ < provided)
printf("MBI ERROR: The MPI Implementation does not provide the required thread level!\\n");
"""
def get_if_cond(eq: bool, var: str, val: str):
if infvars.generator_language == "fort":
comp = ".eq." if eq else ".ne."
suffix = "then"
else:
comp = "==" if eq else "!="
suffix = "{"
return f"if ({var} {comp} {val}) {suffix}"
def get_if_postfix():
return "}" if infvars.generator_language == "c" else "end if"
def get_call(func: str):
if infvars.generator_language == "fort":
return f"call {func}(ierr)"
......@@ -256,19 +247,19 @@ class TemplateManager:
alloc_vars_fort.append(f" {inst.get_type()}, pointer :: {inst.get_name()}(:)")
if inst.get_rank_executing() != current_rank:
if current_rank != 'all':
code_string = code_string + get_if_postfix() + "\n"
code_string = code_string + IfBranch.trailer() + "\n"
# end previous if
current_rank = inst.get_rank_executing()
if current_rank == 'not0':
code_string = code_string + get_if_cond(False, "rank", "0") + "\n"
code_string = code_string + IfBranch("rank != 0").header() + "\n"
elif current_rank != 'all':
code_string = code_string + get_if_cond(True, "rank", str(current_rank)) + "\n"
code_string = code_string + IfBranch(f"rank == {current_rank}").header() + "\n"
code_string += str(inst) + "\n"
# end for inst
if current_rank != 'all':
code_string = code_string + get_if_postfix() + "\n" # end previous if
code_string = code_string + IfBranch.trailer() + "\n" # end previous if
init_string = ""
if self._has_init:
......
......@@ -45,6 +45,8 @@ def adjust_var_language(var) -> str:
actual_value = re.sub(r"^\*", r"", actual_value)
# Need to replace % with MODULO
actual_value = re.sub(r"([A-z0-9]+)\s*\%\s*([A-z0-9]+)", r"modulo(\1, \2)", actual_value)
# Replace array access [*] with (*)
actual_value = re.sub(r"([A-z0-9]+)\s*\[(.+)\]", r"\1(\2)", actual_value)
return actual_value
BASIC_TEST_LEVEL = 1
......
......@@ -6,6 +6,7 @@ from Infrastructure.ErrorGenerator import ErrorGenerator
from Infrastructure.MPICallFactory import CorrectMPICallFactory
from Infrastructure.Template import TemplateManager
from Infrastructure.ArrAsgn import ArrAsgn
from Infrastructure.Branches import IfBranch, ForLoop
class MessageRaceErrorAnyTag(ErrorGenerator):
......@@ -26,30 +27,30 @@ class MessageRaceErrorAnyTag(ErrorGenerator):
# send part
tm.register_instruction("for(int i =0; i < 10; ++i) {", rank_to_execute=1)
tm.register_instruction(ForLoop(0, 10).header(), rank_to_execute=1)
tm.register_instruction(ArrAsgn("buf", 0, "i"), rank_to_execute=1)
send_call = CorrectMPICallFactory().mpi_send()
send_call.set_arg("tag", "i")
tm.register_instruction(send_call, rank_to_execute=1)
tm.register_instruction("}", rank_to_execute=1)
tm.register_instruction(ForLoop.trailer(), rank_to_execute=1)
# the final msg after the loop
send_call = CorrectMPICallFactory().mpi_send()
tm.register_instruction(send_call, rank_to_execute=1)
# recv part
tm.register_instruction("for(int i =0; i < 10; ++i) {", rank_to_execute=0)
tm.register_instruction(ForLoop(0, 10).header(), rank_to_execute=0)
recv_call = CorrectMPICallFactory().mpi_recv()
recv_call.set_arg("tag", "MPI_ANY_TAG")
recv_call.set_rank_executing(0)
tm.register_instruction(recv_call)
tm.register_instruction("if(buf[0]!=i){", rank_to_execute=0)
tm.register_instruction(IfBranch("buf[0]!=i").header(), rank_to_execute=0)
additional_recv = CorrectMPICallFactory().mpi_recv()
additional_recv.set_has_error() # additional recv may lead to deadlock
tm.register_instruction(additional_recv,rank_to_execute=0)
tm.register_instruction(" }", rank_to_execute=0) # end if
tm.register_instruction("}", rank_to_execute=0) # end for
tm.register_instruction(IfBranch.trailer(), rank_to_execute=0)
tm.register_instruction(ForLoop.trailer(), rank_to_execute=0)
tm.register_instruction(CorrectParameterFactory().get_buffer_free())
......@@ -81,16 +82,16 @@ class MessageRaceErrorAnysource(ErrorGenerator):
tm.register_instruction(send_call, rank_to_execute=1)
# recv part
tm.register_instruction("for(int i =1; i < nprocs; ++i) {", rank_to_execute=0)
tm.register_instruction(ForLoop(1, "nprocs").header(), rank_to_execute=0)
recv_call = CorrectMPICallFactory().mpi_recv()
recv_call.set_arg("source", "MPI_ANY_SOURCE")
tm.register_instruction(recv_call, rank_to_execute=0)
tm.register_instruction("if(buf[0]!=i){", rank_to_execute=0)
tm.register_instruction(IfBranch("buf[0]!=i").header(), rank_to_execute=0)
additional_recv = CorrectMPICallFactory().mpi_recv()
additional_recv.set_has_error() # additional recv leads to deadlock
tm.register_instruction(additional_recv, rank_to_execute=0)
tm.register_instruction(" }", rank_to_execute=0) # end if
tm.register_instruction("}", rank_to_execute=0) # end for
tm.register_instruction(IfBranch.trailer(), rank_to_execute=0) # end if
tm.register_instruction(ForLoop.trailer(), rank_to_execute=0)
tm.register_instruction(CorrectParameterFactory().get_buffer_free())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment