Select Git revision
DemoRoomTest.umap
segment_WSI.py 22.34 KiB
# this file recursively performs the automated segmentation of WSIs by applying the tissue segmentation and structure segmentation CNN to all WSIs contained in a specified folder
import numpy as np
import os
import sys
import cv2
import torch
import math
import logging as log
from tqdm import tqdm, trange
import re
from openslide.lowlevel import OpenSlideError
import openslide as osl
from PIL import Image
import matplotlib.pyplot as plt
from scipy.ndimage.measurements import label
from scipy.ndimage import zoom
from scipy.ndimage.morphology import binary_dilation, binary_closing, binary_fill_holes, binary_erosion, binary_opening
import scipy.ndimage
import scipy as sp
from skimage.transform import rescale, resize
from skimage.measure import regionprops
from skimage.segmentation import flood, flood_fill
from skimage.color import rgb2gray
from skimage.segmentation import clear_border
from skimage import filters
from skimage.morphology import remove_small_objects
from utils import generate_ball, patchify, unpatchify, saveOverlayResults, savePredictionOverlayResults, savePredictionOverlayResults_Fast, saveImage
from model import Custom
WSIrootFolder = 'SPECIFIED WSI FOLDER'
modelpath = 'STRUCTURE SEGMENTATION MODEL PATH'
resultsPath = 'RESULTS PATH'
if not os.path.exists(resultsPath):
os.makedirs(resultsPath)
model_FG_path = 'TISSUE SEGMENTATION MODEL PATH'
patchSegmSize = 516
patchImgSize = 640
patchLengthUM = 174.
patchImgSize_FG = 512
patchLengthUM_FG = 2500
regionMinSizeUM = 3E5
alpha = 0.3
strideProportion = 0.5
figHeight = 15
minibatchSize = 2
minibatchSize_FG = 1
useAllGPUs = False
GPUno = 0
device = torch.device("cuda:" + str(GPUno) if torch.cuda.is_available() else "cpu")
tubuliInstanceID_StartsWith = 10 # => tubuli instances start with id 10
ftChannelsOutput = 8
applyTestTimeAugmentation = True
centerWeighting = False
centerWeight = 3.
saveWSICoarseForegroundSegmResults = True
# saveCroppedForegroundSegmResults = False
saveCroppedWSIimg = True
saveWSIandPredNumpy = True
TUBULI_MIN_SIZE = 400
GLOM_MIN_SIZE = 1500
TUFT_MIN_SIZE = 500
VEIN_MIN_SIZE = 3000
ARTERY_MIN_SIZE = 400
LUMEN_MIN_SIZE = 20
labelBG = 8
# LOAD STRUCTURE SEGMENTATION MODEL
model = Custom(input_ch=3, output_ch=ftChannelsOutput, modelDim=2)
model.load_state_dict(torch.load(modelpath, map_location=lambda storage, loc: storage))
model.train(False)
model.eval()
if useAllGPUs:
model = torch.nn.DataParallel(model) # multi-GPUs
model = model.to(device)
# LOAD TISSUE SEGMENTATION MODEL
model_FG = torch.load(model_FG_path)
model_FG = model_FG.to(device)
model_FG.eval()
segmentationPatchStride = int(patchSegmSize * strideProportion)
targetSpacing = patchLengthUM / patchSegmSize
targetSpacing_FG = patchLengthUM_FG / patchImgSize_FG
shiftMinUM = (patchImgSize - patchSegmSize) // 2 * targetSpacing + 2 # + 2 just for sufficient margin reasons
shiftMaxUM = ((patchImgSize - patchSegmSize) // 2 + patchSegmSize * strideProportion) * targetSpacing
# Set up logger
log.basicConfig(
level=log.INFO,
format='%(asctime)s %(message)s',
datefmt='%Y/%m/%d %I:%M:%S %p',
handlers=[
log.FileHandler(resultsPath + '/LOGS.log', 'w'),
log.StreamHandler(sys.stdout)
])
logger = log.getLogger()
struc3 = generate_ball(1)
try:
# walk through directoy and its subdirectoy recursively
for dirName, subdirList, fileList in os.walk(WSIrootFolder):
# filter WSIs in current directory
fileListWSI = sorted([fname for fname in fileList if (fname.endswith('.svs') or fname.endswith('.ndpi') or fname.endswith('.scn')) and 'PAS' in fname])
if len(fileListWSI) != 0:
logger.info(str(len(fileListWSI)) + ' WSIs to be analyzed in directory: ' + dirName)
resultsDir = resultsPath + dirName[len(WSIrootFolder):]
resultsDirNPYfiles = resultsDir + '/npyFiles'
if not os.path.exists(resultsDirNPYfiles):
os.makedirs(resultsDirNPYfiles)
# traverse through all found WSIs
for no, fname in enumerate(fileListWSI):
# Extract/print relevant parameters
try:
slide = osl.OpenSlide(os.path.join(dirName, fname))
logger.info(str(no + 1) + ': WSI:\t' + fname)
spacings = np.array([float(slide.properties['openslide.mpp-x']), float(slide.properties['openslide.mpp-y'])])
except:
logger.info('Slide {} or its spacings not readable, slide skipped!'.format(fname))
continue
levelDims = np.array(slide.level_dimensions)
amountLevels = len(levelDims)
levelDownsamples = np.asarray(np.round(np.array(slide.level_downsamples)), np.int)
logger.info('Spacings: ' + str(spacings))
logger.info('Level Dimensions: ' + str(levelDims))
logger.info('Level Downsamples: '+str(levelDownsamples))
suffixCut = -5 if fname.split('.')[-1] == 'ndpi' else -4
# extract the WSI level that is closest to the target spacing of the tissue segmentation network (increasing efficiency instead of simply taking full resolution, finest level 0)
spacingFactorX = spacings[0] / targetSpacing_FG
spacingFactorY = spacings[1] / targetSpacing_FG
x_scaled = round(levelDims[0][0] * spacingFactorX)
y_scaled = round(levelDims[0][1] * spacingFactorY)
usedLevel = np.argwhere(np.all(levelDims > [x_scaled, y_scaled], 1) == True).max()
# resample to target spacing
logger.info('Image size resampled to FG spacing would be {}, {}, thus level {} with resolution {} chosen as resampling point!'.format(x_scaled, y_scaled, usedLevel, levelDims[usedLevel]))
spacingOnUsedLevelX = spacings[0] * levelDownsamples[usedLevel]
spacingOnUsedLevelY = spacings[1] * levelDownsamples[usedLevel]
downsamplingFactorX = spacingOnUsedLevelX / targetSpacing_FG
downsamplingFactorY = spacingOnUsedLevelY / targetSpacing_FG
imgWSI = np.array(slide.read_region(location=np.array([0, 0]), level=usedLevel, size=levelDims[usedLevel]))[:, :, :3]
d1 = int(round(imgWSI.shape[1] * downsamplingFactorX))
d2 = int(round(imgWSI.shape[0] * downsamplingFactorY))
imgWSI = cv2.resize(imgWSI, dsize=(d1, d2), interpolation=cv2.INTER_LINEAR) #dtype: uint8, size: d2 x d1 x 3
imgWSIzeropadded = np.zeros(shape=(d2+patchImgSize_FG-1, d1+patchImgSize_FG-1, 3), dtype=np.float32)
imgWSIzeropadded[:d2,:d1,:] = imgWSI
# tesselate resampled image
smallOverlappingPatches = patchify(imgWSIzeropadded, patch_size=(patchImgSize_FG, patchImgSize_FG, 3), step=patchImgSize_FG) # CARE: IMAGE DATA AT THE RIGHT AND BOTTOM BORDERS IS LOST !!!
tileDataset = []
# with tqdm(total=smallOverlappingPatches.shape[0] * smallOverlappingPatches.shape[1]) as pbar:
for i in range(smallOverlappingPatches.shape[0]):
for j in range(smallOverlappingPatches.shape[1]):
tileDataset.append({'name': '{}-{}'.format(i, j), 'data': torch.from_numpy(smallOverlappingPatches[i, j, 0, :, :, :])})
# pbar.update(1)
img_mask = np.zeros(shape=(imgWSIzeropadded.shape[0],imgWSIzeropadded.shape[1]), dtype=np.bool)
# create dataloader for concurrent tile processing and prediction computation
dataloader = torch.utils.data.DataLoader(tileDataset, batch_size=minibatchSize_FG, shuffle=False)
with torch.no_grad():
for i, data in enumerate(dataloader, 0):
imgBatch = data['data'].permute(0, 3, 1, 2).to(device)
prediction = model_FG(imgBatch) # prediction should have shape (1,2,512,512)
prediction = (prediction[:,1,:,:] > prediction[:,0,:,:]).to("cpu").numpy()
for n, d in zip(data['name'], prediction):
x = int(n.split('-')[0])
y = int(n.split('-')[1])
img_mask[x * patchImgSize_FG: (x+1) * patchImgSize_FG, y * patchImgSize_FG: (y+1) * patchImgSize_FG] = d
img_mask = img_mask[:d2,:d1]
# postprocessing
img_mask = binary_fill_holes(img_mask)
# remove connected regions if too small
regionMinPixels = regionMinSizeUM / (targetSpacing_FG * targetSpacing_FG)
img_mask, _ = label(img_mask)
labeledRegions, numberFGRegions = label(remove_small_objects(img_mask, min_size=regionMinPixels))
if numberFGRegions < 256:
labeledRegions = np.asarray(labeledRegions, np.uint8)
if saveWSICoarseForegroundSegmResults:
logger.info('Saving WSI-level coarse FG segmentation results...')
savePredictionOverlayResults(imgWSI, labeledRegions, alpha=alpha, figSize=(labeledRegions.shape[1]/labeledRegions.shape[0]*figHeight, figHeight), fullResultPath=resultsDir + '/' + fname[:suffixCut] + '_0_fgSeg.png')
labeledRegions = cv2.resize(labeledRegions, dsize=(levelDims[0][0],levelDims[0][1]), interpolation = cv2.INTER_NEAREST) # FG RESULTS ON WSI-RESOLUTION, UINT8, REGION IDs
logger.info('In total -> '+str(numberFGRegions)+' <- regions on WSI detected!')
# process all detected tissue regions separately
for regionID in range(1, numberFGRegions+1):
logger.info('#######\n Extract foreground region ' + str(regionID) + '...')
detectedRegion = labeledRegions == regionID
# compute bounding box and how much to enlarge bbox to consider wider context utilization (especially for patchify)
temp = np.where(detectedRegion == 1)
bbox = np.array([np.min(temp[0]), np.min(temp[1]), np.max(temp[0]), np.max(temp[1])])
shiftMin = round(shiftMinUM / spacings[0])
shiftMax = round(shiftMaxUM / spacings[0])
# enlarge bounding box due to wider context consideration
bbox[0] = max(bbox[0] - shiftMin, 0)
bbox[1] = max(bbox[1] - shiftMin, 0)
bbox[2] = min(bbox[2] + shiftMax, detectedRegion.shape[0] - 1) + 1
bbox[3] = min(bbox[3] + shiftMax, detectedRegion.shape[1] - 1) + 1
logger.info('Extract high res patch and segm map...')
try:
img_WSI = np.asarray(np.array(slide.read_region(location=np.array([bbox[1], bbox[0]]), level=0, size=np.array([bbox[3] - bbox[1], bbox[2] - bbox[0]])))[:, :, :3], np.uint8)
except OpenSlideError:
logger.info('#################################### FILE CORRUPTED - IGNORED ####################################')
continue
detectedRegion = detectedRegion[bbox[0]:bbox[2], bbox[1]:bbox[3]]
# extract image and resample into target spacing of the structure segmentation network
downsamplingFactor = spacings[0] / targetSpacing # Rescaling would be very slow using 'rescale' method!
logger.info('Utilized spacing of slide: '+str(spacings[0])+', Resample both patches using factor: ' + str(downsamplingFactor))
segMap = np.asarray(zoom(detectedRegion, downsamplingFactor, order=0), np.bool)
img_WSI = cv2.resize(img_WSI, dsize=tuple(np.flip(segMap.shape)), interpolation=cv2.INTER_LINEAR)
# segMap = np.asarray(np.round(rescale(segMap, downsamplingFactor, order=0, preserve_range=True, multichannel=False)), np.bool)
assert img_WSI.shape[:2] == segMap.shape, "Error: Upsampling via zoom/resize led to unequal resolutions..."
logger.info('Done - size of extracted resampled high res patch: ' + str(img_WSI.shape))
if np.min(segMap.shape) < patchImgSize:
logger.info('Detected region smaller than window, thus skipped...')
continue
##### PREPROCESSING DONE - NOW: NETWORK SEGMENTATION PART #####
logger.info('Start segmentation process...')
# preprocess img
img_WSI_prep = np.array((img_WSI / 255. - 0.5) / 0.5, np.float32)
# tesselate image and tissue prediction results
smallOverlappingPatches = patchify(img_WSI_prep.copy(), patch_size=(patchImgSize, patchImgSize, 3), step=segmentationPatchStride) # CARE: IMAGE DATA AT THE RIGHT AND BOTTOM BORDERS IS LOST !!!
smallOverlappingPatches_FG = patchify(segMap.copy(), patch_size=(patchImgSize, patchImgSize), step=segmentationPatchStride)
tileDataset = []
for i in range(smallOverlappingPatches.shape[0]):
for j in range(smallOverlappingPatches.shape[1]):
if smallOverlappingPatches_FG[i,j,:,:].any():
tileDataset.append({'name': '{}-{}'.format(i, j), 'data': torch.from_numpy(smallOverlappingPatches[i, j, 0, :, :, :])})
# calculate segmentation patch size since patchify cuts of last patch if not exactly fitting in window
startX = (patchImgSize - patchSegmSize) // 2; startY = startX
endX = segmentationPatchStride * (smallOverlappingPatches.shape[0]-1) + patchSegmSize + startX
endY = segmentationPatchStride * (smallOverlappingPatches.shape[1]-1) + patchSegmSize + startY
bigPatchResults = torch.zeros(device="cpu", size=(ftChannelsOutput, endX - startX, endY - startY))
# create dataloader for concurrent prediction computation
dataloader = torch.utils.data.DataLoader(tileDataset, batch_size=minibatchSize, shuffle=False)
with torch.no_grad():
for i, data in enumerate(dataloader, 0):
imgBatch = data['data'].permute(0, 3, 1, 2).to(device)
prediction = torch.softmax(model(imgBatch), dim=1) # shape: (minibatchSize, 8, 516, 516)
if applyTestTimeAugmentation:
imgBatch = imgBatch.flip(2)
prediction += torch.softmax(model(imgBatch), 1).flip(2)
imgBatch = imgBatch.flip(3)
prediction += torch.softmax(model(imgBatch), 1).flip(3).flip(2)
imgBatch = imgBatch.flip(2)
prediction += torch.softmax(model(imgBatch), 1).flip(3)
if centerWeighting:
prediction[:, :, patchSegmSize // 4: patchSegmSize // 4 * 3, patchSegmSize // 4: patchSegmSize // 4 * 3] *= centerWeight
prediction = prediction.to("cpu")
for n, d in zip(data['name'], prediction):
x = int(n.split('-')[0])
y = int(n.split('-')[1])
bigPatchResults[:, x * segmentationPatchStride: x * segmentationPatchStride + patchSegmSize, y * segmentationPatchStride: y * segmentationPatchStride + patchSegmSize] = d
bigPatchResults = torch.argmax(bigPatchResults, 0).byte().numpy() # shape: (1536, 2048)
logger.info('Predictions generated. Final shape: '+str(bigPatchResults.shape))
# Context margin + border patches not fully inside img removed
img_WSI = img_WSI[startX:endX, startY:endY, :]
segMap = segMap[startX:endX, startY:endY]
bgMap = np.logical_not(segMap)
# Save cropped foreground segmentation result as overlay
if saveCroppedWSIimg:
logger.info('Saving cropped segmented WSI image...')
saveImage(img_WSI, fullResultPath=resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_fgWSI_({}_{}_{}).png'.format(bbox[0], bbox[1], spacings[0]), figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight))
# correct foreground segmentation including all touching vein prediction instances
bigPatchResults[bgMap] = 4 #vein class assignment of bg
temp = bigPatchResults == 4
bgMap = np.logical_xor(clear_border(temp), temp)
segMap = np.logical_not(bgMap)
segMap = binary_fill_holes(segMap)
bgMap = np.logical_not(segMap)
# remove small fg components
temp, numberLabeledRegions = label(segMap, struc3)
if numberLabeledRegions > 1:
regionMinPixels = regionMinSizeUM / (targetSpacing * targetSpacing)
regionIDs = np.where(np.array([region.area for region in regionprops(temp)]) > regionMinSizeUM)[0] + 1
segMap = np.isin(temp, regionIDs)
bgMap = np.logical_not(segMap)
bigPatchResults[bgMap] = labelBG # color of label 'labelBG' => Purple represents BG just for visualization purposes
logger.info('Saving prediction and background overlay results...')
savePredictionOverlayResults(img_WSI, bigPatchResults, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlay.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha)
if saveWSIandPredNumpy:
logger.info('Saving numpy img...')
np.save(resultsDirNPYfiles + '/' + fname[:suffixCut] + '_'+str(regionID)+'_resultWSI.npy', img_WSI)
logger.info('Start postprocessing...')
# remove border class
bigPatchResults[bigPatchResults == 7] = 0
# Delete BG to reduce postprocessing overhead
bigPatchResults[bgMap] = 0
################# HOLE FILLING ################
bigPatchResults[binary_fill_holes(bigPatchResults == 1)] = 1 # tubuli
bigPatchResults[binary_fill_holes(bigPatchResults == 4)] = 4 # veins
temp = binary_fill_holes(bigPatchResults == 3) # tuft
bigPatchResults[binary_fill_holes(np.logical_or(bigPatchResults == 3, bigPatchResults == 2))] = 2 # glom
bigPatchResults[temp] = 3 # tuft
temp = binary_fill_holes(bigPatchResults == 6) # artery_lumen
bigPatchResults[binary_fill_holes(np.logical_or(bigPatchResults == 5, bigPatchResults == 6))] = 5 # full_artery
bigPatchResults[temp] = 6 # artery_lumen
###### REMOVING TOO SMALL CONNECTED REGIONS ######
temp, _ = label(bigPatchResults == 1)
finalResults_Instance = remove_small_objects(temp, min_size=TUBULI_MIN_SIZE) > 0
############ PERFORM TUBULE DILATION ############
finalResults_Instance, numberTubuli = label(finalResults_Instance) #dtype: int32
finalResults_Instance[finalResults_Instance > 0] += (tubuliInstanceID_StartsWith - 1)
if numberTubuli < 65500:
finalResults_Instance = cv2.dilate(np.asarray(finalResults_Instance, np.uint16), kernel=np.asarray(generate_ball(2), np.uint8), iterations=1) #RESULT TYPE: UINT16
else:
finalResults_Instance = np.asarray(cv2.dilate(np.asarray(finalResults_Instance, np.float64), kernel=np.asarray(generate_ball(2), np.uint8), iterations=1), np.int32)
temp, _ = label(np.logical_or(bigPatchResults == 2, bigPatchResults == 3))
finalResults_Instance[remove_small_objects(temp, min_size=GLOM_MIN_SIZE) > 0] = 2
temp, _ = label(bigPatchResults == 3)
finalResults_Instance[np.logical_and(remove_small_objects(temp, min_size=TUFT_MIN_SIZE) > 0, finalResults_Instance==2)] = 3
temp, _ = label(bigPatchResults == 4)
finalResults_Instance[remove_small_objects(temp, min_size=VEIN_MIN_SIZE) > 0] = 4
temp, _ = label(np.logical_or(bigPatchResults == 5, bigPatchResults == 6))
finalResults_Instance[remove_small_objects(temp, min_size=ARTERY_MIN_SIZE) > 0] = 5
temp, _ = label(bigPatchResults == 6)
finalResults_Instance[np.logical_and(remove_small_objects(temp, min_size=LUMEN_MIN_SIZE) > 0, finalResults_Instance==5)] = 6
finalResults_Instance = finalResults_Instance * segMap
logger.info('Done - Save final instance overlay results...')
savePredictionOverlayResults(img_WSI, finalResults_Instance, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlayFINALInstance.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha)
logger.info('Done - Save final non-instance overlay results...')
finalResults = finalResults_Instance.copy()
finalResults[finalResults > tubuliInstanceID_StartsWith] = 1
savePredictionOverlayResults(img_WSI, finalResults, resultsDir + '/' + fname[:suffixCut] +'_'+str(regionID)+'_resultOverlayFINAL.png', figSize=(bigPatchResults.shape[1]/bigPatchResults.shape[0]*figHeight, figHeight), alpha=alpha)
finalResults_Instance[bgMap] = labelBG
if saveWSIandPredNumpy:
logger.info('Saving numpy final instance prediction results...')
np.save(resultsDirNPYfiles + '/' + fname[:suffixCut] + '_'+str(regionID)+'_finalInstancePrediction.npy', finalResults_Instance)
logger.info('####################')
break
except:
logger.exception('! Exception !')
raise
log.info('%%%% Ended regularly ! %%%%')