From b63d39e7605130f70686f14104871e6627f4749d Mon Sep 17 00:00:00 2001
From: Nassim Bouteldja <nbouteldja@ukaachen.de>
Date: Wed, 13 Apr 2022 21:38:06 +0200
Subject: [PATCH] Upload New File

---
 utils.py | 701 +++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 701 insertions(+)
 create mode 100644 utils.py

diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..693c847
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,701 @@
+#### THIS UTILITY PYTHON SCRIPT CONTAINS LOTS OF HELPFUL FUNCTIONS ####
+
+import numpy as np
+import torch
+from subprocess import check_output
+import os
+import psutil
+import math
+from scipy.ndimage import zoom
+
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+
+from torch.utils.data import DataLoader
+from torch.utils.data.sampler import SubsetRandomSampler
+import torch.nn as nn
+
+from scipy.ndimage.measurements import label
+from scipy.ndimage.morphology import binary_dilation, binary_fill_holes
+
+
+colors = torch.tensor([[  0,   0,   0], # Black
+                       [255,   0,   0], # Red
+                       [  0, 255,   0], # Green
+                       [  0,   0, 255], # Blue
+                       [  0, 255, 255], # Cyan
+                       [255,   0, 255], # Magenta
+                       [255, 255,   0], # Yellow
+                       [139,  69,  19], # Brown (saddlebrown)
+                       [128,   0, 128], # Purple
+                       [255, 140,   0], # Orange
+                       [255, 255, 255]], dtype=torch.uint8) # White
+
+# Takes a binary mask image and outputs its bounding box
+def getBoundingBox(img):
+    rows = np.any(img, axis=1)
+    cols = np.any(img, axis=0)
+    rmin, rmax = np.where(rows)[0][[0, -1]]
+    cmin, cmax = np.where(cols)[0][[0, -1]]
+
+    return rmin, rmax, cmin, cmax
+
+# Generates a 2d ball of a specified radius representing a structuring element for morphological operations
+def generate_ball(radius):
+    structure = np.zeros((3, 3), dtype=np.int)
+    structure[1, :] = 1
+    structure[:, 1] = 1
+
+    ball = np.zeros((radius * 2 + 1, radius * 2 + 1), dtype=np.uint8)
+    ball[radius, radius] = 1
+    for i in range(radius):
+        ball = binary_dilation(ball, structure=structure)
+    return np.asarray(ball, dtype=np.int)
+
+
+def convert_labelmap_to_rgb(labelmap):
+    """
+    Method used to generate rgb label maps for tensorboard visualization
+    :param labelmap: HxW label map tensor containing values from 0 to n_classes
+    :return: 3xHxW RGB label map containing colors in the following order: Black (background), Red, Green, Blue, Cyan, Magenta, Yellow, Brown, Orange, Purple
+    """
+    n_classes = labelmap.max()
+
+    result = torch.zeros(size=(labelmap.size()[0], labelmap.size()[1], 3), dtype=torch.uint8)
+    for i in range(1, n_classes+1):
+        result[labelmap == i] = colors[i]
+
+    return result.permute(2, 0, 1)
+
+def convert_labelmap_to_rgb_with_instance_first_class(labelmap):
+    """
+    Method used to generate rgb label maps for tensorboard visualization
+    :param labelmap: HxW label map tensor containing values from 0 to n_classes
+    :return: 3xHxW RGB label map containing colors in the following order: Black (background), Red, Green, Blue, Cyan, Magenta, Yellow, Brown, Orange, Purple
+    """
+    n_classes = labelmap.max()
+
+    result = np.zeros(shape=(labelmap.shape[0], labelmap.shape[1], 3), dtype=np.uint8)
+    for i in range(2, n_classes+1):
+        result[labelmap == i] = colors[i].numpy()
+
+    structure = np.ones((3, 3), dtype=np.int)
+
+    labeledTubuli, numberTubuli = label(np.asarray(labelmap == 1, np.uint8))  # datatype of 'labeledTubuli': int32
+    for i in range(1, numberTubuli + 1):
+        result[binary_dilation(binary_dilation(binary_dilation(labeledTubuli == i)))] = np.random.randint(low=0, high=256, size=3, dtype=np.uint8)  # assign random colors to tubuli
+
+    return result
+
+def convert_labelmap_to_rgb_except_first_class(labelmap):
+    """
+    Method used to generate rgb label maps for tensorboard visualization
+    :param labelmap: HxW label map tensor containing values from 0 to n_classes
+    :return: 3xHxW RGB label map containing colors in the following order: Black (background), Red, Green, Blue, Cyan, Magenta, Yellow, Brown, Orange, Purple
+    """
+    n_classes = labelmap.max()
+
+    result = torch.zeros(size=(labelmap.size()[0], labelmap.size()[1], 3), dtype=torch.uint8)
+    for i in range(2, n_classes+1):
+        result[labelmap == i] = colors[i]
+
+    return result.permute(2, 0, 1)
+
+
+def getColorMapForLabelMap():
+    return ['black', 'red', 'green', 'blue', 'cyan', 'magenta', 'yellow', 'brown', 'purple', 'orange', 'white','red', 'green', 'blue', 'cyan', 'magenta', 'yellow', 'brown', 'purple', 'orange', 'white','red', 'green', 'blue', 'cyan', 'magenta', 'yellow', 'brown', 'purple', 'orange', 'white','red', 'green', 'blue', 'cyan', 'magenta', 'yellow', 'brown', 'purple', 'orange', 'white']
+
+# saves final results in a figure containing 2 rows and 4 coloums: From first to last: 1. Image, 2. network prediction, 
+# 3. postprocessed prediction, 4. prediction on instance-level for tubules, 5. overlay of images and postprocessed prediction
+# 6. Ground-Truth, 7. preprocessed Ground-Truth, 8. preprocessed Ground-Truth on instance-level for tubules
+def saveFigureResults(img, outputPrediction, postprocessedPrediction, finalPredictionRGB, GT, preprocessedGT, preprocessedGTrgb, fullResultPath, alpha=0.4):
+    customColors = getColorMapForLabelMap()
+    max_number_of_labels = len(customColors)
+    assert outputPrediction.max() < max_number_of_labels, 'Too many labels -> Not enough colors available in custom colormap! Add some colors!'
+    customColorMap = mpl.colors.ListedColormap(getColorMapForLabelMap())
+
+    # avoid brown color (border visualization) in output for final GT and prediction
+    postprocessedPrediction[postprocessedPrediction==7] = 0
+    preprocessedGT[preprocessedGT==7] = 0
+
+    predictionMask = np.ma.masked_where(postprocessedPrediction == 0, postprocessedPrediction)
+
+    plt.figure(figsize=(16, 8.1))
+    plt.subplot(241)
+    plt.imshow(img)
+    plt.axis('off')
+    plt.subplot(242)
+    plt.imshow(outputPrediction, cmap=customColorMap, vmin = 0, vmax = max_number_of_labels-1)
+    plt.axis('off')
+    plt.subplot(243)
+    plt.imshow(postprocessedPrediction, cmap=customColorMap, vmin = 0, vmax = max_number_of_labels-1)
+    plt.axis('off')
+    plt.subplot(244)
+    plt.imshow(finalPredictionRGB)
+    plt.axis('off')
+    plt.subplot(245)
+    plt.imshow(img[(img.shape[0]-outputPrediction.shape[0])//2:(img.shape[0]-outputPrediction.shape[0])//2+outputPrediction.shape[0],(img.shape[1]-outputPrediction.shape[1])//2:(img.shape[1]-outputPrediction.shape[1])//2+outputPrediction.shape[1],:])
+    plt.imshow(predictionMask, cmap=customColorMap, vmin = 0, vmax = max_number_of_labels-1, alpha = alpha)
+    plt.axis('off')
+    plt.subplot(246)
+    plt.imshow(GT, cmap=customColorMap, vmin = 0, vmax = max_number_of_labels-1)
+    plt.axis('off')
+    plt.subplot(247)
+    plt.imshow(preprocessedGT, cmap=customColorMap, vmin = 0, vmax = max_number_of_labels-1)
+    plt.axis('off')
+    plt.subplot(248)
+    plt.imshow(preprocessedGTrgb)
+    plt.axis('off')
+
+    plt.subplots_adjust(wspace=0, hspace=0)
+
+    plt.savefig(fullResultPath)
+    plt.close()
+
+# Visualizes prediction after applying tubules dilation
+def savePredictionResults(predictionWithoutTubuliDilation, fullResultPath, figSize):
+    prediction = predictionWithoutTubuliDilation.copy()
+    prediction[binary_dilation(binary_dilation(binary_dilation(binary_dilation(prediction == 1))))] = 1
+
+    customColors = getColorMapForLabelMap()
+    max_number_of_labels = len(customColors)
+    assert prediction.max() < max_number_of_labels, 'Too many labels -> Not enough colors available in custom colormap! Add some colors!'
+    customColorMap = mpl.colors.ListedColormap(getColorMapForLabelMap())
+
+    fig = plt.figure(figsize=figSize)
+    ax = plt.Axes(fig, [0., 0., 1., 1., ])
+    ax.set_axis_off()
+    fig.add_axes(ax)
+    ax.imshow(prediction, cmap=customColorMap, vmin = 0, vmax = max_number_of_labels-1)
+    plt.savefig(fullResultPath)
+    plt.close()
+
+# Visualizes prediction without applying tubules dilation
+def savePredictionResultsWithoutDilation(prediction, fullResultPath, figSize):
+    customColors = getColorMapForLabelMap()
+    max_number_of_labels = len(customColors)
+    assert prediction.max() < max_number_of_labels, 'Too many labels -> Not enough colors available in custom colormap! Add some colors!'
+    customColorMap = mpl.colors.ListedColormap(getColorMapForLabelMap())
+
+    fig = plt.figure(figsize=figSize)
+    ax = plt.Axes(fig, [0., 0., 1., 1., ])
+    ax.set_axis_off()
+    fig.add_axes(ax)
+    ax.imshow(prediction, cmap=customColorMap, vmin = 0, vmax = max_number_of_labels-1)
+    plt.savefig(fullResultPath)
+    plt.close()
+
+# Visualizes prediction and image overlay after dilating tubules
+def savePredictionOverlayResults(img, prediction, fullResultPath, figSize, alpha=0.4):
+    predictionMask = np.ma.masked_where(prediction == 0, prediction)
+
+    colorMap = np.array([[0, 0, 0],  # Black
+                         [255, 0, 0],  # Red
+                         [0, 255, 0],  # Green
+                         [0, 0, 255],  # Blue
+                         [0, 255, 255],  # Cyan
+                         [255, 0, 255],  # Magenta
+                         [255, 255, 0],  # Yellow
+                         [139, 69, 19],  # Brown (saddlebrown)
+                         [128, 0, 128],  # Purple
+                         [255, 140, 0],  # Orange
+                         [255, 255, 255]], dtype=np.uint8)  # White
+
+    newRandomColors = np.random.randint(low=0, high=256, dtype=np.uint8, size=(prediction.max(), 3))
+    colorMap = np.concatenate((colorMap, newRandomColors))
+    colorMap = colorMap / 255.
+
+    max_number_of_labels = len(colorMap)
+    assert prediction.max() < max_number_of_labels, 'Too many labels -> Not enough colors available in custom colormap! Add some colors!'
+    customColorMap = mpl.colors.ListedColormap(colorMap)
+
+    fig = plt.figure(figsize=figSize)
+    ax = plt.Axes(fig, [0., 0., 1., 1., ])
+    ax.set_axis_off()
+    fig.add_axes(ax)
+    ax.imshow(img)
+    ax.imshow(predictionMask, cmap=customColorMap, vmin = 0, vmax = max_number_of_labels-1, alpha = alpha)
+    plt.savefig(fullResultPath)
+    plt.close()
+
+def saveImage(img, fullResultPath, figSize):
+    fig = plt.figure(figsize=figSize)
+    ax = plt.Axes(fig, [0., 0., 1., 1., ])
+    ax.set_axis_off()
+    fig.add_axes(ax)
+    ax.imshow(img)
+    plt.savefig(fullResultPath)
+    plt.close()
+
+# Visualizes prediction and image overlay efficiently by strongly downsampling both
+def savePredictionOverlayResults_Fast(img, prediction, fullResultPath, figHeight, alpha=0.4):
+    downscaleFactor = 5.
+    img = np.asarray(zoom(img, zoom=(1 / downscaleFactor, 1 / downscaleFactor, 1), order=0), img.dtype)
+    prediction = np.asarray(zoom(prediction, 1 / downscaleFactor, order=0), prediction.dtype)
+
+    predictionMask = np.ma.masked_where(prediction == 0, prediction)
+
+    colorMap = np.array([[0, 0, 0],  # Black
+                         [255, 0, 0],  # Red
+                         [0, 255, 0],  # Green
+                         [0, 0, 255],  # Blue
+                         [0, 255, 255],  # Cyan
+                         [255, 0, 255],  # Magenta
+                         [255, 255, 0],  # Yellow
+                         [139, 69, 19],  # Brown (saddlebrown)
+                         [128, 0, 128],  # Purple
+                         [255, 140, 0],  # Orange
+                         [255, 255, 255]], dtype=np.uint8)  # White
+
+    newRandomColors = np.random.randint(low=0, high=256, dtype=np.uint8, size=(int(prediction.max()), 3))
+    colorMap = np.concatenate((colorMap, newRandomColors))
+    colorMap = colorMap / 255.
+
+    max_number_of_labels = len(colorMap)
+    assert prediction.max() < max_number_of_labels, 'Too many labels -> Not enough colors available in custom colormap! Add some colors!'
+    customColorMap = mpl.colors.ListedColormap(colorMap)
+
+    fig = plt.figure(figsize=(figHeight*prediction.shape[1]/prediction.shape[0], figHeight))
+    ax = plt.Axes(fig, [0., 0., 1., 1., ])
+    ax.set_axis_off()
+    fig.add_axes(ax)
+    ax.imshow(img)
+    ax.imshow(predictionMask, cmap=customColorMap, vmin = 0, vmax = max_number_of_labels-1, alpha = alpha)
+    plt.savefig(fullResultPath)
+    plt.close()
+
+def saveOverlayResults(img, seg, fullResultPath, figHeight, alpha=0.4):
+    segMask = np.ma.masked_where(seg == 0, seg)
+
+    customColors = getColorMapForLabelMap()
+    max_number_of_labels = len(customColors)
+    assert seg.max() < max_number_of_labels, 'Too many labels -> Not enough colors available in custom colormap! Add some colors!'
+    customColorMap = mpl.colors.ListedColormap(getColorMapForLabelMap())
+
+    fig = plt.figure(figsize=(figHeight*seg.shape[1]/seg.shape[0], figHeight))
+    ax = plt.Axes(fig, [0., 0., 1., 1., ])
+    ax.set_axis_off()
+    fig.add_axes(ax)
+    ax.imshow(img)
+    ax.imshow(segMask, cmap=customColorMap, vmin = 0, vmax = max_number_of_labels-1, alpha = alpha)
+    plt.savefig(fullResultPath)
+    plt.close()
+
+# Visualizes prediction and image RGB overlay
+def saveRGBPredictionOverlayResults(img, prediction, fullResultPath, figSize, alpha=0.4):
+    predictionMask = prediction.sum(2)==0
+    predictionCopy = prediction.copy()
+    predictionCopy[predictionMask] = img[predictionMask]
+    fig = plt.figure(figsize=figSize)
+    ax = plt.Axes(fig, [0., 0., 1., 1., ])
+    ax.set_axis_off()
+    fig.add_axes(ax)
+    ax.imshow(np.asarray(np.round(predictionCopy*alpha+(1-alpha)*img), np.uint8))
+    plt.savefig(fullResultPath)
+    plt.close()
+
+def saveImage(img, fullResultPath, figSize):
+    fig = plt.figure(figsize=figSize)
+    ax = plt.Axes(fig, [0., 0., 1., 1., ])
+    ax.set_axis_off()
+    fig.add_axes(ax)
+    ax.imshow(img)
+    plt.savefig(fullResultPath)
+    plt.close()
+
+
+
+def getCrossValSplits(dataIDX, amountFolds, foldNo, setting):
+    """
+    Cross-Validation-Split of indices according to fold number and setting
+    Usage:
+        dataIDX = np.arange(dataset.__len__())
+        # np.random.shuffle(dataIDX)
+        for i in range(amountFolds):
+            train_idx, val_idx, test_idx = getCrossFoldSplits(dataIDX=dataIDX, amountFolds=amountFolds, foldNo=i+1, setting=setting)
+    :param dataIDX: All data indices stored in numpy array
+    :param amountFolds: Total amount of folds
+    :param foldNo: Fold number, # CARE: Fold numbers start with 1 and go up to amountFolds ! #
+    :param setting: Train / Train+Test / Train+Val / Train+Test+Val
+    :return: tuple consisting of 3 numpy arrays (trainIDX, valIDX, testIDX) containing indices according to split
+    """
+    assert (setting in ['train_val_test', 'train_test', 'train_val', 'train']), 'Given setting >'+setting+'< is incorrect!'
+
+    num_total_data = dataIDX.__len__()
+
+    if setting == 'train':
+        return dataIDX, None, None
+
+    elif setting == 'train_val':
+        valIDX = dataIDX[num_total_data * (foldNo - 1) // amountFolds: num_total_data * foldNo // amountFolds]
+        trainIDX = np.setxor1d(dataIDX, valIDX)
+        return trainIDX, valIDX, None
+
+    elif setting == 'train_test':
+        testIDX = dataIDX[num_total_data * (foldNo - 1) // amountFolds: num_total_data * foldNo // amountFolds]
+        trainIDX = np.setxor1d(dataIDX, testIDX)
+        return trainIDX, None, testIDX
+
+    elif setting == 'train_val_test':
+        testIDX = dataIDX[num_total_data * (foldNo - 1) // amountFolds: num_total_data * foldNo // amountFolds]
+        if foldNo != amountFolds:
+            valIDX = dataIDX[num_total_data * foldNo // amountFolds: num_total_data * (foldNo+1) // amountFolds]
+        else:
+            valIDX = dataIDX[0 : num_total_data // amountFolds]
+        trainIDX = np.setxor1d(np.setxor1d(dataIDX, testIDX), valIDX)
+        return trainIDX, valIDX, testIDX
+
+    else:
+        raise ValueError('Given setting >'+str(setting)+'< is invalid!')
+
+
+def parse_nvidia_smi(unit=0):
+    result = check_output(["nvidia-smi", "-i", str(unit),]).decode('utf-8').split('\n')
+    return 'Current GPU usage: ' + result[0] + '\r\n' + result[5] + '\r\n' + result[8]
+
+
+def parse_RAM_info():
+    return 'Current RAM usage: '+str(round(psutil.Process(os.getpid()).memory_info().rss / 1E6, 2))+' MB'
+
+
+def countParam(model):
+    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
+    params = sum([np.prod(p.size()) for p in model_parameters])
+    return params
+
+
+def getOneHotEncoding(imgBatch, labelBatch):
+    """
+    :param imgBatch: image minibatch (FloatTensor) to extract shape and device info for output
+    :param labelBatch: label minibatch (LongTensor) to be converted to one-hot encoding
+    :return: One-hot encoded label minibatch with equal size as imgBatch and stored on same device
+    """
+    if imgBatch.size()[1] != 1: # Multi-label segmentation otherwise binary segmentation
+        labelBatch = labelBatch.unsqueeze(1)
+        onehotEncoding = torch.zeros_like(imgBatch)
+        onehotEncoding.scatter_(1, labelBatch, 1)
+        return onehotEncoding
+    return labelBatch
+
+
+def getWeightsForCEloss(dataset, train_idx, areLabelsOnehotEncoded, device, logger):
+    # Choice 1) Manually set custom weights
+    weights = torch.tensor([1,2,4,6,2,3], dtype=torch.float32, device=device)
+    weights = weights / weights.sum()
+
+    # Choice 2) Compute weights as "np.mean(histogram) / histogram"
+    dataloader = DataLoader(dataset=dataset, batch_size=6, sampler=SubsetRandomSampler(train_idx), num_workers=6)
+
+    if areLabelsOnehotEncoded:
+        histograms = 0
+        for batch in dataloader:
+            imgBatch, segBatch = batch
+            amountLabels = segBatch.size()[1]
+            if amountLabels == 1: # binary segmentation
+                histograms = histograms + torch.tensor([(segBatch==0).sum(),(segBatch==1).sum()])
+            else: # multi-label segmentation
+                if imgBatch.dim() == 4: #2D data
+                    histograms = histograms + segBatch.sum(3).sum(2).sum(0)
+                else: #3D data
+                    histograms = histograms + segBatch.sum(4).sum(3).sum(2).sum(0)
+
+        histograms = histograms.numpy()
+    else:
+        histograms = np.array([0])
+        for batch in dataloader:
+            _, segBatch = batch
+
+            segHistogram = np.histogram(segBatch.numpy(), segBatch.numpy().max()+1)[0]
+
+            if len(histograms) >= len(segHistogram): #(segHistogram could have different size than histograms)
+                histograms[:len(segHistogram)] += segHistogram
+            else:
+                segHistogram[:len(histograms)] += histograms
+                histograms = segHistogram
+
+    weights = np.mean(histograms) / histograms
+    weights = torch.from_numpy(weights).float().to(device)
+
+    logger.info('=> Weights for CE-loss: '+str(weights))
+
+    return weights
+
+
+
+def getMeanDiceScores(diceScores, logger):
+    """
+    Compute mean label dice scores of numpy dice score array (2d) (and its mean)
+    :return: mean label dice scores with '-1' representing totally missing label (meanLabelDiceScores), mean overall dice score (meanOverallDice)
+    """
+    meanLabelDiceScores = np.ma.masked_where(diceScores == -1, diceScores).mean(0).data
+    label_GT_occurrences = (diceScores != -1).sum(0)
+    if (label_GT_occurrences == 0).any():
+        logger.info('[# WARNING #] Label(s): ' + str(np.argwhere(label_GT_occurrences == 0).flatten() + 1) + ' not present at all in predictions and ground-truth of current dataset!')
+        meanLabelDiceScores[label_GT_occurrences == 0] = -1
+    meanOverallDice = meanLabelDiceScores[meanLabelDiceScores != -1].mean()
+
+    return meanLabelDiceScores, meanOverallDice
+
+
+def getDiceScores(prediction, segBatch):
+    """
+    Compute mean dice scores of predicted foreground labels.
+    NOTE: Dice scores of missing gt labels will be excluded and are thus represented by -1 value entries in returned dice score matrix!
+    NOTE: Method changes prediction to 0/1 values in the binary case!
+    :param prediction: BxCxHxW (if 2D) or BxCxHxWxD (if 3D) FloatTensor (care: prediction has not undergone any final activation!) (note: C=1 for binary segmentation task)
+    :param segBatch: BxCxHxW (if 2D) or BxCxHxWxD (if 3D) FloatTensor (Onehot-Encoding) or Bx1xHxW (if 2D) or Bx1xHxWxD (if 3D) LongTensor
+    :return: Numpy array containing BxC-1 (background excluded) dice scores
+    """
+    batchSize, amountClasses = prediction.size()[0], prediction.size()[1]
+
+    if amountClasses == 1: # binary segmentation task => simulate sigmoid to get label results
+        prediction[prediction >= 0] = 1
+        prediction[prediction < 0] = 0
+        prediction = prediction.squeeze(1)
+        segBatch = segBatch.squeeze(1)
+        amountClasses += 1
+    else: # multi-label segmentation task
+        prediction = prediction.argmax(1) # LongTensor without C-channel
+        if segBatch.dtype == torch.float32:  # segBatch is onehot-encoded
+            segBatch = segBatch.argmax(1)
+        else:
+            segBatch = segBatch.squeeze(1)
+
+    prediction = prediction.view(batchSize, -1)
+    segBatch = segBatch.view(batchSize, -1)
+
+    labelDiceScores = np.zeros((batchSize, amountClasses-1), dtype=np.float32) - 1 #ignore background class!
+    for b in range(batchSize):
+        currPred = prediction[b,:]
+        currGT = segBatch[b,:]
+
+        for c in range(1,amountClasses):
+            classPred = (currPred == c).float()
+            classGT = (currGT == c).float()
+
+            if classGT.sum() + classPred.sum() != 0: # only evaluate label prediction when is also present in ground-truth
+                labelDiceScores[b, c-1] = ((2. * (classPred * classGT).sum()) / (classGT.sum() + classPred.sum())).item()
+
+    return labelDiceScores
+
+def getDiceScoresSinglePair(prediction, segBatch, tubuliDilation=True):
+    if tubuliDilation:
+        prediction[binary_dilation(prediction == 1)] = 1
+
+    batchSize, amountClasses = 1, 8
+
+    prediction = torch.from_numpy(prediction)
+    segBatch = torch.from_numpy(segBatch)
+
+    prediction = prediction.view(batchSize, -1)
+    segBatch = segBatch.view(batchSize, -1)
+
+    labelDiceScores = np.zeros((batchSize, amountClasses-1), dtype=np.float32) - 1 #ignore background class!
+    for b in range(batchSize):
+        currPred = prediction[b,:]
+        currGT = segBatch[b,:]
+
+        for c in range(1,amountClasses):
+            if c == 2 or c == 5:
+                classPred = np.logical_or(currPred == c, currPred == c+1).float()
+                classGT = np.logical_or(currGT == c, currGT == c+1).float()
+            else:
+                classPred = (currPred == c).float()
+                classGT = (currGT == c).float()
+
+            # if classGT.sum() + classPred.sum() != 0: # only evaluate label prediction when is also present in ground-truth
+            if classGT.sum() != 0: # only evaluate label prediction when is also present in ground-truth
+                labelDiceScores[b, c-1] = ((2. * (classPred * classGT).sum()) / (classGT.sum() + classPred.sum())).item()
+
+    return labelDiceScores
+
+def printResults(allClassEvaluators, applyTestTimeAugmentation, printOnlyTTAresults, logger, saveNumpyResults, resultsPath):
+    if not printOnlyTTAresults:
+        logger.info('########## Detection (average precision + fscores) and segmentation accuracies (object-level dice): ##########')
+        precisionsAPTub, avg_precisionTub, precisionsTub, recallsTub, fscoresTub, avg_dice_scoreTub, std_dice_scoreTub, min_dice_scoreTub, max_dice_scoreTub = allClassEvaluators[0][0].score()
+        precisionsAPGlom, avg_precisionGlom, precisionsGlom, recallsGlom, fscoresGlom, avg_dice_scoreGlom, std_dice_scoreGlom, min_dice_scoreGlom, max_dice_scoreGlom = allClassEvaluators[0][1].score()
+        precisionsAPTuft, avg_precisionTuft, precisionsTuft, recallsTuft, fscoresTuft, avg_dice_scoreTuft, std_dice_scoreTuft, min_dice_scoreTuft, max_dice_scoreTuft = allClassEvaluators[0][2].score()
+        precisionsAPVeins, avg_precisionVeins, precisionsVeins, recallsVeins, fscoresVeins, avg_dice_scoreVeins, std_dice_scoreVeins, min_dice_scoreVeins, max_dice_scoreVeins = allClassEvaluators[0][3].score()
+        precisionsAPArtery, avg_precisionArtery, precisionsArtery, recallsArtery, fscoresArtery, avg_dice_scoreArtery, std_dice_scoreArtery, min_dice_scoreArtery, max_dice_scoreArtery = allClassEvaluators[0][4].score()
+        precisionsAPLumen, avg_precisionLumen, precisionsLumen, recallsLumen, fscoresLumen, avg_dice_scoreLumen, std_dice_scoreLumen, min_dice_scoreLumen, max_dice_scoreLumen = allClassEvaluators[0][5].score()
+
+        logger.info('DETECTION RESULTS MEASURED BY AVERAGE PRECISION:')
+        logger.info('0.5    0.55    0.6    0.65    0.7    0.75    0.8    0.85    0.9 <- Thresholds')
+        logger.info(str(np.round(precisionsAPTub, 4)) + ', Mean: ' + str(np.round(avg_precisionTub, 4)) + '  <-- Tubuli')
+        logger.info(str(np.round(precisionsAPGlom, 4)) + ', Mean: ' + str(np.round(avg_precisionGlom, 4)) + '  <-- Glomeruli (incl. tuft)')
+        logger.info(str(np.round(precisionsAPTuft, 4)) + ', Mean: ' + str(np.round(avg_precisionTuft, 4)) + '  <-- Tuft')
+        logger.info(str(np.round(precisionsAPVeins, 4)) + ', Mean: ' + str(np.round(avg_precisionVeins, 4)) + '  <-- Veins')
+        logger.info(str(np.round(precisionsAPArtery, 4)) + ', Mean: ' + str(np.round(avg_precisionArtery, 4)) + '  <-- Artery (incl. lumen)')
+        logger.info(str(np.round(precisionsAPLumen, 4)) + ', Mean: ' + str(np.round(avg_precisionLumen, 4)) + '  <-- Artery lumen')
+        logger.info('DETECTION RESULTS MEASURED BY F-SCORES, RECALL, PRECISION (MIN. 50% IoU):')
+        logger.info('F-score: {} (Recall: {}, Precision: {})   <-- Tubuli'.format(str(np.round(fscoresTub[0], 4)), str(np.round(recallsTub[0], 4)), str(np.round(precisionsTub[0], 4))))
+        logger.info('F-score: {} (Recall: {}, Precision: {})   <-- Glomeruli (incl. tuft)'.format(str(np.round(fscoresGlom[0], 4)), str(np.round(recallsGlom[0], 4)), str(np.round(precisionsGlom[0], 4))))
+        logger.info('F-score: {} (Recall: {}, Precision: {})   <-- Tuft'.format(str(np.round(fscoresTuft[0], 4)), str(np.round(recallsTuft[0], 4)), str(np.round(precisionsTuft[0], 4))))
+        logger.info('F-score: {} (Recall: {}, Precision: {})   <-- Veins'.format(str(np.round(fscoresVeins[0], 4)), str(np.round(recallsVeins[0], 4)), str(np.round(precisionsVeins[0], 4))))
+        logger.info('F-score: {} (Recall: {}, Precision: {})   <-- Artery (incl. lumen)'.format(str(np.round(fscoresArtery[0], 4)), str(np.round(recallsArtery[0], 4)), str(np.round(precisionsArtery[0], 4))))
+        logger.info('F-score: {} (Recall: {}, Precision: {})   <-- Artery lumen'.format(str(np.round(fscoresLumen[0], 4)), str(np.round(recallsLumen[0], 4)), str(np.round(precisionsLumen[0], 4))))
+
+        logger.info('SEGMENTATION RESULTS MEASURED BY OBJECT-LEVEL DICE SCORES:')
+        logger.info('Mean: ' + str(np.round(avg_dice_scoreTub, 4)) + ', Std: ' + str(np.round(std_dice_scoreTub, 4)) + ', Min: ' + str(np.round(min_dice_scoreTub, 4)) + ', Max: ' + str(np.round(max_dice_scoreTub, 4)) + '  <-- Tubuli')
+        logger.info('Mean: ' + str(np.round(avg_dice_scoreGlom, 4)) + ', Std: ' + str(np.round(std_dice_scoreGlom, 4)) + ', Min: ' + str(np.round(min_dice_scoreGlom, 4)) + ', Max: ' + str(np.round(max_dice_scoreGlom, 4)) + '  <-- Glomeruli (incl. tuft)')
+        logger.info('Mean: ' + str(np.round(avg_dice_scoreTuft, 4)) + ', Std: ' + str(np.round(std_dice_scoreTuft, 4)) + ', Min: ' + str(np.round(min_dice_scoreTuft, 4)) + ', Max: ' + str(np.round(max_dice_scoreTuft, 4)) + '  <-- Tuft')
+        logger.info('Mean: ' + str(np.round(avg_dice_scoreVeins, 4)) + ', Std: ' + str(np.round(std_dice_scoreVeins, 4)) + ', Min: ' + str(np.round(min_dice_scoreVeins, 4)) + ', Max: ' + str(np.round(max_dice_scoreVeins, 4)) + '  <-- Veins')
+        logger.info('Mean: ' + str(np.round(avg_dice_scoreArtery, 4)) + ', Std: ' + str(np.round(std_dice_scoreArtery, 4)) + ', Min: ' + str(np.round(min_dice_scoreArtery, 4)) + ', Max: ' + str(np.round(max_dice_scoreArtery, 4)) + '  <-- Artery (incl. lumen)')
+        logger.info('Mean: ' + str(np.round(avg_dice_scoreLumen, 4)) + ', Std: ' + str(np.round(std_dice_scoreLumen, 4)) + ', Min: ' + str(np.round(min_dice_scoreLumen, 4)) + ', Max: ' + str(np.round(max_dice_scoreLumen, 4)) + '  <-- Artery lumen')
+
+        if saveNumpyResults:
+            figPath = resultsPath + '/QuantitativeResults'
+            if not os.path.exists(figPath):
+                os.makedirs(figPath)
+
+            np.save(figPath + '/' + 'tubuliDice.npy', np.array(allClassEvaluators[0][0].diceScores))
+            np.save(figPath + '/' + 'glomDice.npy', np.array(allClassEvaluators[0][1].diceScores))
+            np.save(figPath + '/' + 'tuftDice.npy', np.array(allClassEvaluators[0][2].diceScores))
+            np.save(figPath + '/' + 'veinsDice.npy', np.array(allClassEvaluators[0][3].diceScores))
+            np.save(figPath + '/' + 'arteriesDice.npy', np.array(allClassEvaluators[0][4].diceScores))
+            np.save(figPath + '/' + 'lumenDice.npy', np.array(allClassEvaluators[0][5].diceScores))
+
+            np.save(figPath + '/' + '_detectionResults.npy', np.stack((precisionsAPTub, precisionsAPGlom, precisionsAPTuft, precisionsAPVeins, precisionsAPArtery, precisionsAPLumen, fscoresTub, fscoresGlom, fscoresTuft, fscoresVeins, fscoresArtery, fscoresLumen)))
+
+
+    if applyTestTimeAugmentation:
+        precisionsAPTub, avg_precisionTub, precisionsTub, recallsTub, fscoresTub, avg_dice_scoreTub, std_dice_scoreTub, min_dice_scoreTub, max_dice_scoreTub = allClassEvaluators[1][0].score()
+        precisionsAPGlom, avg_precisionGlom, precisionsGlom, recallsGlom, fscoresGlom, avg_dice_scoreGlom, std_dice_scoreGlom, min_dice_scoreGlom, max_dice_scoreGlom = allClassEvaluators[1][1].score()
+        precisionsAPTuft, avg_precisionTuft, precisionsTuft, recallsTuft, fscoresTuft, avg_dice_scoreTuft, std_dice_scoreTuft, min_dice_scoreTuft, max_dice_scoreTuft = allClassEvaluators[1][2].score()
+        precisionsAPVeins, avg_precisionVeins, precisionsVeins, recallsVeins, fscoresVeins, avg_dice_scoreVeins, std_dice_scoreVeins, min_dice_scoreVeins, max_dice_scoreVeins = allClassEvaluators[1][3].score()
+        precisionsAPArtery, avg_precisionArtery, precisionsArtery, recallsArtery, fscoresArtery, avg_dice_scoreArtery, std_dice_scoreArtery, min_dice_scoreArtery, max_dice_scoreArtery = allClassEvaluators[1][4].score()
+        precisionsAPLumen, avg_precisionLumen, precisionsLumen, recallsLumen, fscoresLumen, avg_dice_scoreLumen, std_dice_scoreLumen, min_dice_scoreLumen, max_dice_scoreLumen = allClassEvaluators[1][5].score()
+
+        logger.info('###### -> TTA RESULTS <- ######')
+        logger.info('DETECTION RESULTS MEASURED BY AVERAGE PRECISION:')
+        logger.info('0.5    0.55    0.6    0.65    0.7    0.75    0.8    0.85    0.9 <- Thresholds')
+        logger.info(str(np.round(precisionsAPTub, 4)) + ', Mean: ' + str(np.round(avg_precisionTub, 4)) + '  <-- Tubuli')
+        logger.info(str(np.round(precisionsAPGlom, 4)) + ', Mean: ' + str(np.round(avg_precisionGlom, 4)) + '  <-- Glomeruli (incl. tuft)')
+        logger.info(str(np.round(precisionsAPTuft, 4)) + ', Mean: ' + str(np.round(avg_precisionTuft, 4)) + '  <-- Tuft')
+        logger.info(str(np.round(precisionsAPVeins, 4)) + ', Mean: ' + str(np.round(avg_precisionVeins, 4)) + '  <-- Veins')
+        logger.info(str(np.round(precisionsAPArtery, 4)) + ', Mean: ' + str(np.round(avg_precisionArtery, 4)) + '  <-- Artery (incl. lumen)')
+        logger.info(str(np.round(precisionsAPLumen, 4)) + ', Mean: ' + str(np.round(avg_precisionLumen, 4)) + '  <-- Artery lumen')
+        logger.info('DETECTION RESULTS MEASURED BY F-SCORES, RECALL, PRECISION (MIN. 50% overlay):')
+        logger.info('F-score: {} (Recall: {}, Precision: {})   <-- Tubuli'.format(str(np.round(fscoresTub[0], 4)), str(np.round(recallsTub[0], 4)), str(np.round(precisionsTub[0], 4))))
+        logger.info('F-score: {} (Recall: {}, Precision: {})   <-- Glomeruli (incl. tuft)'.format(str(np.round(fscoresGlom[0], 4)), str(np.round(recallsGlom[0], 4)), str(np.round(precisionsGlom[0], 4))))
+        logger.info('F-score: {} (Recall: {}, Precision: {})   <-- Tuft'.format(str(np.round(fscoresTuft[0], 4)), str(np.round(recallsTuft[0], 4)), str(np.round(precisionsTuft[0], 4))))
+        logger.info('F-score: {} (Recall: {}, Precision: {})   <-- Veins'.format(str(np.round(fscoresVeins[0], 4)), str(np.round(recallsVeins[0], 4)), str(np.round(precisionsVeins[0], 4))))
+        logger.info('F-score: {} (Recall: {}, Precision: {})   <-- Artery (incl. lumen)'.format(str(np.round(fscoresArtery[0], 4)), str(np.round(recallsArtery[0], 4)), str(np.round(precisionsArtery[0], 4))))
+        logger.info('F-score: {} (Recall: {}, Precision: {})   <-- Artery lumen'.format(str(np.round(fscoresLumen[0], 4)), str(np.round(recallsLumen[0], 4)), str(np.round(precisionsLumen[0], 4))))
+
+        logger.info('TTA SEGMENTATION RESULTS MEASURED BY OBJECT-LEVEL DICE SCORES:')
+        logger.info('Mean: ' + str(np.round(avg_dice_scoreTub, 4)) + ', Std: ' + str(np.round(std_dice_scoreTub, 4)) + ', Min: ' + str(np.round(min_dice_scoreTub, 4)) + ', Max: ' + str(np.round(max_dice_scoreTub, 4)) + '  <-- Tubuli')
+        logger.info('Mean: ' + str(np.round(avg_dice_scoreGlom, 4)) + ', Std: ' + str(np.round(std_dice_scoreGlom, 4)) + ', Min: ' + str(np.round(min_dice_scoreGlom, 4)) + ', Max: ' + str(np.round(max_dice_scoreGlom, 4)) + '  <-- Glomeruli (incl. tuft)')
+        logger.info('Mean: ' + str(np.round(avg_dice_scoreTuft, 4)) + ', Std: ' + str(np.round(std_dice_scoreTuft, 4)) + ', Min: ' + str(np.round(min_dice_scoreTuft, 4)) + ', Max: ' + str(np.round(max_dice_scoreTuft, 4)) + '  <-- Tuft')
+        logger.info('Mean: ' + str(np.round(avg_dice_scoreVeins, 4)) + ', Std: ' + str(np.round(std_dice_scoreVeins, 4)) + ', Min: ' + str(np.round(min_dice_scoreVeins, 4)) + ', Max: ' + str(np.round(max_dice_scoreVeins, 4)) + '  <-- Veins')
+        logger.info('Mean: ' + str(np.round(avg_dice_scoreArtery, 4)) + ', Std: ' + str(np.round(std_dice_scoreArtery, 4)) + ', Min: ' + str(np.round(min_dice_scoreArtery, 4)) + ', Max: ' + str(np.round(max_dice_scoreArtery, 4)) + '  <-- Artery (incl. lumen)')
+        logger.info('Mean: ' + str(np.round(avg_dice_scoreLumen, 4)) + ', Std: ' + str(np.round(std_dice_scoreLumen, 4)) + ', Min: ' + str(np.round(min_dice_scoreLumen, 4)) + ', Max: ' + str(np.round(max_dice_scoreLumen, 4)) + '  <-- Artery lumen')
+
+        if saveNumpyResults:
+            figPath = resultsPath + '/QuantitativeResults' + '/Biopsies'
+            if not os.path.exists(figPath):
+                os.makedirs(figPath)
+
+            np.save(figPath + '/' + 'tubuliDice_TTA.npy', np.array(allClassEvaluators[1][0].diceScores))
+            np.save(figPath + '/' + 'glomDice_TTA.npy', np.array(allClassEvaluators[1][1].diceScores))
+            np.save(figPath + '/' + 'tuftDice_TTA.npy', np.array(allClassEvaluators[1][2].diceScores))
+            np.save(figPath + '/' + 'veinsDice_TTA.npy', np.array(allClassEvaluators[1][3].diceScores))
+            np.save(figPath + '/' + 'arteriesDice_TTA.npy', np.array(allClassEvaluators[1][4].diceScores))
+            np.save(figPath + '/' + 'lumenDice_TTA.npy', np.array(allClassEvaluators[1][5].diceScores))
+
+            np.save(figPath + '/' + '_detectionResults_TTA.npy', np.stack((precisionsAPTub, precisionsAPGlom, precisionsAPTuft, precisionsAPVeins, precisionsAPArtery, precisionsAPLumen, fscoresTub, fscoresGlom, fscoresTuft, fscoresVeins, fscoresArtery, fscoresLumen)))
+
+
+
+import numpy as np
+from skimage.util import view_as_windows
+from itertools import product
+from typing import Tuple
+
+def patchify(patches: np.ndarray, patch_size: Tuple[int, int], step: int = 1):
+    return view_as_windows(patches, patch_size, step)
+
+def unpatchify(patches: np.ndarray, imsize: Tuple[int, int]):
+
+    assert len(patches.shape) == 4
+
+    i_h, i_w = imsize
+    image = np.zeros(imsize, dtype=patches.dtype)
+    divisor = np.zeros(imsize, dtype=patches.dtype)
+
+    n_h, n_w, p_h, p_w = patches.shape
+
+    # Calculat the overlap size in each axis
+    o_w = (n_w * p_w - i_w) / (n_w - 1)
+    o_h = (n_h * p_h - i_h) / (n_h - 1)
+
+    # The overlap should be integer, otherwise the patches are unable to reconstruct into a image with given shape
+    assert int(o_w) == o_w
+    assert int(o_h) == o_h
+
+    o_w = int(o_w)
+    o_h = int(o_h)
+
+    s_w = p_w - o_w
+    s_h = p_h - o_h
+
+    for i, j in product(range(n_h), range(n_w)):
+        patch = patches[i,j]
+        image[(i * s_h):(i * s_h) + p_h, (j * s_w):(j * s_w) + p_w] += patch
+        divisor[(i * s_h):(i * s_h) + p_h, (j * s_w):(j * s_w) + p_w] += 1
+
+    return image / divisor
+
+# Examplary use:
+# # # # # # # # #
+# import numpy as np
+# from patchify import patchify, unpatchify
+#
+# image = np.array([[1,2,3,4], [5,6,7,8], [9,10,11,12]])
+#
+# patches = patchify(image, (2,2), step=1) # split image into 2*3 small 2*2 patches.
+#
+# assert patches.shape == (2, 3, 2, 2)
+# reconstructed_image = unpatchify(patches, image.shape)
+#
+# assert (reconstructed_image == image).all()
+
+
+
+def getChannelSmootingConvLayer(channels, kernel_size=5, sigma=1.5):
+
+    # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
+    x_cord = torch.arange(kernel_size)
+    x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
+    y_grid = x_grid.t()
+    xy_grid = torch.stack([x_grid, y_grid], dim=-1)
+
+    mean = (kernel_size - 1) / 2.
+    variance = sigma ** 2.
+
+    # Calculate the 2-dimensional gaussian kernel which is
+    # the product of two gaussian distributions for two different
+    # variables (in this case called x and y)
+    gaussian_kernel = (1. / (2. * math.pi * variance)) * \
+                      torch.exp(
+                          (-torch.sum((xy_grid - mean) ** 2., dim=-1) / \
+                          (2 * variance)).float()
+                      )
+    # Make sure sum of values in gaussian kernel equals 1.
+    gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
+
+    # Reshape to 2d depthwise convolutional weight
+    gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
+    gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)
+
+    gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels,
+                                kernel_size=kernel_size, groups=channels, bias=False, padding=2)
+
+    gaussian_filter.weight.data = gaussian_kernel
+    gaussian_filter.weight.requires_grad = False
+
+    return gaussian_filter
+
+
+if __name__ == '__main__':
+    print()
\ No newline at end of file
-- 
GitLab