# this file recursively performs the automated segmentation of WSIs by applying the tissue segmentation and structure segmentation CNN to all WSIs contained in a specified folder import numpy as np import os import sys import cv2 import torch import math import logging as log from tqdm import tqdm, trange import re from openslide.lowlevel import OpenSlideError import openslide as osl from PIL import Image import matplotlib.pyplot as plt from scipy.ndimage.measurements import label from scipy.ndimage import zoom from scipy.ndimage.morphology import binary_dilation, binary_closing, binary_fill_holes, binary_erosion, binary_opening import scipy.ndimage import scipy as sp from skimage.transform import rescale, resize from skimage.measure import regionprops from skimage.segmentation import flood, flood_fill from skimage.color import rgb2gray from skimage.segmentation import clear_border from skimage import filters from skimage.morphology import remove_small_objects from utils import generate_ball, patchify, unpatchify, saveOverlayResults, savePredictionOverlayResults, savePredictionOverlayResults_Fast, saveImage from model import Custom WSIrootFolder = 'SPECIFIED WSI FOLDER' modelpath = 'STRUCTURE SEGMENTATION MODEL PATH' resultsPath = 'RESULTS PATH' if not os.path.exists(resultsPath): os.makedirs(resultsPath) model_FG_path = 'TISSUE SEGMENTATION MODEL PATH' patchSegmSize = 516 patchImgSize = 640 patchLengthUM = 174. patchImgSize_FG = 512 patchLengthUM_FG = 2500 regionMinSizeUM = 3E5 alpha = 0.3 strideProportion = 0.5 figHeight = 15 minibatchSize = 2 minibatchSize_FG = 1 useAllGPUs = False GPUno = 0 device = torch.device("cuda:" + str(GPUno) if torch.cuda.is_available() else "cpu") tubuliInstanceID_StartsWith = 10 # => tubuli instances start with id 10 ftChannelsOutput = 8 applyTestTimeAugmentation = True centerWeighting = False centerWeight = 3. saveWSICoarseForegroundSegmResults = True # saveCroppedForegroundSegmResults = False saveCroppedWSIimg = True saveWSIandPredNumpy = True TUBULI_MIN_SIZE = 400 GLOM_MIN_SIZE = 1500 TUFT_MIN_SIZE = 500 VEIN_MIN_SIZE = 3000 ARTERY_MIN_SIZE = 400 LUMEN_MIN_SIZE = 20 labelBG = 8 # LOAD STRUCTURE SEGMENTATION MODEL model = Custom(input_ch=3, output_ch=ftChannelsOutput, modelDim=2) model.load_state_dict(torch.load(modelpath, map_location=lambda storage, loc: storage)) model.train(False) model.eval() if useAllGPUs: model = torch.nn.DataParallel(model) # multi-GPUs model = model.to(device) # LOAD TISSUE SEGMENTATION MODEL model_FG = torch.load(model_FG_path) model_FG = model_FG.to(device) model_FG.eval() segmentationPatchStride = int(patchSegmSize * strideProportion) targetSpacing = patchLengthUM / patchSegmSize targetSpacing_FG = patchLengthUM_FG / patchImgSize_FG shiftMinUM = (patchImgSize - patchSegmSize) // 2 * targetSpacing + 2 # + 2 just for sufficient margin reasons shiftMaxUM = ((patchImgSize - patchSegmSize) // 2 + patchSegmSize * strideProportion) * targetSpacing # Set up logger log.basicConfig( level=log.INFO, format='%(asctime)s %(message)s', datefmt='%Y/%m/%d %I:%M:%S %p', handlers=[ log.FileHandler(resultsPath + '/LOGS.log', 'w'), log.StreamHandler(sys.stdout) ]) logger = log.getLogger() struc3 = generate_ball(1) try: # walk through directoy and its subdirectoy recursively for dirName, subdirList, fileList in os.walk(WSIrootFolder): # filter WSIs in current directory fileListWSI = sorted([fname for fname in fileList if (fname.endswith('.svs') or fname.endswith('.ndpi') or fname.endswith('.scn')) and 'PAS' in fname]) if len(fileListWSI) != 0: logger.info(str(len(fileListWSI)) + ' WSIs to be analyzed in directory: ' + dirName) resultsDir = resultsPath + dirName[len(WSIrootFolder):] resultsDirNPYfiles = resultsDir + '/npyFiles' if not os.path.exists(resultsDirNPYfiles): os.makedirs(resultsDirNPYfiles) # traverse through all found WSIs for no, fname in enumerate(fileListWSI): # Extract/print relevant parameters try: slide = osl.OpenSlide(os.path.join(dirName, fname)) logger.info(str(no + 1) + ': WSI:\t' + fname) spacings = np.array([float(slide.properties['openslide.mpp-x']), float(slide.properties['openslide.mpp-y'])]) except: logger.info('Slide {} or its spacings not readable, slide skipped!'.format(fname)) continue levelDims = np.array(slide.level_dimensions) amountLevels = len(levelDims) levelDownsamples = np.asarray(np.round(np.array(slide.level_downsamples)), np.int) logger.info('Spacings: ' + str(spacings)) logger.info('Level Dimensions: ' + str(levelDims)) logger.info('Level Downsamples: '+str(levelDownsamples)) suffixCut = -5 if fname.split('.')[-1] == 'ndpi' else -4 # extract the WSI level that is closest to the target spacing of the tissue segmentation network (increasing efficiency instead of simply taking full resolution, finest level 0) spacingFactorX = spacings[0] / targetSpacing_FG spacingFactorY = spacings[1] / targetSpacing_FG x_scaled = round(levelDims[0][0] * spacingFactorX) y_scaled = round(levelDims[0][1] * spacingFactorY) usedLevel = np.argwhere(np.all(levelDims > [x_scaled, y_scaled], 1) == True).max() # resample to target spacing logger.info('Image size resampled to FG spacing would be {}, {}, thus level {} with resolution {} chosen as resampling point!'.format(x_scaled, y_scaled, usedLevel, levelDims[usedLevel])) spacingOnUsedLevelX = spacings[0] * levelDownsamples[usedLevel] spacingOnUsedLevelY = spacings[1] * levelDownsamples[usedLevel] downsamplingFactorX = spacingOnUsedLevelX / targetSpacing_FG downsamplingFactorY = spacingOnUsedLevelY / targetSpacing_FG imgWSI = np.array(slide.read_region(location=np.array([0, 0]), level=usedLevel, size=levelDims[usedLevel]))[:, :, :3] d1 = int(round(imgWSI.shape[1] * downsamplingFactorX)) d2 = int(round(imgWSI.shape[0] * downsamplingFactorY)) imgWSI = cv2.resize(imgWSI, dsize=(d1, d2), interpolation=cv2.INTER_LINEAR) #dtype: uint8, size: d2 x d1 x 3 imgWSIzeropadded = np.zeros(shape=(d2+patchImgSize_FG-1, d1+patchImgSize_FG-1, 3), dtype=np.float32) imgWSIzeropadded[:d2,:d1,:] = imgWSI # tesselate resampled image smallOverlappingPatches = patchify(imgWSIzeropadded, patch_size=(patchImgSize_FG, patchImgSize_FG, 3), step=patchImgSize_FG) # CARE: IMAGE DATA AT THE RIGHT AND BOTTOM BORDERS IS LOST !!! tileDataset = [] # with tqdm(total=smallOverlappingPatches.shape[0] * smallOverlappingPatches.shape[1]) as pbar: for i in range(smallOverlappingPatches.shape[0]): for j in range(smallOverlappingPatches.shape[1]): tileDataset.append({'name': '{}-{}'.format(i, j), 'data': torch.from_numpy(smallOverlappingPatches[i, j, 0, :, :, :])}) # pbar.update(1) img_mask = np.zeros(shape=(imgWSIzeropadded.shape[0],imgWSIzeropadded.shape[1]), dtype=np.bool) # create dataloader for concurrent tile processing and prediction computation dataloader = torch.utils.data.DataLoader(tileDataset, batch_size=minibatchSize_FG, shuffle=False) with torch.no_grad(): for i, data in enumerate(dataloader, 0): imgBatch = data['data'].permute(0, 3, 1, 2).to(device) prediction = model_FG(imgBatch) # prediction should have shape (1,2,512,512) prediction = (prediction[:,1,:,:] > prediction[:,0,:,:]).to("cpu").numpy() for n, d in zip(data['name'], prediction): x = int(n.split('-')[0]) y = int(n.split('-')[1]) img_mask[x * patchImgSize_FG: (x+1) * patchImgSize_FG, y * patchImgSize_FG: (y+1) * patchImgSize_FG] = d img_mask = img_mask[:d2,:d1] # postprocessing img_mask = binary_fill_holes(img_mask) # remove connected regions if too small regionMinPixels = regionMinSizeUM / (targetSpacing_FG * targetSpacing_FG) img_mask, _ = label(img_mask) labeledRegions, numberFGRegions = label(remove_small_objects(img_mask, min_size=regionMinPixels)) if numberFGRegions < 256: labeledRegions = np.asarray(labeledRegions, np.uint8) if saveWSICoarseForegroundSegmResults: logger.info('Saving WSI-level coarse FG segmentation results...') savePredictionOverlayResults(imgWSI, labeledRegions, alpha=alpha, figSize=(labeledRegions.shape[1]/labeledRegions.shape[0]*figHeight, figHeight), fullResultPath=resultsDir + '/' + fname[:suffixCut] + '_0_fgSeg.png') labeledRegions = cv2.resize(labeledRegions, dsize=(levelDims[0][0],levelDims[0][1]), interpolation = cv2.INTER_NEAREST) # FG RESULTS ON WSI-RESOLUTION, UINT8, REGION IDs logger.info('In total -> '+str(numberFGRegions)+' <- regions on WSI detected!') # process all detected tissue regions separately for regionID in range(1, numberFGRegions+1): logger.info('#######\n Extract foreground region ' + str(regionID) + '...') detectedRegion = labeledRegions == regionID # compute bounding box and how much to enlarge bbox to consider wider context utilization (especially for patchify) temp = np.where(detectedRegion == 1) bbox = np.array([np.min(temp[0]), np.min(temp[1]), np.max(temp[0]), np.max(temp[1])]) shiftMin = round(shiftMinUM / spacings[0]) shiftMax = round(shiftMaxUM / spacings[0]) # enlarge bounding box due to wider context consideration bbox[0] = max(bbox[0] - shiftMin, 0) bbox[1] = max(bbox[1] - shiftMin, 0) bbox[2] = min(bbox[2] + shiftMax, detectedRegion.shape[0] - 1) + 1 bbox[3] = min(bbox[3] + shiftMax, detectedRegion.shape[1] - 1) + 1 logger.info('Extract high res patch and segm map...') try: img_WSI = np.asarray(np.array(slide.read_region(location=np.array([bbox[1], bbox[0]]), level=0, size=np.array([bbox[3] - bbox[1], bbox[2] - bbox[0]])))[:, :, :3], np.uint8) except OpenSlideError: logger.info('#################################### FILE CORRUPTED - IGNORED ####################################') continue detectedRegion = detectedRegion[bbox[0]:bbox[2], bbox[1]:bbox[3]] # extract image and resample into target spacing of the structure segmentation network downsamplingFactor = spacings[0] / targetSpacing # Rescaling would be very slow using 'rescale' method! logger.info('Utilized spacing of slide: '+str(spacings[0])+', Resample both patches using factor: ' + str(downsamplingFactor)) segMap = np.asarray(zoom(detectedRegion, downsamplingFactor, order=0), np.bool) img_WSI = cv2.resize(img_WSI, dsize=tuple(np.flip(segMap.shape)), interpolation=cv2.INTER_LINEAR) # segMap = np.asarray(np.round(rescale(segMap, downsamplingFactor, order=0, preserve_range=True, multichannel=False)), np.bool) assert img_WSI.shape[:2] == segMap.shape, "Error: Upsampling via zoom/resize led to unequal resolutions..." logger.info('Done - size of extracted resampled high res patch: ' + str(img_WSI.shape)) if np.min(segMap.shape) < patchImgSize: logger.info('Detected region smaller than window, thus skipped...') continue ##### PREPROCESSING DONE - NOW: NETWORK SEGMENTATION PART ##### logger.info('Start segmentation process...') # preprocess img img_WSI_prep = np.array((img_WSI / 255. - 0.5) / 0.5, np.float32) # tesselate image and tissue prediction results smallOverlappingPatches = patchify(img_WSI_prep.copy(), patch_size=(patchImgSize, patchImgSize, 3), step=segmentationPatchStride) # CARE: IMAGE DATA AT THE RIGHT AND BOTTOM BORDERS IS LOST !!! smallOverlappingPatches_FG = patchify(segMap.copy(), patch_size=(patchImgSize, patchImgSize), step=segmentationPatchStride) tileDataset = [] for i in range(smallOverlappingPatches.shape[0]): for j in range(smallOverlappingPatches.shape[1]): if smallOverlappingPatches_FG[i,j,:,:].any(): tileDataset.append({'name': '{}-{}'.format(i, j), 'data': torch.from_numpy(smallOverlappingPatches[i, j, 0, :, :, :])}) # calculate segmentation patch size since patchify cuts of last patch if not exactly fitting in window startX = (patchImgSize - patchSegmSize) // 2; startY = startX endX = segmentationPatchStride * (smallOverlappingPatches.shape[0]-1) + patchSegmSize + startX endY = segmentationPatchStride * (smallOverlappingPatches.shape[1]-1) + patchSegmSize + startY bigPatchResults = torch.zeros(device="cpu", size=(ftChannelsOutput, endX - startX, endY - startY)) # create dataloader for concurrent prediction computation dataloader = torch.utils.data.DataLoader(tileDataset, batch_size=minibatchSize, shuffle=False) with torch.no_grad(): for i, data in enumerate(dataloader, 0): imgBatch = data['data'].permute(0, 3, 1, 2).to(device) prediction = torch.softmax(model(imgBatch), dim=1) # shape: (minibatchSize, 8, 516, 516) if applyTestTimeAugmentation: imgBatch = imgBatch.flip(2) prediction += torch.softmax(model(imgBatch), 1).flip(2) imgBatch = imgBatch.flip(3) prediction += torch.softmax(model(imgBatch), 1).flip(3).flip(2) imgBatch = imgBatch.flip(2) prediction += torch.softmax(model(imgBatch), 1).flip(3) if centerWeighting: prediction[:, :, patchSegmSize // 4: patchSegmSize // 4 * 3, patchSegmSize // 4: patchSegmSize // 4 * 3] *= centerWeight prediction = prediction.to("cpu") for n, d in zip(data['name'], prediction): x = int(n.split('-')[0]) y = int(n.split('-')[1]) bigPatchResults[:, x * segmentationPatchStride: x * segmentationPatchStride + patchSegmSize, y * segmentationPatchStride: y * segmentationPatchStride + patchSegmSize] = d bigPatchResults = torch.argmax(bigPatchResults, 0).byte().numpy() # shape: (1536, 2048) logger.info('Predictions generated. Final shape: '+str(bigPatchResults.shape)) # Context margin + border patches not fully inside img removed img_WSI = img_WSI[startX:endX, startY:endY, :] segMap = segMap[startX:endX, startY:endY] bgMap = np.logical_not(segMap) # Save cropped foreground segmentation result as overlay if saveCroppedWSIimg: logger.info('Saving cropped segmented WSI image...') saveImage(img_WSI, fullResultPath=resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_fgWSI_({}_{}_{}).png'.format(bbox[0], bbox[1], spacings[0]), figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight)) # correct foreground segmentation including all touching vein prediction instances bigPatchResults[bgMap] = 4 #vein class assignment of bg temp = bigPatchResults == 4 bgMap = np.logical_xor(clear_border(temp), temp) segMap = np.logical_not(bgMap) segMap = binary_fill_holes(segMap) bgMap = np.logical_not(segMap) # remove small fg components temp, numberLabeledRegions = label(segMap, struc3) if numberLabeledRegions > 1: regionMinPixels = regionMinSizeUM / (targetSpacing * targetSpacing) regionIDs = np.where(np.array([region.area for region in regionprops(temp)]) > regionMinSizeUM)[0] + 1 segMap = np.isin(temp, regionIDs) bgMap = np.logical_not(segMap) bigPatchResults[bgMap] = labelBG # color of label 'labelBG' => Purple represents BG just for visualization purposes logger.info('Saving prediction and background overlay results...') savePredictionOverlayResults(img_WSI, bigPatchResults, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlay.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha) if saveWSIandPredNumpy: logger.info('Saving numpy img...') np.save(resultsDirNPYfiles + '/' + fname[:suffixCut] + '_'+str(regionID)+'_resultWSI.npy', img_WSI) logger.info('Start postprocessing...') # remove border class bigPatchResults[bigPatchResults == 7] = 0 # Delete BG to reduce postprocessing overhead bigPatchResults[bgMap] = 0 ################# HOLE FILLING ################ bigPatchResults[binary_fill_holes(bigPatchResults == 1)] = 1 # tubuli bigPatchResults[binary_fill_holes(bigPatchResults == 4)] = 4 # veins temp = binary_fill_holes(bigPatchResults == 3) # tuft bigPatchResults[binary_fill_holes(np.logical_or(bigPatchResults == 3, bigPatchResults == 2))] = 2 # glom bigPatchResults[temp] = 3 # tuft temp = binary_fill_holes(bigPatchResults == 6) # artery_lumen bigPatchResults[binary_fill_holes(np.logical_or(bigPatchResults == 5, bigPatchResults == 6))] = 5 # full_artery bigPatchResults[temp] = 6 # artery_lumen ###### REMOVING TOO SMALL CONNECTED REGIONS ###### temp, _ = label(bigPatchResults == 1) finalResults_Instance = remove_small_objects(temp, min_size=TUBULI_MIN_SIZE) > 0 ############ PERFORM TUBULE DILATION ############ finalResults_Instance, numberTubuli = label(finalResults_Instance) #dtype: int32 finalResults_Instance[finalResults_Instance > 0] += (tubuliInstanceID_StartsWith - 1) if numberTubuli < 65500: finalResults_Instance = cv2.dilate(np.asarray(finalResults_Instance, np.uint16), kernel=np.asarray(generate_ball(2), np.uint8), iterations=1) #RESULT TYPE: UINT16 else: finalResults_Instance = np.asarray(cv2.dilate(np.asarray(finalResults_Instance, np.float64), kernel=np.asarray(generate_ball(2), np.uint8), iterations=1), np.int32) temp, _ = label(np.logical_or(bigPatchResults == 2, bigPatchResults == 3)) finalResults_Instance[remove_small_objects(temp, min_size=GLOM_MIN_SIZE) > 0] = 2 temp, _ = label(bigPatchResults == 3) finalResults_Instance[np.logical_and(remove_small_objects(temp, min_size=TUFT_MIN_SIZE) > 0, finalResults_Instance==2)] = 3 temp, _ = label(bigPatchResults == 4) finalResults_Instance[remove_small_objects(temp, min_size=VEIN_MIN_SIZE) > 0] = 4 temp, _ = label(np.logical_or(bigPatchResults == 5, bigPatchResults == 6)) finalResults_Instance[remove_small_objects(temp, min_size=ARTERY_MIN_SIZE) > 0] = 5 temp, _ = label(bigPatchResults == 6) finalResults_Instance[np.logical_and(remove_small_objects(temp, min_size=LUMEN_MIN_SIZE) > 0, finalResults_Instance==5)] = 6 finalResults_Instance = finalResults_Instance * segMap logger.info('Done - Save final instance overlay results...') savePredictionOverlayResults(img_WSI, finalResults_Instance, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlayFINALInstance.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha) logger.info('Done - Save final non-instance overlay results...') finalResults = finalResults_Instance.copy() finalResults[finalResults > tubuliInstanceID_StartsWith] = 1 savePredictionOverlayResults(img_WSI, finalResults, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlayFINAL.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha) finalResults_Instance[bgMap] = labelBG if saveWSIandPredNumpy: logger.info('Saving numpy final instance prediction results...') np.save(resultsDirNPYfiles + '/' + fname[:suffixCut] + '_'+str(regionID)+'_finalInstancePrediction.npy', finalResults_Instance) logger.info('####################') break except: logger.exception('! Exception !') raise log.info('%%%% Ended regularly ! %%%%')