Skip to content
Snippets Groups Projects
Commit 8fb97451 authored by Jammer, Tim's avatar Jammer, Tim
Browse files

added more type hints

parent b3054538
Branches
No related tags found
1 merge request!9Infrastructure: Type Hints, Instruction class and lists of instructions
#! /usr/bin/python3
from scripts.Infrastructure import MPICall
from scripts.Infrastructure.Instruction import Instruction
from scripts.Infrastructure.MPICallFactory import MPICallFactory
from scripts.Infrastructure.Template import InstructionBlock
from scripts.Infrastructure.AllocCall import AllocCall, get_free
......@@ -16,13 +18,13 @@ class CorrectParameterFactory:
def __init__(self):
pass
def get_buffer_alloc(self):
def get_buffer_alloc(self)-> AllocCall:
return AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False)
def get_buffer_free(self):
def get_buffer_free(self)->Instruction:
return get_free(AllocCall(self.dtype[0], self.buf_size, self.buf_var_name, use_malloc=False))
def get(self, param, func=None):
def get(self, param:str)->str:
if param in ["BUFFER", "buf", "buffer", "sendbuf", "recvbuf", "origin_addr"]:
return self.buf_var_name
if param in ["COUNT", "count", "sendcount", "recvcount", "origin_count", "target_count", "result_count"]:
......@@ -96,7 +98,7 @@ class CorrectParameterFactory:
# todo also for send and non default args
def get_matching_recv(call):
def get_matching_recv(call:MPICall)->MPICall:
correct_params = CorrectParameterFactory()
recv = MPICallFactory().mpi_recv(
correct_params.get("BUFFER"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment