# this file implements the postprocessing of prediction results (tubule dilation, hole filling, small area removal, instance extration, ...)

import numpy as np
import torch
import torch.nn as nn
import math
from scipy.ndimage.measurements import label
from scipy.ndimage.morphology import binary_dilation, binary_fill_holes

from utils import getChannelSmootingConvLayer


structure = np.zeros((3, 3), dtype=np.int)
structure[1, :] = 1
structure[:, 1] = 1

# selected colors for label/class visualization
colors = np.array([    [  0,   0,   0], # Black
                       [255,   0,   0], # Red
                       [  0, 128,   0], # Green
                       [  0,   0, 255], # Blue
                       [  0, 255, 255], # Cyan
                       [255,   0, 255], # Magenta
                       [255, 255,   0], # Yellow
                       [139,  69,  19], # Brown (saddlebrown)
                       [128,   0, 128], # Purple
                       [255, 140,   0], # Orange
                       [255, 255, 255]], dtype=np.uint8) # White

# get random color for tubules instances that is not too similar to colors of other classes
def getRandomTubuliColor():
    while(True):
        candidateColor = np.random.randint(low=0, high=256, size=3, dtype=np.uint8)
        if not ((np.abs((candidateColor-colors[0:7])).sum(1)<50).any()):
            return candidateColor

# this method gets postprocessed prediction results as well as the ground-truth label map and extract all instance channels of each 
# label for further performance computation as well as instance visualization, and further applies the last postprocessing step of tubules dilation
# yielding final (instance) results
def extractInstanceChannels(postprocessedPrediction, preprocessedGT, tubuliDilation=True):

    postprocessedPredictionRGB = np.zeros(shape=(preprocessedGT.shape[0], preprocessedGT.shape[1], 3), dtype=np.uint8)
    preprocessedGTrgb = postprocessedPredictionRGB.copy()
    for i in range(2, 7):
        postprocessedPredictionRGB[postprocessedPrediction == i] = colors[i]
        preprocessedGTrgb[preprocessedGT == i] = colors[i]

    labeledTubuli, numberTubuli = label(np.asarray(postprocessedPrediction == 1, np.uint8), structure)
    labeledGlom, _ = label(np.asarray(np.logical_or(postprocessedPrediction == 2, postprocessedPrediction == 3), np.uint8), structure)
    labeledTuft, _ = label(np.asarray(postprocessedPrediction == 3, np.uint8), structure)
    labeledVeins, _ = label(np.asarray(postprocessedPrediction == 4, np.uint8), structure)
    labeledArtery, _ = label(np.asarray(np.logical_or(postprocessedPrediction == 5, postprocessedPrediction == 6), np.uint8), structure)
    labeledArteryLumen, _ = label(np.asarray(postprocessedPrediction == 6, np.uint8), structure)

    for i in range(1, numberTubuli + 1):
        if tubuliDilation:
            tubuliSelection = binary_dilation(labeledTubuli == i)
            labeledTubuli[tubuliSelection] = i
        else:
            tubuliSelection = labeledTubuli == i
        postprocessedPredictionRGB[tubuliSelection] = getRandomTubuliColor()


    labeledTubuliGT, numberTubuliGT = label(np.asarray(preprocessedGT == 1, np.uint8), structure)
    labeledGlomGT, _ = label(np.asarray(np.logical_or(preprocessedGT == 2, preprocessedGT == 3), np.uint8), structure)
    labeledTuftGT, _ = label(np.asarray(preprocessedGT == 3, np.uint8), structure)
    labeledVeinsGT, _ = label(np.asarray(preprocessedGT == 4, np.uint8), structure)
    labeledArteryGT, _ = label(np.asarray(np.logical_or(preprocessedGT == 5, preprocessedGT == 6), np.uint8), structure)
    labeledArteryLumenGT, _ = label(np.asarray(preprocessedGT == 6, np.uint8), structure)

    for i in range(1, numberTubuliGT + 1):
        tubuliSelectionGT = labeledTubuliGT == i
        preprocessedGTrgb[tubuliSelectionGT] = getRandomTubuliColor()


    return [labeledTubuli, labeledGlom, labeledTuft, labeledVeins, labeledArtery, labeledArteryLumen], [labeledTubuliGT, labeledGlomGT, labeledTuftGT, labeledVeinsGT, labeledArteryGT, labeledArteryLumenGT], postprocessedPredictionRGB, preprocessedGTrgb



def postprocessPredictionAndGT(prediction, GT, device, predictionsmoothing, holefilling):
    """
    :param prediction: Torch FloatTensor of size 1xCxHxW stored in VRAM/on GPU
    :param GT: HxW ground-truth label map, numpy long tensor
    :return: 1.postprocessed labelmap result (prediction smoothing, removal of small areas, hole filling)
             2.network output prediction (w/o postprocessing)
    """
    ################# PREDICTION SMOOTHING ################
    if predictionsmoothing:
        smoothingKernel = getChannelSmootingConvLayer(8).to(device)
        prediction = smoothingKernel(prediction)

    # labelMap contains following labels: 0/1/2/3/4/5/6/7 => Background/tubuli/glom_full/glom_tuft/veins/artery_full/artery_lumen/border
    labelMap = torch.argmax(prediction, dim=1).squeeze(0).to("cpu").numpy() # Label 0/1/2/3/4/5/6/7: Background/tubuli/glom_full/glom_tuft/veins/artery_full/artery_lumen/border

    netOutputPrediction = labelMap.copy()

    ################# REMOVING TOO SMALL CONNECTED REGIONS ################
    # Tuft
    labeledTubuli, numberTubuli = label(np.asarray(labelMap == 3, np.uint8), structure)  # datatype of 'labeledTubuli': int32
    for i in range(1, numberTubuli + 1):
        tubuliSelection = (labeledTubuli == i)
        if tubuliSelection.sum() < 500:  # remove too small noisy regions
            labelMap[tubuliSelection] = 2

    # Glomeruli
    labeledTubuli, numberTubuli = label(np.asarray(np.logical_or(labelMap == 3, labelMap==2), np.uint8), structure)  # datatype of 'labeledTubuli': int32
    for i in range(1, numberTubuli + 1):
        tubuliSelection = (labeledTubuli == i)
        if tubuliSelection.sum() < 1500:  # remove too small noisy regions
            labelMap[tubuliSelection] = 0

    # Artery lumen
    labeledTubuli, numberTubuli = label(np.asarray(labelMap == 6, np.uint8), structure)  # datatype of 'labeledTubuli': int32
    for i in range(1, numberTubuli + 1):
        tubuliSelection = (labeledTubuli == i)
        if tubuliSelection.sum() < 20:  # remove too small noisy regions
            labelMap[tubuliSelection] = 5

    # Full artery
    labeledTubuli, numberTubuli = label(np.asarray(np.logical_or(labelMap == 5, labelMap==6), np.uint8), structure)  # datatype of 'labeledTubuli': int32
    for i in range(1, numberTubuli + 1):
        tubuliSelection = (labeledTubuli == i)
        if tubuliSelection.sum() < 400:  # remove too small noisy regions
            labelMap[tubuliSelection] = 0

    # Veins
    labeledTubuli, numberTubuli = label(np.asarray(labelMap == 4, np.uint8), structure)  # datatype of 'labeledTubuli': int32
    for i in range(1, numberTubuli + 1):
        tubuliSelection = (labeledTubuli == i)
        if tubuliSelection.sum() < 3000:  # remove too small noisy regions
            labelMap[tubuliSelection] = 0

    # Tubuli
    labeledTubuli, numberTubuli = label(np.asarray(labelMap == 1, np.uint8), structure)  # datatype of 'labeledTubuli': int32
    for i in range(1, numberTubuli + 1):
        tubuliSelection = (labeledTubuli == i)
        if tubuliSelection.sum() < 400:  # remove too small noisy regions
            labelMap[tubuliSelection] = 0


    ################# HOLE FILLING ################
    if holefilling:
        labelMap[binary_fill_holes(labelMap==1)] = 1 #tubuli
        labelMap[binary_fill_holes(labelMap==4)] = 4 #veins
        tempTuftMask = binary_fill_holes(labelMap==3) #tuft
        labelMap[binary_fill_holes(np.logical_or(labelMap==3, labelMap==2))] = 2 #glom
        labelMap[tempTuftMask] = 3 #tuft
        tempArteryLumenMask = binary_fill_holes(labelMap == 6)  #artery_lumen
        labelMap[binary_fill_holes(np.logical_or(labelMap == 5, labelMap == 6))] = 5  #full_artery
        labelMap[tempArteryLumenMask] = 6  #artery_lumen


    return labelMap, netOutputPrediction, GT