# this file recursively performs the automated segmentation of WSIs by applying the tissue segmentation and structure segmentation CNN to all WSIs contained in a specified folder

import numpy as np
import os
import sys
import cv2
import torch
import math
import logging as log
from tqdm import tqdm, trange
import re

from openslide.lowlevel import OpenSlideError
import openslide as osl
from PIL import Image
import matplotlib.pyplot as plt
from scipy.ndimage.measurements import label
from scipy.ndimage import zoom
from scipy.ndimage.morphology import binary_dilation, binary_closing, binary_fill_holes, binary_erosion, binary_opening
import scipy.ndimage
import scipy as sp
from skimage.transform import rescale, resize
from skimage.measure import regionprops
from skimage.segmentation import flood, flood_fill
from skimage.color import rgb2gray
from skimage.segmentation import clear_border
from skimage import filters
from skimage.morphology import remove_small_objects

from utils import generate_ball, patchify, unpatchify, saveOverlayResults, savePredictionOverlayResults, savePredictionOverlayResults_Fast, saveImage
from model import Custom




WSIrootFolder = 'SPECIFIED WSI FOLDER'
modelpath = 'STRUCTURE SEGMENTATION MODEL PATH'
resultsPath = 'RESULTS PATH'

if not os.path.exists(resultsPath):
    os.makedirs(resultsPath)

model_FG_path = 'TISSUE SEGMENTATION MODEL PATH'

patchSegmSize = 516
patchImgSize = 640
patchLengthUM = 174.

patchImgSize_FG = 512
patchLengthUM_FG = 2500
regionMinSizeUM = 3E5

alpha = 0.3
strideProportion = 0.5
figHeight = 15
minibatchSize = 2
minibatchSize_FG = 1
useAllGPUs = False
GPUno = 0
device = torch.device("cuda:" + str(GPUno) if torch.cuda.is_available() else "cpu")

tubuliInstanceID_StartsWith = 10  # => tubuli instances start with id 10

ftChannelsOutput = 8
applyTestTimeAugmentation = True
centerWeighting = False
centerWeight = 3.

saveWSICoarseForegroundSegmResults = True
# saveCroppedForegroundSegmResults = False
saveCroppedWSIimg = True
saveWSIandPredNumpy = True

TUBULI_MIN_SIZE = 400
GLOM_MIN_SIZE = 1500
TUFT_MIN_SIZE = 500
VEIN_MIN_SIZE = 3000
ARTERY_MIN_SIZE = 400
LUMEN_MIN_SIZE = 20

labelBG = 8

# LOAD STRUCTURE SEGMENTATION MODEL
model = Custom(input_ch=3, output_ch=ftChannelsOutput, modelDim=2)
model.load_state_dict(torch.load(modelpath, map_location=lambda storage, loc: storage))
model.train(False)
model.eval()

if useAllGPUs:
    model = torch.nn.DataParallel(model)  # multi-GPUs
model = model.to(device)

# LOAD TISSUE SEGMENTATION MODEL
model_FG = torch.load(model_FG_path)
model_FG = model_FG.to(device)
model_FG.eval()

segmentationPatchStride = int(patchSegmSize * strideProportion)
targetSpacing = patchLengthUM / patchSegmSize

targetSpacing_FG = patchLengthUM_FG / patchImgSize_FG

