Skip to content
Snippets Groups Projects
Commit 5da73bb4 authored by Nassim Bouteldja's avatar Nassim Bouteldja
Browse files

Upload New File

parent 7ddbfbc1
No related branches found
No related tags found
No related merge requests found
# This file performs model training
import os
import shutil
import numpy as np
import logging as log
import time
import sys
# from sklearn.metrics import classification_report, confusion_matrix
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR
from tensorboardX import SummaryWriter
from RAdam import RAdam
from dataset import CustomDataSetRAM
from model import Custom
from utils import getCrossValSplits, parse_nvidia_smi, parse_RAM_info, countParam, getDiceScores, getDiceScoresSinglePair, getMeanDiceScores, convert_labelmap_to_rgb, saveFigureResults, printResults
from loss import DiceLoss
from lrScheduler import MyLRScheduler
from postprocessing import postprocessPredictionAndGT, extractInstanceChannels
from evaluation import ClassEvaluator
from nnUnet.generic_UNet import Generic_UNet
import warnings
warnings.filterwarnings("ignore")
#################################### General GPU settings ####################################
GPUno = 0
useAllAvailableGPU = True
device = torch.device("cuda:" + str(GPUno) if torch.cuda.is_available() else "cpu")
##################################### Save test results ######################################
saveFinalTestResults = True
############################### Apply Test Time Augmentation #################################
applyTestTimeAugmentation = True
##############################################################################################
# this method trains a network with the given specification
def train(model, setting, optimizer, scheduler, epochs, batchSize, logger, resultsPath, tbWriter, allClassEvaluators):
model.to(device)
if torch.cuda.device_count() > 1 and useAllAvailableGPU:
logger.info('# {} GPUs utilized! #'.format(torch.cuda.device_count()))
model = nn.DataParallel(model)
# mandatory to produce random numpy numbers during training, otherwise batches will contain equal random numbers (originally: numpy issue)
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
# allocate and separately load train / val / test data sets
dataset_Train = CustomDataSetRAM('train', logger)
dataloader_Train = DataLoader(dataset=dataset_Train, batch_size = batchSize, shuffle = True, num_workers = 4, worker_init_fn=worker_init_fn)
if 'val' in setting:
dataset_Val = CustomDataSetRAM('val', logger)
dataloader_Val = DataLoader(dataset=dataset_Val, batch_size = batchSize, shuffle = False, num_workers = 1, worker_init_fn=worker_init_fn)
if 'test' in setting:
dataset_Test = CustomDataSetRAM('test', logger)
dataloader_Test = DataLoader(dataset=dataset_Test, batch_size = batchSize, shuffle = False, num_workers = 1, worker_init_fn=worker_init_fn)
logger.info('####### DATA LOADED - TRAINING STARTS... #######')
# Utilize dice loss and weighted cross entropy loss, ignore index 8 as this is area outside the image included by augmentation, e.g. due to image rotation
Dice_Loss = DiceLoss(ignore_index=8).to(device)
CE_Loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([1., 1., 1., 1., 1., 1., 1., 10.]), ignore_index=8).to(device)
for epoch in range(epochs):
model.train(True)
epochCELoss = 0
epochDiceLoss = 0
epochLoss = 0
np.random.seed()
start = time.time()
for batch in dataloader_Train:
# get data and put onto device
imgBatch, segBatch = batch
imgBatch = imgBatch.to(device)
segBatch = segBatch.to(device)
optimizer.zero_grad()
# forward image batch, compute loss and backprop
prediction = model(imgBatch)
CEloss = CE_Loss(prediction, segBatch)
diceLoss = Dice_Loss(prediction, segBatch)
loss = CEloss + diceLoss
epochCELoss += CEloss.item()
epochDiceLoss += diceLoss.item()
epochLoss += loss.item()
loss.backward()
optimizer.step()
epochTrainLoss = epochLoss / dataloader_Train.__len__()
end = time.time()
# print current training loss
logger.info('[Epoch '+str(epoch+1)+'] Train-Loss: '+str(round(epochTrainLoss,5))+', DiceLoss: '+
str(round(epochDiceLoss/dataloader_Train.__len__(),5))+', CELoss: '+str(round(epochCELoss/dataloader_Train.__len__(),5))+' [took '+str(round(end-start,3))+'s]')
# use tensorboard for visualization of training progress
tbWriter.add_scalars('Plot/train', {'loss' : epochTrainLoss,
'CEloss' : epochCELoss/dataloader_Train.__len__(),
'DiceLoss' : epochDiceLoss/dataloader_Train.__len__()}, epoch)
# each 50th epoch add prediction image to tensorboard
if epoch % 50 == 0:
with torch.no_grad():
tbWriter.add_image('Train/_img', torch.round((imgBatch[0,:,:,:] + 1.) / 2. * 255.0).byte() , epoch)
tbWriter.add_image('Train/GT', convert_labelmap_to_rgb(segBatch[0,:,:].cpu()), epoch)
tbWriter.add_image('Train/pred', convert_labelmap_to_rgb(prediction[0,:,:,:].argmax(0).cpu()), epoch)
if epoch % 100 == 0:
logger.info('[Epoch ' + str(epoch + 1) + '] ' + parse_nvidia_smi(GPUno))
logger.info('[Epoch ' + str(epoch + 1) + '] ' + parse_RAM_info())
# if validation was included, compute dice scores on validation data
if 'val' in setting:
model.train(False)
diceScores_Val = []
start = time.time()
for batch in dataloader_Val:
imgBatch, segBatch = batch
imgBatch = imgBatch.to(device)
# segBatch = segBatch.to(device)
with torch.no_grad():
prediction = model(imgBatch).to('cpu')
diceScores_Val.append(getDiceScores(prediction, segBatch))
diceScores_Val = np.concatenate(diceScores_Val, 0) # <- all dice scores of val data (batchSize x amountClasses-1)
diceScores_Val = diceScores_Val[:, :-1] # ignore last coloum=border dice scores
mean_DiceScores_Val, epoch_val_mean_score = getMeanDiceScores(diceScores_Val, logger)
end = time.time()
logger.info('[Epoch '+str(epoch+1)+'] Val-Score (mean label dice scores): '+str(np.round(mean_DiceScores_Val,4))+', Mean: '+str(round(epoch_val_mean_score,4))+' [took '+str(round(end-start,3))+'s]')
tbWriter.add_scalar('Plot/val', epoch_val_mean_score, epoch)
if epoch % 50 == 0:
with torch.no_grad():
tbWriter.add_image('Val/_img', torch.round((imgBatch[0,:,:,:] + 1.) / 2. * 255.0).byte(), epoch)
tbWriter.add_image('Val/GT', convert_labelmap_to_rgb(segBatch[0, :, :].cpu()), epoch)
tbWriter.add_image('Val/pred', convert_labelmap_to_rgb(prediction[0, :, :, :].argmax(0).cpu()), epoch)
if epoch % 100 == 0:
logger.info('[Epoch ' + str(epoch + 1) + ' - After Validation] ' + parse_nvidia_smi(GPUno))
logger.info('[Epoch ' + str(epoch + 1) + ' - After Validation] ' + parse_RAM_info())
# scheduler.step()
if 'val' in setting:
endLoop = scheduler.stepTrainVal(epoch_val_mean_score, logger)
else:
endLoop = scheduler.stepTrain(epochTrainLoss, logger)
if epoch == (epochs - 1): #when no early stop is performed, load bestValModel into current model for later save
logger.info('### No early stop performed! Best val model loaded... ####')
if 'val' in setting:
scheduler.loadBestValIntoModel()
# if test was included, compute global dice scores on test data (without postprocessing) for fast and coarse performance check
if 'test' in setting:
model.train(False); model.eval()
diceScores_Test = []
start = time.time()
for batch in dataloader_Test:
imgBatch, segBatch = batch
imgBatch = imgBatch.to(device)
# segBatch = segBatch.to(device)
with torch.no_grad():
prediction = model(imgBatch).to('cpu')
diceScores_Test.append(getDiceScores(prediction, segBatch))
diceScores_Test = np.concatenate(diceScores_Test, 0) # <- all dice scores of test data (amountTestData x amountClasses-1)
diceScores_Test = diceScores_Test[:,:-1] #ignore last coloum=border dice scores
mean_DiceScores_Test, test_mean_score = getMeanDiceScores(diceScores_Test, logger)
end = time.time()
logger.info('[Epoch ' + str(epoch + 1) + '] Test-Score (mean label dice scores): ' + str(np.round(mean_DiceScores_Test, 4))+
', Mean: ' + str(round(test_mean_score, 4)) + ' [took ' + str(round(end - start, 3)) + 's]')
tbWriter.add_scalar('Plot/test', test_mean_score, epoch)
if epoch % 50 == 0:
with torch.no_grad():
tbWriter.add_image('Test/_img', torch.round((imgBatch[0,:,:,:] + 1.) / 2. * 255.0).byte(), epoch)
tbWriter.add_image('Test/GT', convert_labelmap_to_rgb(segBatch[0, :, :].cpu()), epoch)
tbWriter.add_image('Test/pred', convert_labelmap_to_rgb(prediction[0, :, :, :].argmax(0).cpu()), epoch)
if epoch % 100 == 0:
logger.info('[Epoch ' + str(epoch + 1) + ' - After Testing] ' + parse_nvidia_smi(GPUno))
logger.info('[Epoch ' + str(epoch + 1) + ' - After Testing] ' + parse_RAM_info())
with torch.no_grad():
### if training is over, compute final performances using the instance-level dice score and average precision ###
if endLoop or (epoch == epochs - 1):
diceScores_Test = []
diceScores_Test_TTA = []
# iterate through all test images
test_idx = np.arange(dataset_Test.__len__())
for sampleNo in test_idx:
imgBatch, segBatch = dataset_Test.__getitem__(sampleNo)
imgBatch = imgBatch.unsqueeze(0).to(device)
segBatch = segBatch.unsqueeze(0)
# get prediction and postprocess it
prediction = model(imgBatch)
postprocessedPrediction, outputPrediction, preprocessedGT = postprocessPredictionAndGT(prediction, segBatch.squeeze(0).numpy(), device=device, predictionsmoothing=False, holefilling=True)
classInstancePredictionList, classInstanceGTList, finalPredictionRGB, preprocessedGTrgb = extractInstanceChannels(postprocessedPrediction, preprocessedGT, tubuliDilation=True)
# here the evaluation is performed
# evaluate performance (TP, NP, FP counting and instance dice score computation)
for i in range(6): #number classes to evaluate = 6
allClassEvaluators[0][i].add_example(classInstancePredictionList[i],classInstanceGTList[i])
# there are regular dice similarity scores
diceScores_Test.append(getDiceScoresSinglePair(postprocessedPrediction, preprocessedGT, tubuliDilation=True)) #dilates 'postprocessedPrediction' permanently
if saveFinalTestResults:
figFolder = resultsPath
if not os.path.exists(figFolder):
os.makedirs(figFolder)
imgBatchCPU = torch.round((imgBatch[0, :, :, :].to("cpu") + 1.) / 2. * 255.0).byte().numpy().transpose(1, 2, 0)
# figPath = figFolder + '/test_idx_' + str(sampleNo) + '_result.png'
# saveFigureResults(imgBatchCPU, outputPrediction, postprocessedPrediction, finalPredictionRGB, segBatch.squeeze(0).numpy(), preprocessedGT, preprocessedGTrgb, fullResultPath=figPath, alpha=0.4)
if applyTestTimeAugmentation: #perform test-time augmentation
prediction = torch.softmax(prediction, 1)
imgBatch = imgBatch.flip(2)
prediction += torch.softmax(model(imgBatch), 1).flip(2)
imgBatch = imgBatch.flip(3)
prediction += torch.softmax(model(imgBatch), 1).flip(3).flip(2)
imgBatch = imgBatch.flip(2)
prediction += torch.softmax(model(imgBatch), 1).flip(3)
prediction /= 4.
postprocessedPrediction, outputPrediction, preprocessedGT = postprocessPredictionAndGT(prediction, segBatch.squeeze(0).numpy(), device=device, predictionsmoothing=False, holefilling=True)
classInstancePredictionList, classInstanceGTList, finalPredictionRGB, preprocessedGTrgb = extractInstanceChannels(postprocessedPrediction, preprocessedGT, tubuliDilation=False)
for i in range(6):
allClassEvaluators[1][i].add_example(classInstancePredictionList[i], classInstanceGTList[i])
diceScores_Test_TTA.append(getDiceScoresSinglePair(postprocessedPrediction, preprocessedGT, tubuliDilation=True)) #dilates 'postprocessedPrediction' permanently
if saveFinalTestResults:
figPath = figFolder + '/test_idx_' + str(sampleNo) + '_result_TTA.png'
saveFigureResults(imgBatchCPU, outputPrediction, postprocessedPrediction, finalPredictionRGB, segBatch.squeeze(0).numpy(), preprocessedGT, preprocessedGTrgb, fullResultPath=figPath, alpha=0.4)
logger.info('############################### RESULTS ###############################')
# print regular dice similarity coefficients as coarse performance check
diceScores_Test = np.concatenate(diceScores_Test, 0) # <- all dice scores of test data (amountTestData x amountClasses-1)
diceScores_Test = diceScores_Test[:, :-1] # ignore last coloum=border dice scores
mean_DiceScores_Test, test_mean_score = getMeanDiceScores(diceScores_Test, logger)
logger.info('MEAN DICE SCORES: ' + str(np.round(mean_DiceScores_Test, 4)) + ', Overall mean: ' + str(round(test_mean_score, 4)))
np.savetxt(resultsPath + '/allTestDiceScores.csv', diceScores_Test, delimiter=',')
# print regular dice similarity coefficients as coarse performance check
if applyTestTimeAugmentation:
diceScores_Test_TTA = np.concatenate(diceScores_Test_TTA, 0) # <- all dice scores of test data (amountTestData x amountClasses-1)
diceScores_Test_TTA = diceScores_Test_TTA[:, :-1] # ignore last coloum=border dice scores
mean_DiceScores_Test_TTA, test_mean_score_TTA = getMeanDiceScores(diceScores_Test_TTA, logger)
logger.info('TTA - MEAN DICE SCORES: ' + str(np.round(mean_DiceScores_Test_TTA, 4)) + ', Overall mean: ' + str(round(test_mean_score_TTA, 4)))
np.savetxt(resultsPath + '/allTestDiceScores_TTA.csv', diceScores_Test_TTA, delimiter=',')
printResults(allClassEvaluators=allClassEvaluators, applyTestTimeAugmentation=applyTestTimeAugmentation, printOnlyTTAresults=True, logger=logger, saveNumpyResults=False, resultsPath=resultsPath)
if endLoop:
logger.info('### Early network training stop at epoch '+str(epoch+1)+'! ###')
break
logger.info('[Epoch '+str(epoch+1)+'] ### Training done! ###')
return model
def set_up_training(modelString, setting, epochs, batchSize, lrate, weightDecay, logger, resultsPath):
logger.info('### SETTING -> {} <- ###'.format(setting.upper()))
# class evaluation modules for each structure and with or w/o test-time augmentation
classEvaluators = [ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator()]
classEvaluatorsTTA = [ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator()]
allClassEvaluators = [classEvaluators, classEvaluatorsTTA]
resultsModelPath = resultsPath +'/Model'
if not os.path.exists(resultsModelPath):
os.makedirs(resultsModelPath)
# setting up tensorboard visualization
tensorboardPath = resultsPath + '/TB'
shutil.rmtree(tensorboardPath, ignore_errors=True) #<- remove existing TB events
tbWriter = SummaryWriter(log_dir=tensorboardPath)
if modelString == 'custom':
model = Custom(input_ch=3, output_ch=8, modelDim=2)
elif modelString == 'nnunet':
model = Generic_UNet(input_channels=3, num_classes=8, base_num_features=30, num_pool=7, final_nonlin = None, deep_supervision=False, dropout_op_kwargs = {'p': 0.0, 'inplace': True})
else:
raise ValueError('Given model >' + modelString + '< is invalid!')
logger.info(model)
logger.info('Model capacity: {} parameters.'.format(countParam(model)))
# set up optimizer
optimizer = RAdam(model.parameters(), lr=lrate, weight_decay=weightDecay)
# set up scheduler
# scheduler = MultiStepLR(optimizer, milestones=[5, 15, 20, 25], gamma=0.3)
# scheduler = ExponentialLR(optimizer, gamma=0.99)
scheduler = MyLRScheduler(optimizer, model, resultsModelPath, setting, initLR=lrate, divideLRfactor=3.0)
trained_model = train(
model,
setting,
optimizer,
scheduler,
epochs,
batchSize,
logger,
resultsPath,
tbWriter,
allClassEvaluators
)
# save final model (when validation is included, the model with lowest validation error is saved)
torch.save(trained_model.state_dict(), resultsModelPath + '/finalModel.pt')
if '__main__' == __name__:
import argparse
parser = argparse.ArgumentParser(description='python training.py -m <model-type> -d <dataset> -s <train_valid_test> -e <epochs> '+
'-b <batch-size> -r <learning-rate> -w <weight-decay>')
parser.add_argument('-m', '--model', default='custom')
parser.add_argument('-s', '--setting', default='train_val_test')
parser.add_argument('-e', '--epochs', default=500, type=int)
parser.add_argument('-b', '--batchSize', default=6, type=int)
parser.add_argument('-r', '--lrate', default=0.001, type=float)
parser.add_argument('-w', '--weightDecay', default=0.00001, type=float)
options = parser.parse_args()
assert(options.model in ['custom', 'unet', 'CEnet2D', 'CE_Net_Inception_Variants_2D', 'nnunet'])
assert(options.setting in ['train_val_test', 'train_test', 'train_val', 'train'])
assert(options.epochs > 0)
assert(options.batchSize > 0)
assert(options.lrate > 0)
assert(options.weightDecay > 0)
# Results path
resultsPath = 'SPECIFY RESULTS PATH'
if not os.path.exists(resultsPath):
os.makedirs(resultsPath)
# Set up logger
log.basicConfig(
level=log.INFO,
format='%(asctime)s %(message)s',
datefmt='%Y/%m/%d %I:%M:%S %p',
handlers=[
log.FileHandler(resultsPath + '/LOGS.log','w'),
log.StreamHandler(sys.stdout)
])
logger = log.getLogger()
logger.info('###### STARTED PROGRAM WITH OPTIONS: {} ######'.format(str(options)))
torch.backends.cudnn.benchmark = True
try:
# start whole training and evaluation procedure
set_up_training(modelString=options.model,
setting=options.setting,
epochs=options.epochs,
batchSize=options.batchSize,
lrate=options.lrate,
weightDecay=options.weightDecay,
logger=logger,
resultsPath=resultsPath)
except:
logger.exception('! Exception !')
raise
log.info('%%%% Ended regularly ! %%%%')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment