Skip to content
Snippets Groups Projects
Select Git revision
  • 0a5e566b83ec543832d310cb6781ac3bec12cf17
  • 5.4 default protected
  • 5.5
  • dev/5.5
  • dev/5.4
  • dev/5.3_downgrade
  • feature/experimenttime_hack
  • 5.3 protected
  • _IntenSelect5.3
  • IntenSelect5.3
  • 4.27 protected
  • 4.26 protected
  • 5.0 protected
  • 4.22 protected
  • 4.21 protected
  • UE5.4-2024.1
  • UE5.4-2024.1-rc1
  • UE5.3-2023.1-rc3
  • UE5.3-2023.1-rc2
  • UE5.3-2023.1-rc
20 results

DemoRoomTest.umap

Blame
  • segment_WSI.py 22.34 KiB
    # this file recursively performs the automated segmentation of WSIs by applying the tissue segmentation and structure segmentation CNN to all WSIs contained in a specified folder
    
    import numpy as np
    import os
    import sys
    import cv2
    import torch
    import math
    import logging as log
    from tqdm import tqdm, trange
    import re
    
    from openslide.lowlevel import OpenSlideError
    import openslide as osl
    from PIL import Image
    import matplotlib.pyplot as plt
    from scipy.ndimage.measurements import label
    from scipy.ndimage import zoom
    from scipy.ndimage.morphology import binary_dilation, binary_closing, binary_fill_holes, binary_erosion, binary_opening
    import scipy.ndimage
    import scipy as sp
    from skimage.transform import rescale, resize
    from skimage.measure import regionprops
    from skimage.segmentation import flood, flood_fill
    from skimage.color import rgb2gray
    from skimage.segmentation import clear_border
    from skimage import filters
    from skimage.morphology import remove_small_objects
    
    from utils import generate_ball, patchify, unpatchify, saveOverlayResults, savePredictionOverlayResults, savePredictionOverlayResults_Fast, saveImage
    from model import Custom
    
    
    
    
    WSIrootFolder = 'SPECIFIED WSI FOLDER'
    modelpath = 'STRUCTURE SEGMENTATION MODEL PATH'
    resultsPath = 'RESULTS PATH'
    
    if not os.path.exists(resultsPath):
        os.makedirs(resultsPath)
    
    model_FG_path = 'TISSUE SEGMENTATION MODEL PATH'
    
    patchSegmSize = 516
    patchImgSize = 640
    patchLengthUM = 174.
    
    patchImgSize_FG = 512
    patchLengthUM_FG = 2500
    regionMinSizeUM = 3E5
    
    alpha = 0.3
    strideProportion = 0.5
    figHeight = 15
    minibatchSize = 2
    minibatchSize_FG = 1
    useAllGPUs = False
    GPUno = 0
    device = torch.device("cuda:" + str(GPUno) if torch.cuda.is_available() else "cpu")
    
    tubuliInstanceID_StartsWith = 10  # => tubuli instances start with id 10
    
    ftChannelsOutput = 8
    applyTestTimeAugmentation = True
    centerWeighting = False
    centerWeight = 3.
    
    saveWSICoarseForegroundSegmResults = True
    # saveCroppedForegroundSegmResults = False
    saveCroppedWSIimg = True
    saveWSIandPredNumpy = True
    
    TUBULI_MIN_SIZE = 400
    GLOM_MIN_SIZE = 1500
    TUFT_MIN_SIZE = 500
    VEIN_MIN_SIZE = 3000
    ARTERY_MIN_SIZE = 400
    LUMEN_MIN_SIZE = 20
    
    labelBG = 8
    
    # LOAD STRUCTURE SEGMENTATION MODEL
    model = Custom(input_ch=3, output_ch=ftChannelsOutput, modelDim=2)
    model.load_state_dict(torch.load(modelpath, map_location=lambda storage, loc: storage))
    model.train(False)
    model.eval()
    
    if useAllGPUs:
        model = torch.nn.DataParallel(model)  # multi-GPUs
    model = model.to(device)
    
    # LOAD TISSUE SEGMENTATION MODEL
    model_FG = torch.load(model_FG_path)
    model_FG = model_FG.to(device)
    model_FG.eval()
    
    segmentationPatchStride = int(patchSegmSize * strideProportion)
    targetSpacing = patchLengthUM / patchSegmSize
    
    targetSpacing_FG = patchLengthUM_FG / patchImgSize_FG
    
    shiftMinUM = (patchImgSize - patchSegmSize) // 2 * targetSpacing + 2 # + 2 just for sufficient margin reasons
    shiftMaxUM = ((patchImgSize - patchSegmSize) // 2 + patchSegmSize * strideProportion) * targetSpacing
    
    # Set up logger
    log.basicConfig(
        level=log.INFO,
        format='%(asctime)s %(message)s',
        datefmt='%Y/%m/%d %I:%M:%S %p',
        handlers=[
            log.FileHandler(resultsPath + '/LOGS.log', 'w'),
            log.StreamHandler(sys.stdout)
        ])
    logger = log.getLogger()
    
    struc3 = generate_ball(1)
    
    try:
        # walk through directoy and its subdirectoy recursively
        for dirName, subdirList, fileList in os.walk(WSIrootFolder):
    
            # filter WSIs in current directory
            fileListWSI = sorted([fname for fname in fileList if (fname.endswith('.svs') or fname.endswith('.ndpi') or fname.endswith('.scn')) and 'PAS' in fname])
    
            if len(fileListWSI) != 0:
                logger.info(str(len(fileListWSI)) + ' WSIs to be analyzed in directory: ' + dirName)
    
                resultsDir = resultsPath + dirName[len(WSIrootFolder):]
                resultsDirNPYfiles = resultsDir + '/npyFiles'
                if not os.path.exists(resultsDirNPYfiles):
                    os.makedirs(resultsDirNPYfiles)
    
                # traverse through all found WSIs
                for no, fname in enumerate(fileListWSI):
    
                    # Extract/print relevant parameters
                    try:
                        slide = osl.OpenSlide(os.path.join(dirName, fname))
                        logger.info(str(no + 1) + ':  WSI:\t' + fname)
                        spacings = np.array([float(slide.properties['openslide.mpp-x']), float(slide.properties['openslide.mpp-y'])])
                    except:
                        logger.info('Slide {} or its spacings not readable, slide skipped!'.format(fname))
                        continue
                    levelDims = np.array(slide.level_dimensions)
                    amountLevels = len(levelDims)
                    levelDownsamples = np.asarray(np.round(np.array(slide.level_downsamples)), np.int)
    
                    logger.info('Spacings: ' + str(spacings))
                    logger.info('Level Dimensions: ' + str(levelDims))
                    logger.info('Level Downsamples: '+str(levelDownsamples))
    
                    suffixCut = -5 if fname.split('.')[-1] == 'ndpi' else -4
    
                    # extract the WSI level that is closest to the target spacing of the tissue segmentation network (increasing efficiency instead of simply taking full resolution, finest level 0) 
                    spacingFactorX = spacings[0] / targetSpacing_FG
                    spacingFactorY = spacings[1] / targetSpacing_FG
                    x_scaled = round(levelDims[0][0] * spacingFactorX)
                    y_scaled = round(levelDims[0][1] * spacingFactorY)
    
                    usedLevel = np.argwhere(np.all(levelDims > [x_scaled, y_scaled], 1) == True).max()
    
                    # resample to target spacing
                    logger.info('Image size resampled to FG spacing would be {}, {}, thus level {} with resolution {} chosen as resampling point!'.format(x_scaled, y_scaled, usedLevel, levelDims[usedLevel]))
                    spacingOnUsedLevelX = spacings[0] * levelDownsamples[usedLevel]
                    spacingOnUsedLevelY = spacings[1] * levelDownsamples[usedLevel]
                    downsamplingFactorX = spacingOnUsedLevelX / targetSpacing_FG
                    downsamplingFactorY = spacingOnUsedLevelY / targetSpacing_FG
                    imgWSI = np.array(slide.read_region(location=np.array([0, 0]), level=usedLevel, size=levelDims[usedLevel]))[:, :, :3]
    
                    d1 = int(round(imgWSI.shape[1] * downsamplingFactorX))
                    d2 = int(round(imgWSI.shape[0] * downsamplingFactorY))
    
                    imgWSI = cv2.resize(imgWSI, dsize=(d1, d2), interpolation=cv2.INTER_LINEAR) #dtype: uint8, size: d2 x d1 x 3
                    imgWSIzeropadded = np.zeros(shape=(d2+patchImgSize_FG-1, d1+patchImgSize_FG-1, 3), dtype=np.float32)
                    imgWSIzeropadded[:d2,:d1,:] = imgWSI
    
                    # tesselate resampled image
                    smallOverlappingPatches = patchify(imgWSIzeropadded, patch_size=(patchImgSize_FG, patchImgSize_FG, 3), step=patchImgSize_FG) # CARE: IMAGE DATA AT THE RIGHT AND BOTTOM BORDERS IS LOST !!!
    
                    tileDataset = []
                    # with tqdm(total=smallOverlappingPatches.shape[0] * smallOverlappingPatches.shape[1]) as pbar:
                    for i in range(smallOverlappingPatches.shape[0]):
                        for j in range(smallOverlappingPatches.shape[1]):
                            tileDataset.append({'name': '{}-{}'.format(i, j), 'data': torch.from_numpy(smallOverlappingPatches[i, j, 0, :, :, :])})
                                # pbar.update(1)
    
                    img_mask = np.zeros(shape=(imgWSIzeropadded.shape[0],imgWSIzeropadded.shape[1]), dtype=np.bool)
    
                    # create dataloader for concurrent tile processing and prediction computation
                    dataloader = torch.utils.data.DataLoader(tileDataset, batch_size=minibatchSize_FG, shuffle=False)
                    with torch.no_grad():
                        for i, data in enumerate(dataloader, 0):
                            imgBatch = data['data'].permute(0, 3, 1, 2).to(device)
    
                            prediction = model_FG(imgBatch)  # prediction should have shape (1,2,512,512)
                            prediction = (prediction[:,1,:,:] > prediction[:,0,:,:]).to("cpu").numpy()
    
                            for n, d in zip(data['name'], prediction):
                                x = int(n.split('-')[0])
                                y = int(n.split('-')[1])
                                img_mask[x * patchImgSize_FG: (x+1) * patchImgSize_FG, y * patchImgSize_FG: (y+1) * patchImgSize_FG] = d
    
                    img_mask = img_mask[:d2,:d1]
    
                    # postprocessing
                    img_mask = binary_fill_holes(img_mask)
    
                    # remove connected regions if too small
                    regionMinPixels = regionMinSizeUM / (targetSpacing_FG * targetSpacing_FG)
                    img_mask, _ = label(img_mask)
                    labeledRegions, numberFGRegions = label(remove_small_objects(img_mask, min_size=regionMinPixels))
                    if numberFGRegions < 256:
                        labeledRegions = np.asarray(labeledRegions, np.uint8)
    
                    if saveWSICoarseForegroundSegmResults:
                        logger.info('Saving WSI-level coarse FG segmentation results...')
                        savePredictionOverlayResults(imgWSI, labeledRegions, alpha=alpha, figSize=(labeledRegions.shape[1]/labeledRegions.shape[0]*figHeight, figHeight), fullResultPath=resultsDir + '/' + fname[:suffixCut] + '_0_fgSeg.png')
    
                    labeledRegions = cv2.resize(labeledRegions, dsize=(levelDims[0][0],levelDims[0][1]), interpolation = cv2.INTER_NEAREST) # FG RESULTS ON WSI-RESOLUTION, UINT8, REGION IDs
    
                    logger.info('In total -> '+str(numberFGRegions)+' <- regions on WSI detected!')
    
                    # process all detected tissue regions separately 
                    for regionID in range(1, numberFGRegions+1):
                        logger.info('#######\n Extract foreground region ' + str(regionID) + '...')
                        detectedRegion = labeledRegions == regionID
    
                        # compute bounding box and how much to enlarge bbox to consider wider context utilization (especially for patchify)
                        temp = np.where(detectedRegion == 1)
                        bbox = np.array([np.min(temp[0]), np.min(temp[1]), np.max(temp[0]), np.max(temp[1])])
    
                        shiftMin = round(shiftMinUM / spacings[0])
                        shiftMax = round(shiftMaxUM / spacings[0])
    
                        # enlarge bounding box due to wider context consideration
                        bbox[0] = max(bbox[0] - shiftMin, 0)
                        bbox[1] = max(bbox[1] - shiftMin, 0)
                        bbox[2] = min(bbox[2] + shiftMax, detectedRegion.shape[0] - 1) + 1
                        bbox[3] = min(bbox[3] + shiftMax, detectedRegion.shape[1] - 1) + 1
    
                        logger.info('Extract high res patch and segm map...')
                        try:
                            img_WSI = np.asarray(np.array(slide.read_region(location=np.array([bbox[1], bbox[0]]), level=0, size=np.array([bbox[3] - bbox[1], bbox[2] - bbox[0]])))[:, :, :3], np.uint8)
                        except OpenSlideError:
                            logger.info('#################################### FILE CORRUPTED - IGNORED ####################################')
                            continue
                        detectedRegion = detectedRegion[bbox[0]:bbox[2], bbox[1]:bbox[3]]
    
                        # extract image and resample into target spacing of the structure segmentation network
                        downsamplingFactor = spacings[0] / targetSpacing  # Rescaling would be very slow using 'rescale' method!
                        logger.info('Utilized spacing of slide: '+str(spacings[0])+', Resample both patches using factor: ' + str(downsamplingFactor))
                        segMap = np.asarray(zoom(detectedRegion, downsamplingFactor, order=0), np.bool)
                        img_WSI = cv2.resize(img_WSI, dsize=tuple(np.flip(segMap.shape)), interpolation=cv2.INTER_LINEAR)
                        # segMap = np.asarray(np.round(rescale(segMap, downsamplingFactor, order=0, preserve_range=True, multichannel=False)), np.bool)
                        assert img_WSI.shape[:2] == segMap.shape, "Error: Upsampling via zoom/resize led to unequal resolutions..."
                        logger.info('Done - size of extracted resampled high res patch: ' + str(img_WSI.shape))
    
                        if np.min(segMap.shape) < patchImgSize:
                            logger.info('Detected region smaller than window, thus skipped...')
                            continue
    
                        ##### PREPROCESSING DONE - NOW: NETWORK SEGMENTATION PART #####
                        logger.info('Start segmentation process...')
    
                        # preprocess img
                        img_WSI_prep = np.array((img_WSI / 255. - 0.5) / 0.5, np.float32)
    
                        # tesselate image and tissue prediction results
                        smallOverlappingPatches = patchify(img_WSI_prep.copy(), patch_size=(patchImgSize, patchImgSize, 3), step=segmentationPatchStride) # CARE: IMAGE DATA AT THE RIGHT AND BOTTOM BORDERS IS LOST !!!
                        smallOverlappingPatches_FG = patchify(segMap.copy(), patch_size=(patchImgSize, patchImgSize), step=segmentationPatchStride)
    
                        tileDataset = []
                        for i in range(smallOverlappingPatches.shape[0]):
                            for j in range(smallOverlappingPatches.shape[1]):
                                if smallOverlappingPatches_FG[i,j,:,:].any():
                                    tileDataset.append({'name': '{}-{}'.format(i, j), 'data': torch.from_numpy(smallOverlappingPatches[i, j, 0, :, :, :])})
    
                        # calculate segmentation patch size since patchify cuts of last patch if not exactly fitting in window
                        startX = (patchImgSize - patchSegmSize) // 2; startY = startX
                        endX = segmentationPatchStride * (smallOverlappingPatches.shape[0]-1) + patchSegmSize + startX
                        endY = segmentationPatchStride * (smallOverlappingPatches.shape[1]-1) + patchSegmSize + startY
    
                        bigPatchResults = torch.zeros(device="cpu", size=(ftChannelsOutput, endX - startX, endY - startY))
    
                        # create dataloader for concurrent prediction computation
                        dataloader = torch.utils.data.DataLoader(tileDataset, batch_size=minibatchSize, shuffle=False)
                        with torch.no_grad():
                            for i, data in enumerate(dataloader, 0):
                                imgBatch = data['data'].permute(0, 3, 1, 2).to(device)
    
                                prediction = torch.softmax(model(imgBatch), dim=1)  # shape: (minibatchSize, 8, 516, 516)
                                if applyTestTimeAugmentation:
                                    imgBatch = imgBatch.flip(2)
                                    prediction += torch.softmax(model(imgBatch), 1).flip(2)
    
                                    imgBatch = imgBatch.flip(3)
                                    prediction += torch.softmax(model(imgBatch), 1).flip(3).flip(2)
    
                                    imgBatch = imgBatch.flip(2)
                                    prediction += torch.softmax(model(imgBatch), 1).flip(3)
    
                                if centerWeighting:
                                    prediction[:, :, patchSegmSize // 4: patchSegmSize // 4 * 3, patchSegmSize // 4: patchSegmSize // 4 * 3] *= centerWeight
    
                                prediction = prediction.to("cpu")
    
                                for n, d in zip(data['name'], prediction):
                                    x = int(n.split('-')[0])
                                    y = int(n.split('-')[1])
                                    bigPatchResults[:, x * segmentationPatchStride: x * segmentationPatchStride + patchSegmSize, y * segmentationPatchStride: y * segmentationPatchStride + patchSegmSize] = d
    
                            bigPatchResults = torch.argmax(bigPatchResults, 0).byte().numpy() # shape: (1536, 2048)
    
                        logger.info('Predictions generated. Final shape: '+str(bigPatchResults.shape))
    
                        # Context margin + border patches not fully inside img removed
                        img_WSI = img_WSI[startX:endX, startY:endY, :]
                        segMap = segMap[startX:endX, startY:endY]
                        bgMap = np.logical_not(segMap)
    
                        # Save cropped foreground segmentation result as overlay
                        if saveCroppedWSIimg:
                            logger.info('Saving cropped segmented WSI image...')
                            saveImage(img_WSI, fullResultPath=resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_fgWSI_({}_{}_{}).png'.format(bbox[0], bbox[1], spacings[0]), figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight))
    
                        # correct foreground segmentation including all touching vein prediction instances
                        bigPatchResults[bgMap] = 4 #vein class assignment of bg
                        temp = bigPatchResults == 4
                        bgMap = np.logical_xor(clear_border(temp), temp)
                        segMap = np.logical_not(bgMap)
                        segMap = binary_fill_holes(segMap)
                        bgMap = np.logical_not(segMap)
    
                        # remove small fg components
                        temp, numberLabeledRegions = label(segMap, struc3)
                        if numberLabeledRegions > 1:
                            regionMinPixels = regionMinSizeUM / (targetSpacing * targetSpacing)
                            regionIDs = np.where(np.array([region.area for region in regionprops(temp)]) > regionMinSizeUM)[0] + 1
                            segMap = np.isin(temp, regionIDs)
                            bgMap = np.logical_not(segMap)
    
                        bigPatchResults[bgMap] = labelBG # color of label 'labelBG' => Purple represents BG just for visualization purposes
    
                        logger.info('Saving prediction and background overlay results...')
                        savePredictionOverlayResults(img_WSI, bigPatchResults, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlay.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha)
    
                        if saveWSIandPredNumpy:
                            logger.info('Saving numpy img...')
                            np.save(resultsDirNPYfiles + '/' + fname[:suffixCut] + '_'+str(regionID)+'_resultWSI.npy', img_WSI)
    
                        logger.info('Start postprocessing...')
    
                        # remove border class
                        bigPatchResults[bigPatchResults == 7] = 0
    
                        # Delete BG to reduce postprocessing overhead
                        bigPatchResults[bgMap] = 0
    
                        ################# HOLE FILLING ################
                        bigPatchResults[binary_fill_holes(bigPatchResults == 1)] = 1  # tubuli
                        bigPatchResults[binary_fill_holes(bigPatchResults == 4)] = 4  # veins
                        temp = binary_fill_holes(bigPatchResults == 3)  # tuft
                        bigPatchResults[binary_fill_holes(np.logical_or(bigPatchResults == 3, bigPatchResults == 2))] = 2  # glom
                        bigPatchResults[temp] = 3  # tuft
                        temp = binary_fill_holes(bigPatchResults == 6)  # artery_lumen
                        bigPatchResults[binary_fill_holes(np.logical_or(bigPatchResults == 5, bigPatchResults == 6))] = 5  # full_artery
                        bigPatchResults[temp] = 6  # artery_lumen
    
                        ###### REMOVING TOO SMALL CONNECTED REGIONS ######
                        temp, _ = label(bigPatchResults == 1)
                        finalResults_Instance = remove_small_objects(temp, min_size=TUBULI_MIN_SIZE) > 0
    
                        ############ PERFORM TUBULE DILATION ############
                        finalResults_Instance, numberTubuli = label(finalResults_Instance) #dtype: int32
                        finalResults_Instance[finalResults_Instance > 0] += (tubuliInstanceID_StartsWith - 1)
                        if numberTubuli < 65500:
                            finalResults_Instance = cv2.dilate(np.asarray(finalResults_Instance, np.uint16), kernel=np.asarray(generate_ball(2), np.uint8), iterations=1) #RESULT TYPE: UINT16
                        else:
                            finalResults_Instance = np.asarray(cv2.dilate(np.asarray(finalResults_Instance, np.float64), kernel=np.asarray(generate_ball(2), np.uint8), iterations=1), np.int32)
    
                        temp, _ = label(np.logical_or(bigPatchResults == 2, bigPatchResults == 3))
                        finalResults_Instance[remove_small_objects(temp, min_size=GLOM_MIN_SIZE) > 0] = 2
                        temp, _ = label(bigPatchResults == 3)
                        finalResults_Instance[np.logical_and(remove_small_objects(temp, min_size=TUFT_MIN_SIZE) > 0, finalResults_Instance==2)] = 3
                        temp, _ = label(bigPatchResults == 4)
                        finalResults_Instance[remove_small_objects(temp, min_size=VEIN_MIN_SIZE) > 0] = 4
                        temp, _ = label(np.logical_or(bigPatchResults == 5, bigPatchResults == 6))
                        finalResults_Instance[remove_small_objects(temp, min_size=ARTERY_MIN_SIZE) > 0] = 5
                        temp, _ = label(bigPatchResults == 6)
                        finalResults_Instance[np.logical_and(remove_small_objects(temp, min_size=LUMEN_MIN_SIZE) > 0, finalResults_Instance==5)] = 6
    
                        finalResults_Instance = finalResults_Instance * segMap
    
                        logger.info('Done - Save final instance overlay results...')
                        savePredictionOverlayResults(img_WSI, finalResults_Instance, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlayFINALInstance.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha)
    
                        logger.info('Done - Save final non-instance overlay results...')
                        finalResults = finalResults_Instance.copy()
                        finalResults[finalResults > tubuliInstanceID_StartsWith] = 1
                        savePredictionOverlayResults(img_WSI, finalResults, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlayFINAL.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha)
    
                        finalResults_Instance[bgMap] = labelBG
    
                        if saveWSIandPredNumpy:
                            logger.info('Saving numpy final instance prediction results...')
                            np.save(resultsDirNPYfiles + '/' + fname[:suffixCut] + '_'+str(regionID)+'_finalInstancePrediction.npy', finalResults_Instance)
    
                    logger.info('####################')
    
            break
    
    except:
        logger.exception('! Exception !')
        raise
    
    log.info('%%%% Ended regularly ! %%%%')