Skip to content
Snippets Groups Projects
Select Git revision
  • 56cb3d69ef128fb4cca83817fcd52e538316470b
  • 4.26 default protected
  • 5.3
  • dev/5.5
  • dev/5.3
  • dev/5.1
  • deprecated/4.22
7 results

VAAbstractSourceComponent.cpp

Blame
  • training.py 20.26 KiB
    # This file performs model training
    
    import os
    import shutil
    import numpy as np
    import logging as log
    import time
    import sys
    
    # from sklearn.metrics import classification_report, confusion_matrix
    
    import torch
    import torch.nn as nn
    from torch.utils.data import DataLoader
    from torch.utils.data.sampler import SubsetRandomSampler
    from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR
    from tensorboardX import SummaryWriter
    
    from RAdam import RAdam
    from dataset import CustomDataSetRAM
    from model import Custom
    from utils import getCrossValSplits, parse_nvidia_smi, parse_RAM_info, countParam, getDiceScores, getDiceScoresSinglePair, getMeanDiceScores, convert_labelmap_to_rgb, saveFigureResults, printResults
    from loss import DiceLoss
    from lrScheduler import MyLRScheduler
    from postprocessing import postprocessPredictionAndGT, extractInstanceChannels
    from evaluation import ClassEvaluator
    
    from nnUnet.generic_UNet import Generic_UNet
    
    import warnings
    warnings.filterwarnings("ignore")
    
    #################################### General GPU settings ####################################
    GPUno = 0
    useAllAvailableGPU = True
    device = torch.device("cuda:" + str(GPUno) if torch.cuda.is_available() else "cpu")
    ##################################### Save test results ######################################
    saveFinalTestResults = True
    ############################### Apply Test Time Augmentation #################################
    applyTestTimeAugmentation = True
    ##############################################################################################
    
    # this method trains a network with the given specification
    def train(model, setting, optimizer, scheduler, epochs, batchSize, logger, resultsPath, tbWriter, allClassEvaluators):
    
        model.to(device)
        if torch.cuda.device_count() > 1 and useAllAvailableGPU:
            logger.info('# {} GPUs utilized! #'.format(torch.cuda.device_count()))
            model = nn.DataParallel(model)
    
        # mandatory to produce random numpy numbers during training, otherwise batches will contain equal random numbers (originally: numpy issue)
        def worker_init_fn(worker_id):
            np.random.seed(np.random.get_state()[1][0] + worker_id)
    
        # allocate and separately load train / val / test data sets
        dataset_Train = CustomDataSetRAM('train', logger)
        dataloader_Train = DataLoader(dataset=dataset_Train, batch_size = batchSize, shuffle = True, num_workers = 4, worker_init_fn=worker_init_fn)
    
        if 'val' in setting:
            dataset_Val = CustomDataSetRAM('val', logger)
            dataloader_Val = DataLoader(dataset=dataset_Val, batch_size = batchSize, shuffle = False, num_workers = 1, worker_init_fn=worker_init_fn)
    
        if 'test' in setting:
            dataset_Test = CustomDataSetRAM('test', logger)
            dataloader_Test = DataLoader(dataset=dataset_Test, batch_size = batchSize, shuffle = False, num_workers = 1, worker_init_fn=worker_init_fn)
    
        logger.info('####### DATA LOADED - TRAINING STARTS... #######')
    
        # Utilize dice loss and weighted cross entropy loss, ignore index 8 as this is area outside the image included by augmentation, e.g. due to image rotation
        Dice_Loss = DiceLoss(ignore_index=8).to(device)
        CE_Loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([1., 1., 1., 1., 1., 1., 1., 10.]), ignore_index=8).to(device)
    
        for epoch in range(epochs):
            model.train(True)
    
            epochCELoss = 0
            epochDiceLoss = 0
            epochLoss = 0
    
            np.random.seed()
            start = time.time()
            for batch in dataloader_Train:
                # get data and put onto device
                imgBatch, segBatch = batch
    
                imgBatch = imgBatch.to(device)
                segBatch = segBatch.to(device)
    
                optimizer.zero_grad()
    
                # forward image batch, compute loss and backprop
                prediction = model(imgBatch)
    
                CEloss = CE_Loss(prediction, segBatch)
                diceLoss = Dice_Loss(prediction, segBatch)
    
                loss = CEloss + diceLoss
    
                epochCELoss += CEloss.item()
                epochDiceLoss += diceLoss.item()
                epochLoss += loss.item()
    
                loss.backward()
                optimizer.step()
    
            epochTrainLoss = epochLoss / dataloader_Train.__len__()
    
            end = time.time()
            # print current training loss
            logger.info('[Epoch '+str(epoch+1)+'] Train-Loss: '+str(round(epochTrainLoss,5))+', DiceLoss: '+
                        str(round(epochDiceLoss/dataloader_Train.__len__(),5))+', CELoss: '+str(round(epochCELoss/dataloader_Train.__len__(),5))+'  [took '+str(round(end-start,3))+'s]')
    
            # use tensorboard for visualization of training progress
            tbWriter.add_scalars('Plot/train', {'loss' : epochTrainLoss,
                                               'CEloss' : epochCELoss/dataloader_Train.__len__(),
                                               'DiceLoss' : epochDiceLoss/dataloader_Train.__len__()}, epoch)
            
            # each 50th epoch add prediction image to tensorboard
            if epoch % 50 == 0:
                with torch.no_grad():
                    tbWriter.add_image('Train/_img', torch.round((imgBatch[0,:,:,:] + 1.) / 2. * 255.0).byte() , epoch)
                    tbWriter.add_image('Train/GT', convert_labelmap_to_rgb(segBatch[0,:,:].cpu()), epoch)
                    tbWriter.add_image('Train/pred', convert_labelmap_to_rgb(prediction[0,:,:,:].argmax(0).cpu()), epoch)
    
            if epoch % 100 == 0:
                logger.info('[Epoch ' + str(epoch + 1) + '] ' + parse_nvidia_smi(GPUno))
                logger.info('[Epoch ' + str(epoch + 1) + '] ' + parse_RAM_info())
    
    
            # if validation was included, compute dice scores on validation data
            if 'val' in setting:
                model.train(False)
    
                diceScores_Val = []
    
                start = time.time()
                for batch in dataloader_Val:
                    imgBatch, segBatch = batch
                    imgBatch = imgBatch.to(device)
                    # segBatch = segBatch.to(device)
    
                    with torch.no_grad():
                        prediction = model(imgBatch).to('cpu')
    
                        diceScores_Val.append(getDiceScores(prediction, segBatch))
    
                diceScores_Val = np.concatenate(diceScores_Val, 0) # <- all dice scores of val data (batchSize x amountClasses-1)
                diceScores_Val = diceScores_Val[:, :-1]  # ignore last coloum=border dice scores
    
                mean_DiceScores_Val, epoch_val_mean_score = getMeanDiceScores(diceScores_Val, logger)
    
                end = time.time()
                logger.info('[Epoch '+str(epoch+1)+'] Val-Score (mean label dice scores): '+str(np.round(mean_DiceScores_Val,4))+', Mean: '+str(round(epoch_val_mean_score,4))+'  [took '+str(round(end-start,3))+'s]')
    
                tbWriter.add_scalar('Plot/val', epoch_val_mean_score, epoch)
    
                if epoch % 50 == 0:
                    with torch.no_grad():
                        tbWriter.add_image('Val/_img', torch.round((imgBatch[0,:,:,:] + 1.) / 2. * 255.0).byte(), epoch)
                        tbWriter.add_image('Val/GT', convert_labelmap_to_rgb(segBatch[0, :, :].cpu()), epoch)
                        tbWriter.add_image('Val/pred', convert_labelmap_to_rgb(prediction[0, :, :, :].argmax(0).cpu()), epoch)
    
                if epoch % 100 == 0:
                    logger.info('[Epoch ' + str(epoch + 1) + ' - After Validation] ' + parse_nvidia_smi(GPUno))
                    logger.info('[Epoch ' + str(epoch + 1) + ' - After Validation] ' + parse_RAM_info())
    
    
            # scheduler.step()
            if 'val' in setting:
                endLoop = scheduler.stepTrainVal(epoch_val_mean_score, logger)
            else:
                endLoop = scheduler.stepTrain(epochTrainLoss, logger)
    
            if epoch == (epochs - 1): #when no early stop is performed, load bestValModel into current model for later save
                logger.info('### No early stop performed! Best val model loaded... ####')
                if 'val' in setting:
                    scheduler.loadBestValIntoModel()
    
            # if test was included, compute global dice scores on test data (without postprocessing) for fast and coarse performance check
            if 'test' in setting:
                model.train(False); model.eval()
    
                diceScores_Test = []
    
                start = time.time()
                for batch in dataloader_Test:
                    imgBatch, segBatch = batch
                    imgBatch = imgBatch.to(device)
                    # segBatch = segBatch.to(device)
    
                    with torch.no_grad():
                        prediction = model(imgBatch).to('cpu')
    
                        diceScores_Test.append(getDiceScores(prediction, segBatch))
    
    
                diceScores_Test = np.concatenate(diceScores_Test, 0)  # <- all dice scores of test data (amountTestData x amountClasses-1)
                diceScores_Test = diceScores_Test[:,:-1] #ignore last coloum=border dice scores
    
                mean_DiceScores_Test, test_mean_score = getMeanDiceScores(diceScores_Test, logger)
    
                end = time.time()
                logger.info('[Epoch ' + str(epoch + 1) + '] Test-Score (mean label dice scores): ' + str(np.round(mean_DiceScores_Test, 4))+
                            ', Mean: ' + str(round(test_mean_score, 4)) + '  [took ' + str(round(end - start, 3)) + 's]')
    
                tbWriter.add_scalar('Plot/test', test_mean_score, epoch)
    
                if epoch % 50 == 0:
                    with torch.no_grad():
                        tbWriter.add_image('Test/_img', torch.round((imgBatch[0,:,:,:] + 1.) / 2. * 255.0).byte(), epoch)
                        tbWriter.add_image('Test/GT', convert_labelmap_to_rgb(segBatch[0, :, :].cpu()), epoch)
                        tbWriter.add_image('Test/pred', convert_labelmap_to_rgb(prediction[0, :, :, :].argmax(0).cpu()), epoch)
    
                if epoch % 100 == 0:
                    logger.info('[Epoch ' + str(epoch + 1) + ' - After Testing] ' + parse_nvidia_smi(GPUno))
                    logger.info('[Epoch ' + str(epoch + 1) + ' - After Testing] ' + parse_RAM_info())
    
                with torch.no_grad():
                    ### if training is over, compute final performances using the instance-level dice score and average precision ###
                    if endLoop or (epoch == epochs - 1):
    
                        diceScores_Test = []
                        diceScores_Test_TTA = []
    
                        # iterate through all test images
                        test_idx = np.arange(dataset_Test.__len__())
                        for sampleNo in test_idx:
                            imgBatch, segBatch = dataset_Test.__getitem__(sampleNo)
    
                            imgBatch = imgBatch.unsqueeze(0).to(device)
                            segBatch = segBatch.unsqueeze(0)
    
                            # get prediction and postprocess it
                            prediction = model(imgBatch)
    
                            postprocessedPrediction, outputPrediction, preprocessedGT = postprocessPredictionAndGT(prediction, segBatch.squeeze(0).numpy(), device=device, predictionsmoothing=False, holefilling=True)
    
                            classInstancePredictionList, classInstanceGTList, finalPredictionRGB, preprocessedGTrgb = extractInstanceChannels(postprocessedPrediction, preprocessedGT, tubuliDilation=True)
    
                            # here the evaluation is performed
                            # evaluate performance (TP, NP, FP counting and instance dice score computation)
                            for i in range(6): #number classes to evaluate = 6
                                allClassEvaluators[0][i].add_example(classInstancePredictionList[i],classInstanceGTList[i])
    
                            # there are regular dice similarity scores
                            diceScores_Test.append(getDiceScoresSinglePair(postprocessedPrediction, preprocessedGT, tubuliDilation=True)) #dilates 'postprocessedPrediction' permanently
    
                            if saveFinalTestResults:
                                figFolder = resultsPath
                                if not os.path.exists(figFolder):
                                    os.makedirs(figFolder)
    
                                imgBatchCPU = torch.round((imgBatch[0, :, :, :].to("cpu") + 1.) / 2. * 255.0).byte().numpy().transpose(1, 2, 0)
                                # figPath = figFolder + '/test_idx_' + str(sampleNo) + '_result.png'
                                # saveFigureResults(imgBatchCPU, outputPrediction, postprocessedPrediction, finalPredictionRGB, segBatch.squeeze(0).numpy(), preprocessedGT, preprocessedGTrgb, fullResultPath=figPath, alpha=0.4)
    
                            if applyTestTimeAugmentation: #perform test-time augmentation
                                prediction = torch.softmax(prediction, 1)
    
                                imgBatch = imgBatch.flip(2)
                                prediction += torch.softmax(model(imgBatch), 1).flip(2)
    
                                imgBatch = imgBatch.flip(3)
                                prediction += torch.softmax(model(imgBatch), 1).flip(3).flip(2)
    
                                imgBatch = imgBatch.flip(2)
                                prediction += torch.softmax(model(imgBatch), 1).flip(3)
    
                                prediction /= 4.
    
                                postprocessedPrediction, outputPrediction, preprocessedGT = postprocessPredictionAndGT(prediction, segBatch.squeeze(0).numpy(), device=device, predictionsmoothing=False, holefilling=True)
    
                                classInstancePredictionList, classInstanceGTList, finalPredictionRGB, preprocessedGTrgb = extractInstanceChannels(postprocessedPrediction, preprocessedGT, tubuliDilation=False)
    
                                for i in range(6):
                                    allClassEvaluators[1][i].add_example(classInstancePredictionList[i], classInstanceGTList[i])
    
                                diceScores_Test_TTA.append(getDiceScoresSinglePair(postprocessedPrediction, preprocessedGT, tubuliDilation=True)) #dilates 'postprocessedPrediction' permanently
    
                                if saveFinalTestResults:
                                    figPath = figFolder + '/test_idx_' + str(sampleNo) + '_result_TTA.png'
                                    saveFigureResults(imgBatchCPU, outputPrediction, postprocessedPrediction, finalPredictionRGB, segBatch.squeeze(0).numpy(), preprocessedGT, preprocessedGTrgb, fullResultPath=figPath, alpha=0.4)
    
    
                        logger.info('############################### RESULTS ###############################')
    
                        # print regular dice similarity coefficients as coarse performance check
                        diceScores_Test = np.concatenate(diceScores_Test, 0)  # <- all dice scores of test data (amountTestData x amountClasses-1)
                        diceScores_Test = diceScores_Test[:, :-1]  # ignore last coloum=border dice scores
                        mean_DiceScores_Test, test_mean_score = getMeanDiceScores(diceScores_Test, logger)
                        logger.info('MEAN DICE SCORES: ' + str(np.round(mean_DiceScores_Test, 4)) + ', Overall mean: ' + str(round(test_mean_score, 4)))
                        np.savetxt(resultsPath + '/allTestDiceScores.csv', diceScores_Test, delimiter=',')
    
                        # print regular dice similarity coefficients as coarse performance check
                        if applyTestTimeAugmentation:
                            diceScores_Test_TTA = np.concatenate(diceScores_Test_TTA, 0)  # <- all dice scores of test data (amountTestData x amountClasses-1)
                            diceScores_Test_TTA = diceScores_Test_TTA[:, :-1]  # ignore last coloum=border dice scores
                            mean_DiceScores_Test_TTA, test_mean_score_TTA = getMeanDiceScores(diceScores_Test_TTA, logger)
                            logger.info('TTA - MEAN DICE SCORES: ' + str(np.round(mean_DiceScores_Test_TTA, 4)) + ', Overall mean: ' + str(round(test_mean_score_TTA, 4)))
                            np.savetxt(resultsPath + '/allTestDiceScores_TTA.csv', diceScores_Test_TTA, delimiter=',')
    
                            printResults(allClassEvaluators=allClassEvaluators, applyTestTimeAugmentation=applyTestTimeAugmentation, printOnlyTTAresults=True, logger=logger, saveNumpyResults=False, resultsPath=resultsPath)
    
            if endLoop:
                logger.info('### Early network training stop at epoch '+str(epoch+1)+'! ###')
                break
    
    
        logger.info('[Epoch '+str(epoch+1)+'] ### Training done! ###')
    
        return model
    
    
    
    def set_up_training(modelString, setting, epochs, batchSize, lrate, weightDecay, logger, resultsPath):
    
        logger.info('### SETTING -> {} <- ###'.format(setting.upper()))
    
        # class evaluation modules for each structure and with or w/o test-time augmentation
        classEvaluators = [ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator()]
        classEvaluatorsTTA = [ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator()]
    
        allClassEvaluators = [classEvaluators, classEvaluatorsTTA]
    
        resultsModelPath = resultsPath +'/Model'
        if not os.path.exists(resultsModelPath):
            os.makedirs(resultsModelPath)
    
        # setting up tensorboard visualization
        tensorboardPath = resultsPath + '/TB'
        shutil.rmtree(tensorboardPath, ignore_errors=True) #<- remove existing TB events
        tbWriter = SummaryWriter(log_dir=tensorboardPath)
    
        if modelString == 'custom':
            model = Custom(input_ch=3, output_ch=8, modelDim=2)
        elif modelString == 'nnunet':
            model = Generic_UNet(input_channels=3, num_classes=8, base_num_features=30, num_pool=7, final_nonlin = None, deep_supervision=False, dropout_op_kwargs = {'p': 0.0, 'inplace': True})
        else:
            raise ValueError('Given model >' + modelString + '< is invalid!')
    
        logger.info(model)
        logger.info('Model capacity: {} parameters.'.format(countParam(model)))
    
        # set up optimizer
        optimizer = RAdam(model.parameters(), lr=lrate, weight_decay=weightDecay)
    
        # set up scheduler
        # scheduler = MultiStepLR(optimizer, milestones=[5, 15, 20, 25], gamma=0.3)
        # scheduler = ExponentialLR(optimizer, gamma=0.99)
        scheduler = MyLRScheduler(optimizer, model, resultsModelPath, setting, initLR=lrate, divideLRfactor=3.0)
    
        trained_model = train(
            model,
            setting,
            optimizer,
            scheduler,
            epochs,
            batchSize,
            logger,
            resultsPath,
            tbWriter,
            allClassEvaluators
        )
    
        # save final model (when validation is included, the model with lowest validation error is saved)
        torch.save(trained_model.state_dict(), resultsModelPath + '/finalModel.pt')
    
    
    
    if '__main__' == __name__:
        import argparse
        parser = argparse.ArgumentParser(description='python training.py -m <model-type> -d <dataset> -s <train_valid_test> -e <epochs> '+
                                                     '-b <batch-size> -r <learning-rate> -w <weight-decay>')
        parser.add_argument('-m', '--model', default='custom')
        parser.add_argument('-s', '--setting', default='train_val_test')
        parser.add_argument('-e', '--epochs', default=500, type=int)
        parser.add_argument('-b', '--batchSize', default=6, type=int)
        parser.add_argument('-r', '--lrate', default=0.001, type=float)
        parser.add_argument('-w', '--weightDecay', default=0.00001, type=float)
    
        options = parser.parse_args()
        assert(options.model in ['custom', 'unet', 'CEnet2D', 'CE_Net_Inception_Variants_2D', 'nnunet'])
        assert(options.setting in ['train_val_test', 'train_test', 'train_val', 'train'])
        assert(options.epochs > 0)
        assert(options.batchSize > 0)
        assert(options.lrate > 0)
        assert(options.weightDecay > 0)
    
        # Results path
        resultsPath = 'SPECIFY RESULTS PATH'
    
        if not os.path.exists(resultsPath):
            os.makedirs(resultsPath)
    
        # Set up logger
        log.basicConfig(
            level=log.INFO,
            format='%(asctime)s %(message)s',
            datefmt='%Y/%m/%d %I:%M:%S %p',
            handlers=[
                log.FileHandler(resultsPath + '/LOGS.log','w'),
                log.StreamHandler(sys.stdout)
            ])
        logger = log.getLogger()
    
        logger.info('###### STARTED PROGRAM WITH OPTIONS: {} ######'.format(str(options)))
    
        torch.backends.cudnn.benchmark = True
    
        try:
            # start whole training and evaluation procedure
            set_up_training(modelString=options.model,
                                     setting=options.setting,
                                     epochs=options.epochs,
                                     batchSize=options.batchSize,
                                     lrate=options.lrate,
                                     weightDecay=options.weightDecay,
                                     logger=logger,
                                     resultsPath=resultsPath)
        except:
            logger.exception('! Exception !')
            raise
    
        log.info('%%%% Ended regularly ! %%%%')