shiftMinUM = (patchImgSize - patchSegmSize) // 2 * targetSpacing + 2 # + 2 just for sufficient margin reasons
shiftMaxUM = ((patchImgSize - patchSegmSize) // 2 + patchSegmSize * strideProportion) * targetSpacing

# Set up logger
log.basicConfig(
    level=log.INFO,
    format='%(asctime)s %(message)s',
    datefmt='%Y/%m/%d %I:%M:%S %p',
    handlers=[
        log.FileHandler(resultsPath + '/LOGS.log', 'w'),
        log.StreamHandler(sys.stdout)
    ])
logger = log.getLogger()

struc3 = generate_ball(1)

try:
    # walk through directoy and its subdirectoy recursively
    for dirName, subdirList, fileList in os.walk(WSIrootFolder):

        # filter WSIs in current directory
        fileListWSI = sorted([fname for fname in fileList if (fname.endswith('.svs') or fname.endswith('.ndpi') or fname.endswith('.scn')) and 'PAS' in fname])

        if len(fileListWSI) != 0:
            logger.info(str(len(fileListWSI)) + ' WSIs to be analyzed in directory: ' + dirName)

            resultsDir = resultsPath + dirName[len(WSIrootFolder):]
            resultsDirNPYfiles = resultsDir + '/npyFiles'
            if not os.path.exists(resultsDirNPYfiles):
                os.makedirs(resultsDirNPYfiles)

            # traverse through all found WSIs
            for no, fname in enumerate(fileListWSI):

                # Extract/print relevant parameters
                try:
                    slide = osl.OpenSlide(os.path.join(dirName, fname))
                    logger.info(str(no + 1) + ':  WSI:\t' + fname)
                    spacings = np.array([float(slide.properties['openslide.mpp-x']), float(slide.properties['openslide.mpp-y'])])
                except:
                    logger.info('Slide {} or its spacings not readable, slide skipped!'.format(fname))
                    continue
                levelDims = np.array(slide.level_dimensions)
                amountLevels = len(levelDims)
                levelDownsamples = np.asarray(np.round(np.array(slide.level_downsamples)), np.int)

                logger.info('Spacings: ' + str(spacings))
                logger.info('Level Dimensions: ' + str(levelDims))
                logger.info('Level Downsamples: '+str(levelDownsamples))

                suffixCut = -5 if fname.split('.')[-1] == 'ndpi' else -4

                # extract the WSI level that is closest to the target spacing of the tissue segmentation network (increasing efficiency instead of simply taking full resolution, finest level 0) 
                spacingFactorX = spacings[0] / targetSpacing_FG
                spacingFactorY = spacings[1] / targetSpacing_FG
                x_scaled = round(levelDims[0][0] * spacingFactorX)
                y_scaled = round(levelDims[0][1] * spacingFactorY)

                usedLevel = np.argwhere(np.all(levelDims > [x_scaled, y_scaled], 1) == True).max()

                # resample to target spacing
                logger.info('Image size resampled to FG spacing would be {}, {}, thus level {} with resolution {} chosen as resampling point!'.format(x_scaled, y_scaled, usedLevel, levelDims[usedLevel]))
                spacingOnUsedLevelX = spacings[0] * levelDownsamples[usedLevel]
                spacingOnUsedLevelY = spacings[1] * levelDownsamples[usedLevel]
                downsamplingFactorX = spacingOnUsedLevelX / targetSpacing_FG
                downsamplingFactorY = spacingOnUsedLevelY / targetSpacing_FG
                imgWSI = np.array(slide.read_region(location=np.array([0, 0]), level=usedLevel, size=levelDims[usedLevel]))[:, :, :3]

                d1 = int(round(imgWSI.shape[1] * downsamplingFactorX))
                d2 = int(round(imgWSI.shape[0] * downsamplingFactorY))

                imgWSI = cv2.resize(imgWSI, dsize=(d1, d2), interpolation=cv2.INTER_LINEAR) #dtype: uint8, size: d2 x d1 x 3
                imgWSIzeropadded = np.zeros(shape=(d2+patchImgSize_FG-1, d1+patchImgSize_FG-1, 3), dtype=np.float32)
                imgWSIzeropadded[:d2,:d1,:] = imgWSI

                # tesselate resampled image
                smallOverlappingPatches = patchify(imgWSIzeropadded, patch_size=(patchImgSize_FG, patchImgSize_FG, 3), step=patchImgSize_FG) # CARE: IMAGE DATA AT THE RIGHT AND BOTTOM BORDERS IS LOST !!!

                tileDataset = []
                # with tqdm(total=smallOverlappingPatches.shape[0] * smallOverlappingPatches.shape[1]) as pbar:
                for i in range(smallOverlappingPatches.shape[0]):
                    for j in range(smallOverlappingPatches.shape[1]):
                        tileDataset.append({'name': '{}-{}'.format(i, j), 'data': torch.from_numpy(smallOverlappingPatches[i, j, 0, :, :, :])})
                            # pbar.update(1)

                img_mask = np.zeros(shape=(imgWSIzeropadded.shape[0],imgWSIzeropadded.shape[1]), dtype=np.bool)

                # create dataloader for concurrent tile processing and prediction computation
                dataloader = torch.utils.data.DataLoader(tileDataset, batch_size=minibatchSize_FG, shuffle=False)
                with torch.no_grad():
                    for i, data in enumerate(dataloader, 0):
                        imgBatch = data['data'].permute(0, 3, 1, 2).to(device)

                        prediction = model_FG(imgBatch)  # prediction should have shape (1,2,512,512)
                        prediction = (prediction[:,1,:,:] > prediction[:,0,:,:]).to("cpu").numpy()

                        for n, d in zip(data['name'], prediction):
                            x = int(n.split('-')[0])
                            y = int(n.split('-')[1])
                            img_mask[x * patchImgSize_FG: (x+1) * patchImgSize_FG, y * patchImgSize_FG: (y+1) * patchImgSize_FG] = d

                img_mask = img_mask[:d2,:d1]

                # postprocessing
                img_mask = binary_fill_holes(img_mask)

                # remove connected regions if too small
                regionMinPixels = regionMinSizeUM / (targetSpacing_FG * targetSpacing_FG)
                img_mask, _ = label(img_mask)
                labeledRegions, numberFGRegions = label(remove_small_objects(img_mask, min_size=regionMinPixels))
                if numberFGRegions < 256:
                    labeledRegions = np.asarray(labeledRegions, np.uint8)

                if saveWSICoarseForegroundSegmResults:
                    logger.info('Saving WSI-level coarse FG segmentation results...')
                    savePredictionOverlayResults(imgWSI, labeledRegions, alpha=alpha, figSize=(labeledRegions.shape[1]/labeledRegions.shape[0]*figHeight, figHeight), fullResultPath=resultsDir + '/' + fname[:suffixCut] + '_0_fgSeg.png')

                labeledRegions = cv2.resize(labeledRegions, dsize=(levelDims[0][0],levelDims[0][1]), interpolation = cv2.INTER_NEAREST) # FG RESULTS ON WSI-RESOLUTION, UINT8, REGION IDs

                logger.info('In total -> '+str(numberFGRegions)+' <- regions on WSI detected!')

                # process all detected tissue regions separately 
                for regionID in range(1, numberFGRegions+1):
                    logger.info('#######\n Extract foreground region ' + str(regionID) + '...')
                    detectedRegion = labeledRegions == regionID

                    # compute bounding box and how much to enlarge bbox to consider wider context utilization (especially for patchify)
                    temp = np.where(detectedRegion == 1)
                    bbox = np.array([np.min(temp[0]), np.min(temp[1]), np.max(temp[0]), np.max(temp[1])])

                    shiftMin = round(shiftMinUM / spacings[0])
                    shiftMax = round(shiftMaxUM / spacings[0])

                    # enlarge bounding box due to wider context consideration
                    bbox[0] = max(bbox[0] - shiftMin, 0)
                    bbox[1] = max(bbox[1] - shiftMin, 0)
                    bbox[2] = min(bbox[2] + shiftMax, detectedRegion.shape[0] - 1) + 1
                    bbox[3] = min(bbox[3] + shiftMax, detectedRegion.shape[1] - 1) + 1

                    logger.info('Extract high res patch and segm map...')
                    try:
                        img_WSI = np.asarray(np.array(slide.read_region(location=np.array([bbox[1], bbox[0]]), level=0, size=np.array([bbox[3] - bbox[1], bbox[2] - bbox[0]])))[:, :, :3], np.uint8)
                    except OpenSlideError:
                        logger.info('#################################### FILE CORRUPTED - IGNORED ####################################')
                        continue
                    detectedRegion = detectedRegion[bbox[0]:bbox[2], bbox[1]:bbox[3]]

                    # extract image and resample into target spacing of the structure segmentation network
                    downsamplingFactor = spacings[0] / targetSpacing  # Rescaling would be very slow using 'rescale' method!
                    logger.info('Utilized spacing of slide: '+str(spacings[0])+', Resample both patches using factor: ' + str(downsamplingFactor))
                    segMap = np.asarray(zoom(detectedRegion, downsamplingFactor, order=0), np.bool)
                    img_WSI = cv2.resize(img_WSI, dsize=tuple(np.flip(segMap.shape)), interpolation=cv2.INTER_LINEAR)
                    # segMap = np.asarray(np.round(rescale(segMap, downsamplingFactor, order=0, preserve_range=True, multichannel=False)), np.bool)
                    assert img_WSI.shape[:2] == segMap.shape, "Error: Upsampling via zoom/resize led to unequal resolutions..."
                    logger.info('Done - size of extracted resampled high res patch: ' + str(img_WSI.shape))

                    if np.min(segMap.shape) < patchImgSize:
                        logger.info('Detected region smaller than window, thus skipped...')
                        continue

                    ##### PREPROCESSING DONE - NOW: NETWORK SEGMENTATION PART #####
                    logger.info('Start segmentation process...')

                    # preprocess img
                    img_WSI_prep = np.array((img_WSI / 255. - 0.5) / 0.5, np.float32)

                    # tesselate image and tissue prediction results
                    smallOverlappingPatches = patchify(img_WSI_prep.copy(), patch_size=(patchImgSize, patchImgSize, 3), step=segmentationPatchStride) # CARE: IMAGE DATA AT THE RIGHT AND BOTTOM BORDERS IS LOST !!!
                    smallOverlappingPatches_FG = patchify(segMap.copy(), patch_size=(patchImgSize, patchImgSize), step=segmentationPatchStride)

                    tileDataset = []
                    for i in range(smallOverlappingPatches.shape[0]):
                        for j in range(smallOverlappingPatches.shape[1]):
                            if smallOverlappingPatches_FG[i,j,:,:].any():
                                tileDataset.append({'name': '{}-{}'.format(i, j), 'data': torch.from_numpy(smallOverlappingPatches[i, j, 0, :, :, :])})

                    # calculate segmentation patch size since patchify cuts of last patch if not exactly fitting in window
                    startX = (patchImgSize - patchSegmSize) // 2; startY = startX
                    endX = segmentationPatchStride * (smallOverlappingPatches.shape[0]-1) + patchSegmSize + startX
                    endY = segmentationPatchStride * (smallOverlappingPatches.shape[1]-1) + patchSegmSize + startY

                    bigPatchResults = torch.zeros(device="cpu", size=(ftChannelsOutput, endX - startX, endY - startY))

                    # create dataloader for concurrent prediction computation
                    dataloader = torch.utils.data.DataLoader(tileDataset, batch_size=minibatchSize, shuffle=False)
                    with torch.no_grad():
                        for i, data in enumerate(dataloader, 0):
                            imgBatch = data['data'].permute(0, 3, 1, 2).to(device)

                            prediction = torch.softmax(model(imgBatch), dim=1)  # shape: (minibatchSize, 8, 516, 516)
                            if applyTestTimeAugmentation:
                                imgBatch = imgBatch.flip(2)
                                prediction += torch.softmax(model(imgBatch), 1).flip(2)

                                imgBatch = imgBatch.flip(3)
                                prediction += torch.softmax(model(imgBatch), 1).flip(3).flip(2)

                                imgBatch = imgBatch.flip(2)
                                prediction += torch.softmax(model(imgBatch), 1).flip(3)

                            if centerWeighting:
                                prediction[:, :, patchSegmSize // 4: patchSegmSize // 4 * 3, patchSegmSize // 4: patchSegmSize // 4 * 3] *= centerWeight

                            prediction = prediction.to("cpu")

                            for n, d in zip(data['name'], prediction):
                                x = int(n.split('-')[0])
                                y = int(n.split('-')[1])
                                bigPatchResults[:, x * segmentationPatchStride: x * segmentationPatchStride + patchSegmSize, y * segmentationPatchStride: y * segmentationPatchStride + patchSegmSize] = d

                        bigPatchResults = torch.argmax(bigPatchResults, 0).byte().numpy() # shape: (1536, 2048)

                    logger.info('Predictions generated. Final shape: '+str(bigPatchResults.shape))

                    # Context margin + border patches not fully inside img removed
                    img_WSI = img_WSI[startX:endX, startY:endY, :]
                    segMap = segMap[startX:endX, startY:endY]
                    bgMap = np.logical_not(segMap)

                    # Save cropped foreground segmentation result as overlay
                    if saveCroppedWSIimg:
                        logger.info('Saving cropped segmented WSI image...')
                        saveImage(img_WSI, fullResultPath=resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_fgWSI_({}_{}_{}).png'.format(bbox[0], bbox[1], spacings[0]), figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight))

                    # correct foreground segmentation including all touching vein prediction instances
                    bigPatchResults[bgMap] = 4 #vein class assignment of bg
                    temp = bigPatchResults == 4
                    bgMap = np.logical_xor(clear_border(temp), temp)
                    segMap = np.logical_not(bgMap)
                    segMap = binary_fill_holes(segMap)
                    bgMap = np.logical_not(segMap)

                    # remove small fg components
                    temp, numberLabeledRegions = label(segMap, struc3)
                    if numberLabeledRegions > 1:
                        regionMinPixels = regionMinSizeUM / (targetSpacing * targetSpacing)
                        regionIDs = np.where(np.array([region.area for region in regionprops(temp)]) > regionMinSizeUM)[0] + 1
                        segMap = np.isin(temp, regionIDs)
                        bgMap = np.logical_not(segMap)

                    bigPatchResults[bgMap] = labelBG # color of label 'labelBG' => Purple represents BG just for visualization purposes

                    logger.info('Saving prediction and background overlay results...')
                    savePredictionOverlayResults(img_WSI, bigPatchResults, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlay.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha)

                    if saveWSIandPredNumpy:
                        logger.info('Saving numpy img...')
                        np.save(resultsDirNPYfiles + '/' + fname[:suffixCut] + '_'+str(regionID)+'_resultWSI.npy', img_WSI)

                    logger.info('Start postprocessing...')

                    # remove border class
                    bigPatchResults[bigPatchResults == 7] = 0

                    # Delete BG to reduce postprocessing overhead
                    bigPatchResults[bgMap] = 0

                    ################# HOLE FILLING ################
                    bigPatchResults[binary_fill_holes(bigPatchResults == 1)] = 1  # tubuli
                    bigPatchResults[binary_fill_holes(bigPatchResults == 4)] = 4  # veins
                    temp = binary_fill_holes(bigPatchResults == 3)  # tuft
                    bigPatchResults[binary_fill_holes(np.logical_or(bigPatchResults == 3, bigPatchResults == 2))] = 2  # glom
                    bigPatchResults[temp] = 3  # tuft
                    temp = binary_fill_holes(bigPatchResults == 6)  # artery_lumen
                    bigPatchResults[binary_fill_holes(np.logical_or(bigPatchResults == 5, bigPatchResults == 6))] = 5  # full_artery
                    bigPatchResults[temp] = 6  # artery_lumen

                    ###### REMOVING TOO SMALL CONNECTED REGIONS ######
                    temp, _ = label(bigPatchResults == 1)
                    finalResults_Instance = remove_small_objects(temp, min_size=TUBULI_MIN_SIZE) > 0

                    ############ PERFORM TUBULE DILATION ############
                    finalResults_Instance, numberTubuli = label(finalResults_Instance) #dtype: int32
                    finalResults_Instance[finalResults_Instance > 0] += (tubuliInstanceID_StartsWith - 1)
                    if numberTubuli < 65500:
                        finalResults_Instance = cv2.dilate(np.asarray(finalResults_Instance, np.uint16), kernel=np.asarray(generate_ball(2), np.uint8), iterations=1) #RESULT TYPE: UINT16
                    else:
                        finalResults_Instance = np.asarray(cv2.dilate(np.asarray(finalResults_Instance, np.float64), kernel=np.asarray(generate_ball(2), np.uint8), iterations=1), np.int32)

                    temp, _ = label(np.logical_or(bigPatchResults == 2, bigPatchResults == 3))
                    finalResults_Instance[remove_small_objects(temp, min_size=GLOM_MIN_SIZE) > 0] = 2
                    temp, _ = label(bigPatchResults == 3)
                    finalResults_Instance[np.logical_and(remove_small_objects(temp, min_size=TUFT_MIN_SIZE) > 0, finalResults_Instance==2)] = 3
                    temp, _ = label(bigPatchResults == 4)
                    finalResults_Instance[remove_small_objects(temp, min_size=VEIN_MIN_SIZE) > 0] = 4
                    temp, _ = label(np.logical_or(bigPatchResults == 5, bigPatchResults == 6))
                    finalResults_Instance[remove_small_objects(temp, min_size=ARTERY_MIN_SIZE) > 0] = 5
                    temp, _ = label(bigPatchResults == 6)
                    finalResults_Instance[np.logical_and(remove_small_objects(temp, min_size=LUMEN_MIN_SIZE) > 0, finalResults_Instance==5)] = 6

                    finalResults_Instance = finalResults_Instance * segMap

                    logger.info('Done - Save final instance overlay results...')
                    savePredictionOverlayResults(img_WSI, finalResults_Instance, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlayFINALInstance.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha)

                    logger.info('Done - Save final non-instance overlay results...')
                    finalResults = finalResults_Instance.copy()
                    finalResults[finalResults > tubuliInstanceID_StartsWith] = 1
                    savePredictionOverlayResults(img_WSI, finalResults, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlayFINAL.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha)

                    finalResults_Instance[bgMap] = labelBG

                    if saveWSIandPredNumpy:
                        logger.info('Saving numpy final instance prediction results...')
                        np.save(resultsDirNPYfiles + '/' + fname[:suffixCut] + '_'+str(regionID)+'_finalInstancePrediction.npy', finalResults_Instance)

                logger.info('####################')

        break

except:
    logger.exception('! Exception !')
    raise

log.info('%%%% Ended regularly ! %%%%')