From 7ddbfbc1b02ac4ac6bf98099b975b79b5764dff1 Mon Sep 17 00:00:00 2001 From: Nassim Bouteldja <nbouteldja@ukaachen.de> Date: Wed, 13 Apr 2022 21:37:38 +0200 Subject: [PATCH] Upload New File --- segment_WSI.py | 417 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 417 insertions(+) create mode 100644 segment_WSI.py diff --git a/segment_WSI.py b/segment_WSI.py new file mode 100644 index 0000000..b27507e --- /dev/null +++ b/segment_WSI.py @@ -0,0 +1,417 @@ +# this file recursively performs the automated segmentation of WSIs by applying the tissue segmentation and structure segmentation CNN to all WSIs contained in a specified folder + +import numpy as np +import os +import sys +import cv2 +import torch +import math +import logging as log +from tqdm import tqdm, trange +import re + +from openslide.lowlevel import OpenSlideError +import openslide as osl +from PIL import Image +import matplotlib.pyplot as plt +from scipy.ndimage.measurements import label +from scipy.ndimage import zoom +from scipy.ndimage.morphology import binary_dilation, binary_closing, binary_fill_holes, binary_erosion, binary_opening +import scipy.ndimage +import scipy as sp +from skimage.transform import rescale, resize +from skimage.measure import regionprops +from skimage.segmentation import flood, flood_fill +from skimage.color import rgb2gray +from skimage.segmentation import clear_border +from skimage import filters +from skimage.morphology import remove_small_objects + +from utils import generate_ball, patchify, unpatchify, saveOverlayResults, savePredictionOverlayResults, savePredictionOverlayResults_Fast, saveImage +from model import Custom + + + + +WSIrootFolder = 'SPECIFIED WSI FOLDER' +modelpath = 'STRUCTURE SEGMENTATION MODEL PATH' +resultsPath = 'RESULTS PATH' + +if not os.path.exists(resultsPath): + os.makedirs(resultsPath) + +model_FG_path = 'TISSUE SEGMENTATION MODEL PATH' + +patchSegmSize = 516 +patchImgSize = 640 +patchLengthUM = 174. + +patchImgSize_FG = 512 +patchLengthUM_FG = 2500 +regionMinSizeUM = 3E5 + +alpha = 0.3 +strideProportion = 0.5 +figHeight = 15 +minibatchSize = 2 +minibatchSize_FG = 1 +useAllGPUs = False +GPUno = 0 +device = torch.device("cuda:" + str(GPUno) if torch.cuda.is_available() else "cpu") + +tubuliInstanceID_StartsWith = 10 # => tubuli instances start with id 10 + +ftChannelsOutput = 8 +applyTestTimeAugmentation = True +centerWeighting = False +centerWeight = 3. + +saveWSICoarseForegroundSegmResults = True +# saveCroppedForegroundSegmResults = False +saveCroppedWSIimg = True +saveWSIandPredNumpy = True + +TUBULI_MIN_SIZE = 400 +GLOM_MIN_SIZE = 1500 +TUFT_MIN_SIZE = 500 +VEIN_MIN_SIZE = 3000 +ARTERY_MIN_SIZE = 400 +LUMEN_MIN_SIZE = 20 + +labelBG = 8 + +# LOAD STRUCTURE SEGMENTATION MODEL +model = Custom(input_ch=3, output_ch=ftChannelsOutput, modelDim=2) +model.load_state_dict(torch.load(modelpath, map_location=lambda storage, loc: storage)) +model.train(False) +model.eval() + +if useAllGPUs: + model = torch.nn.DataParallel(model) # multi-GPUs +model = model.to(device) + +# LOAD TISSUE SEGMENTATION MODEL +model_FG = torch.load(model_FG_path) +model_FG = model_FG.to(device) +model_FG.eval() + +segmentationPatchStride = int(patchSegmSize * strideProportion) +targetSpacing = patchLengthUM / patchSegmSize + +targetSpacing_FG = patchLengthUM_FG / patchImgSize_FG + +shiftMinUM = (patchImgSize - patchSegmSize) // 2 * targetSpacing + 2 # + 2 just for sufficient margin reasons +shiftMaxUM = ((patchImgSize - patchSegmSize) // 2 + patchSegmSize * strideProportion) * targetSpacing + +# Set up logger +log.basicConfig( + level=log.INFO, + format='%(asctime)s %(message)s', + datefmt='%Y/%m/%d %I:%M:%S %p', + handlers=[ + log.FileHandler(resultsPath + '/LOGS.log', 'w'), + log.StreamHandler(sys.stdout) + ]) +logger = log.getLogger() + +struc3 = generate_ball(1) + +try: + # walk through directoy and its subdirectoy recursively + for dirName, subdirList, fileList in os.walk(WSIrootFolder): + + # filter WSIs in current directory + fileListWSI = sorted([fname for fname in fileList if (fname.endswith('.svs') or fname.endswith('.ndpi') or fname.endswith('.scn')) and 'PAS' in fname]) + + if len(fileListWSI) != 0: + logger.info(str(len(fileListWSI)) + ' WSIs to be analyzed in directory: ' + dirName) + + resultsDir = resultsPath + dirName[len(WSIrootFolder):] + resultsDirNPYfiles = resultsDir + '/npyFiles' + if not os.path.exists(resultsDirNPYfiles): + os.makedirs(resultsDirNPYfiles) + + # traverse through all found WSIs + for no, fname in enumerate(fileListWSI): + + # Extract/print relevant parameters + try: + slide = osl.OpenSlide(os.path.join(dirName, fname)) + logger.info(str(no + 1) + ': WSI:\t' + fname) + spacings = np.array([float(slide.properties['openslide.mpp-x']), float(slide.properties['openslide.mpp-y'])]) + except: + logger.info('Slide {} or its spacings not readable, slide skipped!'.format(fname)) + continue + levelDims = np.array(slide.level_dimensions) + amountLevels = len(levelDims) + levelDownsamples = np.asarray(np.round(np.array(slide.level_downsamples)), np.int) + + logger.info('Spacings: ' + str(spacings)) + logger.info('Level Dimensions: ' + str(levelDims)) + logger.info('Level Downsamples: '+str(levelDownsamples)) + + suffixCut = -5 if fname.split('.')[-1] == 'ndpi' else -4 + + # extract the WSI level that is closest to the target spacing of the tissue segmentation network (increasing efficiency instead of simply taking full resolution, finest level 0) + spacingFactorX = spacings[0] / targetSpacing_FG + spacingFactorY = spacings[1] / targetSpacing_FG + x_scaled = round(levelDims[0][0] * spacingFactorX) + y_scaled = round(levelDims[0][1] * spacingFactorY) + + usedLevel = np.argwhere(np.all(levelDims > [x_scaled, y_scaled], 1) == True).max() + + # resample to target spacing + logger.info('Image size resampled to FG spacing would be {}, {}, thus level {} with resolution {} chosen as resampling point!'.format(x_scaled, y_scaled, usedLevel, levelDims[usedLevel])) + spacingOnUsedLevelX = spacings[0] * levelDownsamples[usedLevel] + spacingOnUsedLevelY = spacings[1] * levelDownsamples[usedLevel] + downsamplingFactorX = spacingOnUsedLevelX / targetSpacing_FG + downsamplingFactorY = spacingOnUsedLevelY / targetSpacing_FG + imgWSI = np.array(slide.read_region(location=np.array([0, 0]), level=usedLevel, size=levelDims[usedLevel]))[:, :, :3] + + d1 = int(round(imgWSI.shape[1] * downsamplingFactorX)) + d2 = int(round(imgWSI.shape[0] * downsamplingFactorY)) + + imgWSI = cv2.resize(imgWSI, dsize=(d1, d2), interpolation=cv2.INTER_LINEAR) #dtype: uint8, size: d2 x d1 x 3 + imgWSIzeropadded = np.zeros(shape=(d2+patchImgSize_FG-1, d1+patchImgSize_FG-1, 3), dtype=np.float32) + imgWSIzeropadded[:d2,:d1,:] = imgWSI + + # tesselate resampled image + smallOverlappingPatches = patchify(imgWSIzeropadded, patch_size=(patchImgSize_FG, patchImgSize_FG, 3), step=patchImgSize_FG) # CARE: IMAGE DATA AT THE RIGHT AND BOTTOM BORDERS IS LOST !!! + + tileDataset = [] + # with tqdm(total=smallOverlappingPatches.shape[0] * smallOverlappingPatches.shape[1]) as pbar: + for i in range(smallOverlappingPatches.shape[0]): + for j in range(smallOverlappingPatches.shape[1]): + tileDataset.append({'name': '{}-{}'.format(i, j), 'data': torch.from_numpy(smallOverlappingPatches[i, j, 0, :, :, :])}) + # pbar.update(1) + + img_mask = np.zeros(shape=(imgWSIzeropadded.shape[0],imgWSIzeropadded.shape[1]), dtype=np.bool) + + # create dataloader for concurrent tile processing and prediction computation + dataloader = torch.utils.data.DataLoader(tileDataset, batch_size=minibatchSize_FG, shuffle=False) + with torch.no_grad(): + for i, data in enumerate(dataloader, 0): + imgBatch = data['data'].permute(0, 3, 1, 2).to(device) + + prediction = model_FG(imgBatch) # prediction should have shape (1,2,512,512) + prediction = (prediction[:,1,:,:] > prediction[:,0,:,:]).to("cpu").numpy() + + for n, d in zip(data['name'], prediction): + x = int(n.split('-')[0]) + y = int(n.split('-')[1]) + img_mask[x * patchImgSize_FG: (x+1) * patchImgSize_FG, y * patchImgSize_FG: (y+1) * patchImgSize_FG] = d + + img_mask = img_mask[:d2,:d1] + + # postprocessing + img_mask = binary_fill_holes(img_mask) + + # remove connected regions if too small + regionMinPixels = regionMinSizeUM / (targetSpacing_FG * targetSpacing_FG) + img_mask, _ = label(img_mask) + labeledRegions, numberFGRegions = label(remove_small_objects(img_mask, min_size=regionMinPixels)) + if numberFGRegions < 256: + labeledRegions = np.asarray(labeledRegions, np.uint8) + + if saveWSICoarseForegroundSegmResults: + logger.info('Saving WSI-level coarse FG segmentation results...') + savePredictionOverlayResults(imgWSI, labeledRegions, alpha=alpha, figSize=(labeledRegions.shape[1]/labeledRegions.shape[0]*figHeight, figHeight), fullResultPath=resultsDir + '/' + fname[:suffixCut] + '_0_fgSeg.png') + + labeledRegions = cv2.resize(labeledRegions, dsize=(levelDims[0][0],levelDims[0][1]), interpolation = cv2.INTER_NEAREST) # FG RESULTS ON WSI-RESOLUTION, UINT8, REGION IDs + + logger.info('In total -> '+str(numberFGRegions)+' <- regions on WSI detected!') + + # process all detected tissue regions separately + for regionID in range(1, numberFGRegions+1): + logger.info('#######\n Extract foreground region ' + str(regionID) + '...') + detectedRegion = labeledRegions == regionID + + # compute bounding box and how much to enlarge bbox to consider wider context utilization (especially for patchify) + temp = np.where(detectedRegion == 1) + bbox = np.array([np.min(temp[0]), np.min(temp[1]), np.max(temp[0]), np.max(temp[1])]) + + shiftMin = round(shiftMinUM / spacings[0]) + shiftMax = round(shiftMaxUM / spacings[0]) + + # enlarge bounding box due to wider context consideration + bbox[0] = max(bbox[0] - shiftMin, 0) + bbox[1] = max(bbox[1] - shiftMin, 0) + bbox[2] = min(bbox[2] + shiftMax, detectedRegion.shape[0] - 1) + 1 + bbox[3] = min(bbox[3] + shiftMax, detectedRegion.shape[1] - 1) + 1 + + logger.info('Extract high res patch and segm map...') + try: + img_WSI = np.asarray(np.array(slide.read_region(location=np.array([bbox[1], bbox[0]]), level=0, size=np.array([bbox[3] - bbox[1], bbox[2] - bbox[0]])))[:, :, :3], np.uint8) + except OpenSlideError: + logger.info('#################################### FILE CORRUPTED - IGNORED ####################################') + continue + detectedRegion = detectedRegion[bbox[0]:bbox[2], bbox[1]:bbox[3]] + + # extract image and resample into target spacing of the structure segmentation network + downsamplingFactor = spacings[0] / targetSpacing # Rescaling would be very slow using 'rescale' method! + logger.info('Utilized spacing of slide: '+str(spacings[0])+', Resample both patches using factor: ' + str(downsamplingFactor)) + segMap = np.asarray(zoom(detectedRegion, downsamplingFactor, order=0), np.bool) + img_WSI = cv2.resize(img_WSI, dsize=tuple(np.flip(segMap.shape)), interpolation=cv2.INTER_LINEAR) + # segMap = np.asarray(np.round(rescale(segMap, downsamplingFactor, order=0, preserve_range=True, multichannel=False)), np.bool) + assert img_WSI.shape[:2] == segMap.shape, "Error: Upsampling via zoom/resize led to unequal resolutions..." + logger.info('Done - size of extracted resampled high res patch: ' + str(img_WSI.shape)) + + if np.min(segMap.shape) < patchImgSize: + logger.info('Detected region smaller than window, thus skipped...') + continue + + ##### PREPROCESSING DONE - NOW: NETWORK SEGMENTATION PART ##### + logger.info('Start segmentation process...') + + # preprocess img + img_WSI_prep = np.array((img_WSI / 255. - 0.5) / 0.5, np.float32) + + # tesselate image and tissue prediction results + smallOverlappingPatches = patchify(img_WSI_prep.copy(), patch_size=(patchImgSize, patchImgSize, 3), step=segmentationPatchStride) # CARE: IMAGE DATA AT THE RIGHT AND BOTTOM BORDERS IS LOST !!! + smallOverlappingPatches_FG = patchify(segMap.copy(), patch_size=(patchImgSize, patchImgSize), step=segmentationPatchStride) + + tileDataset = [] + for i in range(smallOverlappingPatches.shape[0]): + for j in range(smallOverlappingPatches.shape[1]): + if smallOverlappingPatches_FG[i,j,:,:].any(): + tileDataset.append({'name': '{}-{}'.format(i, j), 'data': torch.from_numpy(smallOverlappingPatches[i, j, 0, :, :, :])}) + + # calculate segmentation patch size since patchify cuts of last patch if not exactly fitting in window + startX = (patchImgSize - patchSegmSize) // 2; startY = startX + endX = segmentationPatchStride * (smallOverlappingPatches.shape[0]-1) + patchSegmSize + startX + endY = segmentationPatchStride * (smallOverlappingPatches.shape[1]-1) + patchSegmSize + startY + + bigPatchResults = torch.zeros(device="cpu", size=(ftChannelsOutput, endX - startX, endY - startY)) + + # create dataloader for concurrent prediction computation + dataloader = torch.utils.data.DataLoader(tileDataset, batch_size=minibatchSize, shuffle=False) + with torch.no_grad(): + for i, data in enumerate(dataloader, 0): + imgBatch = data['data'].permute(0, 3, 1, 2).to(device) + + prediction = torch.softmax(model(imgBatch), dim=1) # shape: (minibatchSize, 8, 516, 516) + if applyTestTimeAugmentation: + imgBatch = imgBatch.flip(2) + prediction += torch.softmax(model(imgBatch), 1).flip(2) + + imgBatch = imgBatch.flip(3) + prediction += torch.softmax(model(imgBatch), 1).flip(3).flip(2) + + imgBatch = imgBatch.flip(2) + prediction += torch.softmax(model(imgBatch), 1).flip(3) + + if centerWeighting: + prediction[:, :, patchSegmSize // 4: patchSegmSize // 4 * 3, patchSegmSize // 4: patchSegmSize // 4 * 3] *= centerWeight + + prediction = prediction.to("cpu") + + for n, d in zip(data['name'], prediction): + x = int(n.split('-')[0]) + y = int(n.split('-')[1]) + bigPatchResults[:, x * segmentationPatchStride: x * segmentationPatchStride + patchSegmSize, y * segmentationPatchStride: y * segmentationPatchStride + patchSegmSize] = d + + bigPatchResults = torch.argmax(bigPatchResults, 0).byte().numpy() # shape: (1536, 2048) + + logger.info('Predictions generated. Final shape: '+str(bigPatchResults.shape)) + + # Context margin + border patches not fully inside img removed + img_WSI = img_WSI[startX:endX, startY:endY, :] + segMap = segMap[startX:endX, startY:endY] + bgMap = np.logical_not(segMap) + + # Save cropped foreground segmentation result as overlay + if saveCroppedWSIimg: + logger.info('Saving cropped segmented WSI image...') + saveImage(img_WSI, fullResultPath=resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_fgWSI_({}_{}_{}).png'.format(bbox[0], bbox[1], spacings[0]), figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight)) + + # correct foreground segmentation including all touching vein prediction instances + bigPatchResults[bgMap] = 4 #vein class assignment of bg + temp = bigPatchResults == 4 + bgMap = np.logical_xor(clear_border(temp), temp) + segMap = np.logical_not(bgMap) + segMap = binary_fill_holes(segMap) + bgMap = np.logical_not(segMap) + + # remove small fg components + temp, numberLabeledRegions = label(segMap, struc3) + if numberLabeledRegions > 1: + regionMinPixels = regionMinSizeUM / (targetSpacing * targetSpacing) + regionIDs = np.where(np.array([region.area for region in regionprops(temp)]) > regionMinSizeUM)[0] + 1 + segMap = np.isin(temp, regionIDs) + bgMap = np.logical_not(segMap) + + bigPatchResults[bgMap] = labelBG # color of label 'labelBG' => Purple represents BG just for visualization purposes + + logger.info('Saving prediction and background overlay results...') + savePredictionOverlayResults(img_WSI, bigPatchResults, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlay.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha) + + if saveWSIandPredNumpy: + logger.info('Saving numpy img...') + np.save(resultsDirNPYfiles + '/' + fname[:suffixCut] + '_'+str(regionID)+'_resultWSI.npy', img_WSI) + + logger.info('Start postprocessing...') + + # remove border class + bigPatchResults[bigPatchResults == 7] = 0 + + # Delete BG to reduce postprocessing overhead + bigPatchResults[bgMap] = 0 + + ################# HOLE FILLING ################ + bigPatchResults[binary_fill_holes(bigPatchResults == 1)] = 1 # tubuli + bigPatchResults[binary_fill_holes(bigPatchResults == 4)] = 4 # veins + temp = binary_fill_holes(bigPatchResults == 3) # tuft + bigPatchResults[binary_fill_holes(np.logical_or(bigPatchResults == 3, bigPatchResults == 2))] = 2 # glom + bigPatchResults[temp] = 3 # tuft + temp = binary_fill_holes(bigPatchResults == 6) # artery_lumen + bigPatchResults[binary_fill_holes(np.logical_or(bigPatchResults == 5, bigPatchResults == 6))] = 5 # full_artery + bigPatchResults[temp] = 6 # artery_lumen + + ###### REMOVING TOO SMALL CONNECTED REGIONS ###### + temp, _ = label(bigPatchResults == 1) + finalResults_Instance = remove_small_objects(temp, min_size=TUBULI_MIN_SIZE) > 0 + + ############ PERFORM TUBULE DILATION ############ + finalResults_Instance, numberTubuli = label(finalResults_Instance) #dtype: int32 + finalResults_Instance[finalResults_Instance > 0] += (tubuliInstanceID_StartsWith - 1) + if numberTubuli < 65500: + finalResults_Instance = cv2.dilate(np.asarray(finalResults_Instance, np.uint16), kernel=np.asarray(generate_ball(2), np.uint8), iterations=1) #RESULT TYPE: UINT16 + else: + finalResults_Instance = np.asarray(cv2.dilate(np.asarray(finalResults_Instance, np.float64), kernel=np.asarray(generate_ball(2), np.uint8), iterations=1), np.int32) + + temp, _ = label(np.logical_or(bigPatchResults == 2, bigPatchResults == 3)) + finalResults_Instance[remove_small_objects(temp, min_size=GLOM_MIN_SIZE) > 0] = 2 + temp, _ = label(bigPatchResults == 3) + finalResults_Instance[np.logical_and(remove_small_objects(temp, min_size=TUFT_MIN_SIZE) > 0, finalResults_Instance==2)] = 3 + temp, _ = label(bigPatchResults == 4) + finalResults_Instance[remove_small_objects(temp, min_size=VEIN_MIN_SIZE) > 0] = 4 + temp, _ = label(np.logical_or(bigPatchResults == 5, bigPatchResults == 6)) + finalResults_Instance[remove_small_objects(temp, min_size=ARTERY_MIN_SIZE) > 0] = 5 + temp, _ = label(bigPatchResults == 6) + finalResults_Instance[np.logical_and(remove_small_objects(temp, min_size=LUMEN_MIN_SIZE) > 0, finalResults_Instance==5)] = 6 + + finalResults_Instance = finalResults_Instance * segMap + + logger.info('Done - Save final instance overlay results...') + savePredictionOverlayResults(img_WSI, finalResults_Instance, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlayFINALInstance.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha) + + logger.info('Done - Save final non-instance overlay results...') + finalResults = finalResults_Instance.copy() + finalResults[finalResults > tubuliInstanceID_StartsWith] = 1 + savePredictionOverlayResults(img_WSI, finalResults, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlayFINAL.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha) + + finalResults_Instance[bgMap] = labelBG + + if saveWSIandPredNumpy: + logger.info('Saving numpy final instance prediction results...') + np.save(resultsDirNPYfiles + '/' + fname[:suffixCut] + '_'+str(regionID)+'_finalInstancePrediction.npy', finalResults_Instance) + + logger.info('####################') + + break + +except: + logger.exception('! Exception !') + raise + +log.info('%%%% Ended regularly ! %%%%') -- GitLab