From 7ddbfbc1b02ac4ac6bf98099b975b79b5764dff1 Mon Sep 17 00:00:00 2001
From: Nassim Bouteldja <nbouteldja@ukaachen.de>
Date: Wed, 13 Apr 2022 21:37:38 +0200
Subject: [PATCH] Upload New File

---
 segment_WSI.py | 417 +++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 417 insertions(+)
 create mode 100644 segment_WSI.py

diff --git a/segment_WSI.py b/segment_WSI.py
new file mode 100644
index 0000000..b27507e
--- /dev/null
+++ b/segment_WSI.py
@@ -0,0 +1,417 @@
+# this file recursively performs the automated segmentation of WSIs by applying the tissue segmentation and structure segmentation CNN to all WSIs contained in a specified folder
+
+import numpy as np
+import os
+import sys
+import cv2
+import torch
+import math
+import logging as log
+from tqdm import tqdm, trange
+import re
+
+from openslide.lowlevel import OpenSlideError
+import openslide as osl
+from PIL import Image
+import matplotlib.pyplot as plt
+from scipy.ndimage.measurements import label
+from scipy.ndimage import zoom
+from scipy.ndimage.morphology import binary_dilation, binary_closing, binary_fill_holes, binary_erosion, binary_opening
+import scipy.ndimage
+import scipy as sp
+from skimage.transform import rescale, resize
+from skimage.measure import regionprops
+from skimage.segmentation import flood, flood_fill
+from skimage.color import rgb2gray
+from skimage.segmentation import clear_border
+from skimage import filters
+from skimage.morphology import remove_small_objects
+
+from utils import generate_ball, patchify, unpatchify, saveOverlayResults, savePredictionOverlayResults, savePredictionOverlayResults_Fast, saveImage
+from model import Custom
+
+
+
+
+WSIrootFolder = 'SPECIFIED WSI FOLDER'
+modelpath = 'STRUCTURE SEGMENTATION MODEL PATH'
+resultsPath = 'RESULTS PATH'
+
+if not os.path.exists(resultsPath):
+    os.makedirs(resultsPath)
+
+model_FG_path = 'TISSUE SEGMENTATION MODEL PATH'
+
+patchSegmSize = 516
+patchImgSize = 640
+patchLengthUM = 174.
+
+patchImgSize_FG = 512
+patchLengthUM_FG = 2500
+regionMinSizeUM = 3E5
+
+alpha = 0.3
+strideProportion = 0.5
+figHeight = 15
+minibatchSize = 2
+minibatchSize_FG = 1
+useAllGPUs = False
+GPUno = 0
+device = torch.device("cuda:" + str(GPUno) if torch.cuda.is_available() else "cpu")
+
+tubuliInstanceID_StartsWith = 10  # => tubuli instances start with id 10
+
+ftChannelsOutput = 8
+applyTestTimeAugmentation = True
+centerWeighting = False
+centerWeight = 3.
+
+saveWSICoarseForegroundSegmResults = True
+# saveCroppedForegroundSegmResults = False
+saveCroppedWSIimg = True
+saveWSIandPredNumpy = True
+
+TUBULI_MIN_SIZE = 400
+GLOM_MIN_SIZE = 1500
+TUFT_MIN_SIZE = 500
+VEIN_MIN_SIZE = 3000
+ARTERY_MIN_SIZE = 400
+LUMEN_MIN_SIZE = 20
+
+labelBG = 8
+
+# LOAD STRUCTURE SEGMENTATION MODEL
+model = Custom(input_ch=3, output_ch=ftChannelsOutput, modelDim=2)
+model.load_state_dict(torch.load(modelpath, map_location=lambda storage, loc: storage))
+model.train(False)
+model.eval()
+
+if useAllGPUs:
+    model = torch.nn.DataParallel(model)  # multi-GPUs
+model = model.to(device)
+
+# LOAD TISSUE SEGMENTATION MODEL
+model_FG = torch.load(model_FG_path)
+model_FG = model_FG.to(device)
+model_FG.eval()
+
+segmentationPatchStride = int(patchSegmSize * strideProportion)
+targetSpacing = patchLengthUM / patchSegmSize
+
+targetSpacing_FG = patchLengthUM_FG / patchImgSize_FG
+
+shiftMinUM = (patchImgSize - patchSegmSize) // 2 * targetSpacing + 2 # + 2 just for sufficient margin reasons
+shiftMaxUM = ((patchImgSize - patchSegmSize) // 2 + patchSegmSize * strideProportion) * targetSpacing
+
+# Set up logger
+log.basicConfig(
+    level=log.INFO,
+    format='%(asctime)s %(message)s',
+    datefmt='%Y/%m/%d %I:%M:%S %p',
+    handlers=[
+        log.FileHandler(resultsPath + '/LOGS.log', 'w'),
+        log.StreamHandler(sys.stdout)
+    ])
+logger = log.getLogger()
+
+struc3 = generate_ball(1)
+
+try:
+    # walk through directoy and its subdirectoy recursively
+    for dirName, subdirList, fileList in os.walk(WSIrootFolder):
+
+        # filter WSIs in current directory
+        fileListWSI = sorted([fname for fname in fileList if (fname.endswith('.svs') or fname.endswith('.ndpi') or fname.endswith('.scn')) and 'PAS' in fname])
+
+        if len(fileListWSI) != 0:
+            logger.info(str(len(fileListWSI)) + ' WSIs to be analyzed in directory: ' + dirName)
+
+            resultsDir = resultsPath + dirName[len(WSIrootFolder):]
+            resultsDirNPYfiles = resultsDir + '/npyFiles'
+            if not os.path.exists(resultsDirNPYfiles):
+                os.makedirs(resultsDirNPYfiles)
+
+            # traverse through all found WSIs
+            for no, fname in enumerate(fileListWSI):
+
+                # Extract/print relevant parameters
+                try:
+                    slide = osl.OpenSlide(os.path.join(dirName, fname))
+                    logger.info(str(no + 1) + ':  WSI:\t' + fname)
+                    spacings = np.array([float(slide.properties['openslide.mpp-x']), float(slide.properties['openslide.mpp-y'])])
+                except:
+                    logger.info('Slide {} or its spacings not readable, slide skipped!'.format(fname))
+                    continue
+                levelDims = np.array(slide.level_dimensions)
+                amountLevels = len(levelDims)
+                levelDownsamples = np.asarray(np.round(np.array(slide.level_downsamples)), np.int)
+
+                logger.info('Spacings: ' + str(spacings))
+                logger.info('Level Dimensions: ' + str(levelDims))
+                logger.info('Level Downsamples: '+str(levelDownsamples))
+
+                suffixCut = -5 if fname.split('.')[-1] == 'ndpi' else -4
+
+                # extract the WSI level that is closest to the target spacing of the tissue segmentation network (increasing efficiency instead of simply taking full resolution, finest level 0) 
+                spacingFactorX = spacings[0] / targetSpacing_FG
+                spacingFactorY = spacings[1] / targetSpacing_FG
+                x_scaled = round(levelDims[0][0] * spacingFactorX)
+                y_scaled = round(levelDims[0][1] * spacingFactorY)
+
+                usedLevel = np.argwhere(np.all(levelDims > [x_scaled, y_scaled], 1) == True).max()
+
+                # resample to target spacing
+                logger.info('Image size resampled to FG spacing would be {}, {}, thus level {} with resolution {} chosen as resampling point!'.format(x_scaled, y_scaled, usedLevel, levelDims[usedLevel]))
+                spacingOnUsedLevelX = spacings[0] * levelDownsamples[usedLevel]
+                spacingOnUsedLevelY = spacings[1] * levelDownsamples[usedLevel]
+                downsamplingFactorX = spacingOnUsedLevelX / targetSpacing_FG
+                downsamplingFactorY = spacingOnUsedLevelY / targetSpacing_FG
+                imgWSI = np.array(slide.read_region(location=np.array([0, 0]), level=usedLevel, size=levelDims[usedLevel]))[:, :, :3]
+
+                d1 = int(round(imgWSI.shape[1] * downsamplingFactorX))
+                d2 = int(round(imgWSI.shape[0] * downsamplingFactorY))
+
+                imgWSI = cv2.resize(imgWSI, dsize=(d1, d2), interpolation=cv2.INTER_LINEAR) #dtype: uint8, size: d2 x d1 x 3
+                imgWSIzeropadded = np.zeros(shape=(d2+patchImgSize_FG-1, d1+patchImgSize_FG-1, 3), dtype=np.float32)
+                imgWSIzeropadded[:d2,:d1,:] = imgWSI
+
+                # tesselate resampled image
+                smallOverlappingPatches = patchify(imgWSIzeropadded, patch_size=(patchImgSize_FG, patchImgSize_FG, 3), step=patchImgSize_FG) # CARE: IMAGE DATA AT THE RIGHT AND BOTTOM BORDERS IS LOST !!!
+
+                tileDataset = []
+                # with tqdm(total=smallOverlappingPatches.shape[0] * smallOverlappingPatches.shape[1]) as pbar:
+                for i in range(smallOverlappingPatches.shape[0]):
+                    for j in range(smallOverlappingPatches.shape[1]):
+                        tileDataset.append({'name': '{}-{}'.format(i, j), 'data': torch.from_numpy(smallOverlappingPatches[i, j, 0, :, :, :])})
+                            # pbar.update(1)
+
+                img_mask = np.zeros(shape=(imgWSIzeropadded.shape[0],imgWSIzeropadded.shape[1]), dtype=np.bool)
+
+                # create dataloader for concurrent tile processing and prediction computation
+                dataloader = torch.utils.data.DataLoader(tileDataset, batch_size=minibatchSize_FG, shuffle=False)
+                with torch.no_grad():
+                    for i, data in enumerate(dataloader, 0):
+                        imgBatch = data['data'].permute(0, 3, 1, 2).to(device)
+
+                        prediction = model_FG(imgBatch)  # prediction should have shape (1,2,512,512)
+                        prediction = (prediction[:,1,:,:] > prediction[:,0,:,:]).to("cpu").numpy()
+
+                        for n, d in zip(data['name'], prediction):
+                            x = int(n.split('-')[0])
+                            y = int(n.split('-')[1])
+                            img_mask[x * patchImgSize_FG: (x+1) * patchImgSize_FG, y * patchImgSize_FG: (y+1) * patchImgSize_FG] = d
+
+                img_mask = img_mask[:d2,:d1]
+
+                # postprocessing
+                img_mask = binary_fill_holes(img_mask)
+
+                # remove connected regions if too small
+                regionMinPixels = regionMinSizeUM / (targetSpacing_FG * targetSpacing_FG)
+                img_mask, _ = label(img_mask)
+                labeledRegions, numberFGRegions = label(remove_small_objects(img_mask, min_size=regionMinPixels))
+                if numberFGRegions < 256:
+                    labeledRegions = np.asarray(labeledRegions, np.uint8)
+
+                if saveWSICoarseForegroundSegmResults:
+                    logger.info('Saving WSI-level coarse FG segmentation results...')
+                    savePredictionOverlayResults(imgWSI, labeledRegions, alpha=alpha, figSize=(labeledRegions.shape[1]/labeledRegions.shape[0]*figHeight, figHeight), fullResultPath=resultsDir + '/' + fname[:suffixCut] + '_0_fgSeg.png')
+
+                labeledRegions = cv2.resize(labeledRegions, dsize=(levelDims[0][0],levelDims[0][1]), interpolation = cv2.INTER_NEAREST) # FG RESULTS ON WSI-RESOLUTION, UINT8, REGION IDs
+
+                logger.info('In total -> '+str(numberFGRegions)+' <- regions on WSI detected!')
+
+                # process all detected tissue regions separately 
+                for regionID in range(1, numberFGRegions+1):
+                    logger.info('#######\n Extract foreground region ' + str(regionID) + '...')
+                    detectedRegion = labeledRegions == regionID
+
+                    # compute bounding box and how much to enlarge bbox to consider wider context utilization (especially for patchify)
+                    temp = np.where(detectedRegion == 1)
+                    bbox = np.array([np.min(temp[0]), np.min(temp[1]), np.max(temp[0]), np.max(temp[1])])
+
+                    shiftMin = round(shiftMinUM / spacings[0])
+                    shiftMax = round(shiftMaxUM / spacings[0])
+
+                    # enlarge bounding box due to wider context consideration
+                    bbox[0] = max(bbox[0] - shiftMin, 0)
+                    bbox[1] = max(bbox[1] - shiftMin, 0)
+                    bbox[2] = min(bbox[2] + shiftMax, detectedRegion.shape[0] - 1) + 1
+                    bbox[3] = min(bbox[3] + shiftMax, detectedRegion.shape[1] - 1) + 1
+
+                    logger.info('Extract high res patch and segm map...')
+                    try:
+                        img_WSI = np.asarray(np.array(slide.read_region(location=np.array([bbox[1], bbox[0]]), level=0, size=np.array([bbox[3] - bbox[1], bbox[2] - bbox[0]])))[:, :, :3], np.uint8)
+                    except OpenSlideError:
+                        logger.info('#################################### FILE CORRUPTED - IGNORED ####################################')
+                        continue
+                    detectedRegion = detectedRegion[bbox[0]:bbox[2], bbox[1]:bbox[3]]
+
+                    # extract image and resample into target spacing of the structure segmentation network
+                    downsamplingFactor = spacings[0] / targetSpacing  # Rescaling would be very slow using 'rescale' method!
+                    logger.info('Utilized spacing of slide: '+str(spacings[0])+', Resample both patches using factor: ' + str(downsamplingFactor))
+                    segMap = np.asarray(zoom(detectedRegion, downsamplingFactor, order=0), np.bool)
+                    img_WSI = cv2.resize(img_WSI, dsize=tuple(np.flip(segMap.shape)), interpolation=cv2.INTER_LINEAR)
+                    # segMap = np.asarray(np.round(rescale(segMap, downsamplingFactor, order=0, preserve_range=True, multichannel=False)), np.bool)
+                    assert img_WSI.shape[:2] == segMap.shape, "Error: Upsampling via zoom/resize led to unequal resolutions..."
+                    logger.info('Done - size of extracted resampled high res patch: ' + str(img_WSI.shape))
+
+                    if np.min(segMap.shape) < patchImgSize:
+                        logger.info('Detected region smaller than window, thus skipped...')
+                        continue
+
+                    ##### PREPROCESSING DONE - NOW: NETWORK SEGMENTATION PART #####
+                    logger.info('Start segmentation process...')
+
+                    # preprocess img
+                    img_WSI_prep = np.array((img_WSI / 255. - 0.5) / 0.5, np.float32)
+
+                    # tesselate image and tissue prediction results
+                    smallOverlappingPatches = patchify(img_WSI_prep.copy(), patch_size=(patchImgSize, patchImgSize, 3), step=segmentationPatchStride) # CARE: IMAGE DATA AT THE RIGHT AND BOTTOM BORDERS IS LOST !!!
+                    smallOverlappingPatches_FG = patchify(segMap.copy(), patch_size=(patchImgSize, patchImgSize), step=segmentationPatchStride)
+
+                    tileDataset = []
+                    for i in range(smallOverlappingPatches.shape[0]):
+                        for j in range(smallOverlappingPatches.shape[1]):
+                            if smallOverlappingPatches_FG[i,j,:,:].any():
+                                tileDataset.append({'name': '{}-{}'.format(i, j), 'data': torch.from_numpy(smallOverlappingPatches[i, j, 0, :, :, :])})
+
+                    # calculate segmentation patch size since patchify cuts of last patch if not exactly fitting in window
+                    startX = (patchImgSize - patchSegmSize) // 2; startY = startX
+                    endX = segmentationPatchStride * (smallOverlappingPatches.shape[0]-1) + patchSegmSize + startX
+                    endY = segmentationPatchStride * (smallOverlappingPatches.shape[1]-1) + patchSegmSize + startY
+
+                    bigPatchResults = torch.zeros(device="cpu", size=(ftChannelsOutput, endX - startX, endY - startY))
+
+                    # create dataloader for concurrent prediction computation
+                    dataloader = torch.utils.data.DataLoader(tileDataset, batch_size=minibatchSize, shuffle=False)
+                    with torch.no_grad():
+                        for i, data in enumerate(dataloader, 0):
+                            imgBatch = data['data'].permute(0, 3, 1, 2).to(device)
+
+                            prediction = torch.softmax(model(imgBatch), dim=1)  # shape: (minibatchSize, 8, 516, 516)
+                            if applyTestTimeAugmentation:
+                                imgBatch = imgBatch.flip(2)
+                                prediction += torch.softmax(model(imgBatch), 1).flip(2)
+
+                                imgBatch = imgBatch.flip(3)
+                                prediction += torch.softmax(model(imgBatch), 1).flip(3).flip(2)
+
+                                imgBatch = imgBatch.flip(2)
+                                prediction += torch.softmax(model(imgBatch), 1).flip(3)
+
+                            if centerWeighting:
+                                prediction[:, :, patchSegmSize // 4: patchSegmSize // 4 * 3, patchSegmSize // 4: patchSegmSize // 4 * 3] *= centerWeight
+
+                            prediction = prediction.to("cpu")
+
+                            for n, d in zip(data['name'], prediction):
+                                x = int(n.split('-')[0])
+                                y = int(n.split('-')[1])
+                                bigPatchResults[:, x * segmentationPatchStride: x * segmentationPatchStride + patchSegmSize, y * segmentationPatchStride: y * segmentationPatchStride + patchSegmSize] = d
+
+                        bigPatchResults = torch.argmax(bigPatchResults, 0).byte().numpy() # shape: (1536, 2048)
+
+                    logger.info('Predictions generated. Final shape: '+str(bigPatchResults.shape))
+
+                    # Context margin + border patches not fully inside img removed
+                    img_WSI = img_WSI[startX:endX, startY:endY, :]
+                    segMap = segMap[startX:endX, startY:endY]
+                    bgMap = np.logical_not(segMap)
+
+                    # Save cropped foreground segmentation result as overlay
+                    if saveCroppedWSIimg:
+                        logger.info('Saving cropped segmented WSI image...')
+                        saveImage(img_WSI, fullResultPath=resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_fgWSI_({}_{}_{}).png'.format(bbox[0], bbox[1], spacings[0]), figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight))
+
+                    # correct foreground segmentation including all touching vein prediction instances
+                    bigPatchResults[bgMap] = 4 #vein class assignment of bg
+                    temp = bigPatchResults == 4
+                    bgMap = np.logical_xor(clear_border(temp), temp)
+                    segMap = np.logical_not(bgMap)
+                    segMap = binary_fill_holes(segMap)
+                    bgMap = np.logical_not(segMap)
+
+                    # remove small fg components
+                    temp, numberLabeledRegions = label(segMap, struc3)
+                    if numberLabeledRegions > 1:
+                        regionMinPixels = regionMinSizeUM / (targetSpacing * targetSpacing)
+                        regionIDs = np.where(np.array([region.area for region in regionprops(temp)]) > regionMinSizeUM)[0] + 1
+                        segMap = np.isin(temp, regionIDs)
+                        bgMap = np.logical_not(segMap)
+
+                    bigPatchResults[bgMap] = labelBG # color of label 'labelBG' => Purple represents BG just for visualization purposes
+
+                    logger.info('Saving prediction and background overlay results...')
+                    savePredictionOverlayResults(img_WSI, bigPatchResults, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlay.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha)
+
+                    if saveWSIandPredNumpy:
+                        logger.info('Saving numpy img...')
+                        np.save(resultsDirNPYfiles + '/' + fname[:suffixCut] + '_'+str(regionID)+'_resultWSI.npy', img_WSI)
+
+                    logger.info('Start postprocessing...')
+
+                    # remove border class
+                    bigPatchResults[bigPatchResults == 7] = 0
+
+                    # Delete BG to reduce postprocessing overhead
+                    bigPatchResults[bgMap] = 0
+
+                    ################# HOLE FILLING ################
+                    bigPatchResults[binary_fill_holes(bigPatchResults == 1)] = 1  # tubuli
+                    bigPatchResults[binary_fill_holes(bigPatchResults == 4)] = 4  # veins
+                    temp = binary_fill_holes(bigPatchResults == 3)  # tuft
+                    bigPatchResults[binary_fill_holes(np.logical_or(bigPatchResults == 3, bigPatchResults == 2))] = 2  # glom
+                    bigPatchResults[temp] = 3  # tuft
+                    temp = binary_fill_holes(bigPatchResults == 6)  # artery_lumen
+                    bigPatchResults[binary_fill_holes(np.logical_or(bigPatchResults == 5, bigPatchResults == 6))] = 5  # full_artery
+                    bigPatchResults[temp] = 6  # artery_lumen
+
+                    ###### REMOVING TOO SMALL CONNECTED REGIONS ######
+                    temp, _ = label(bigPatchResults == 1)
+                    finalResults_Instance = remove_small_objects(temp, min_size=TUBULI_MIN_SIZE) > 0
+
+                    ############ PERFORM TUBULE DILATION ############
+                    finalResults_Instance, numberTubuli = label(finalResults_Instance) #dtype: int32
+                    finalResults_Instance[finalResults_Instance > 0] += (tubuliInstanceID_StartsWith - 1)
+                    if numberTubuli < 65500:
+                        finalResults_Instance = cv2.dilate(np.asarray(finalResults_Instance, np.uint16), kernel=np.asarray(generate_ball(2), np.uint8), iterations=1) #RESULT TYPE: UINT16
+                    else:
+                        finalResults_Instance = np.asarray(cv2.dilate(np.asarray(finalResults_Instance, np.float64), kernel=np.asarray(generate_ball(2), np.uint8), iterations=1), np.int32)
+
+                    temp, _ = label(np.logical_or(bigPatchResults == 2, bigPatchResults == 3))
+                    finalResults_Instance[remove_small_objects(temp, min_size=GLOM_MIN_SIZE) > 0] = 2
+                    temp, _ = label(bigPatchResults == 3)
+                    finalResults_Instance[np.logical_and(remove_small_objects(temp, min_size=TUFT_MIN_SIZE) > 0, finalResults_Instance==2)] = 3
+                    temp, _ = label(bigPatchResults == 4)
+                    finalResults_Instance[remove_small_objects(temp, min_size=VEIN_MIN_SIZE) > 0] = 4
+                    temp, _ = label(np.logical_or(bigPatchResults == 5, bigPatchResults == 6))
+                    finalResults_Instance[remove_small_objects(temp, min_size=ARTERY_MIN_SIZE) > 0] = 5
+                    temp, _ = label(bigPatchResults == 6)
+                    finalResults_Instance[np.logical_and(remove_small_objects(temp, min_size=LUMEN_MIN_SIZE) > 0, finalResults_Instance==5)] = 6
+
+                    finalResults_Instance = finalResults_Instance * segMap
+
+                    logger.info('Done - Save final instance overlay results...')
+                    savePredictionOverlayResults(img_WSI, finalResults_Instance, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlayFINALInstance.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha)
+
+                    logger.info('Done - Save final non-instance overlay results...')
+                    finalResults = finalResults_Instance.copy()
+                    finalResults[finalResults > tubuliInstanceID_StartsWith] = 1
+                    savePredictionOverlayResults(img_WSI, finalResults, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlayFINAL.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha)
+
+                    finalResults_Instance[bgMap] = labelBG
+
+                    if saveWSIandPredNumpy:
+                        logger.info('Saving numpy final instance prediction results...')
+                        np.save(resultsDirNPYfiles + '/' + fname[:suffixCut] + '_'+str(regionID)+'_finalInstancePrediction.npy', finalResults_Instance)
+
+                logger.info('####################')
+
+        break
+
+except:
+    logger.exception('! Exception !')
+    raise
+
+log.info('%%%% Ended regularly ! %%%%')
-- 
GitLab