diff --git a/training.py b/training.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccb6d82b42d0307c5bf8fdf2537ac775e33f3e5e
--- /dev/null
+++ b/training.py
@@ -0,0 +1,426 @@
+# This file performs model training
+
+import os
+import shutil
+import numpy as np
+import logging as log
+import time
+import sys
+
+# from sklearn.metrics import classification_report, confusion_matrix
+
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+from torch.utils.data.sampler import SubsetRandomSampler
+from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR
+from tensorboardX import SummaryWriter
+
+from RAdam import RAdam
+from dataset import CustomDataSetRAM
+from model import Custom
+from utils import getCrossValSplits, parse_nvidia_smi, parse_RAM_info, countParam, getDiceScores, getDiceScoresSinglePair, getMeanDiceScores, convert_labelmap_to_rgb, saveFigureResults, printResults
+from loss import DiceLoss
+from lrScheduler import MyLRScheduler
+from postprocessing import postprocessPredictionAndGT, extractInstanceChannels
+from evaluation import ClassEvaluator
+
+from nnUnet.generic_UNet import Generic_UNet
+
+import warnings
+warnings.filterwarnings("ignore")
+
+#################################### General GPU settings ####################################
+GPUno = 0
+useAllAvailableGPU = True
+device = torch.device("cuda:" + str(GPUno) if torch.cuda.is_available() else "cpu")
+##################################### Save test results ######################################
+saveFinalTestResults = True
+############################### Apply Test Time Augmentation #################################
+applyTestTimeAugmentation = True
+##############################################################################################
+
+# this method trains a network with the given specification
+def train(model, setting, optimizer, scheduler, epochs, batchSize, logger, resultsPath, tbWriter, allClassEvaluators):
+
+    model.to(device)
+    if torch.cuda.device_count() > 1 and useAllAvailableGPU:
+        logger.info('# {} GPUs utilized! #'.format(torch.cuda.device_count()))
+        model = nn.DataParallel(model)
+
+    # mandatory to produce random numpy numbers during training, otherwise batches will contain equal random numbers (originally: numpy issue)
+    def worker_init_fn(worker_id):
+        np.random.seed(np.random.get_state()[1][0] + worker_id)
+
+    # allocate and separately load train / val / test data sets
+    dataset_Train = CustomDataSetRAM('train', logger)
+    dataloader_Train = DataLoader(dataset=dataset_Train, batch_size = batchSize, shuffle = True, num_workers = 4, worker_init_fn=worker_init_fn)
+
+    if 'val' in setting:
+        dataset_Val = CustomDataSetRAM('val', logger)
+        dataloader_Val = DataLoader(dataset=dataset_Val, batch_size = batchSize, shuffle = False, num_workers = 1, worker_init_fn=worker_init_fn)
+
+    if 'test' in setting:
+        dataset_Test = CustomDataSetRAM('test', logger)
+        dataloader_Test = DataLoader(dataset=dataset_Test, batch_size = batchSize, shuffle = False, num_workers = 1, worker_init_fn=worker_init_fn)
+
+    logger.info('####### DATA LOADED - TRAINING STARTS... #######')
+
+    # Utilize dice loss and weighted cross entropy loss, ignore index 8 as this is area outside the image included by augmentation, e.g. due to image rotation
+    Dice_Loss = DiceLoss(ignore_index=8).to(device)
+    CE_Loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([1., 1., 1., 1., 1., 1., 1., 10.]), ignore_index=8).to(device)
+
+    for epoch in range(epochs):
+        model.train(True)
+
+        epochCELoss = 0
+        epochDiceLoss = 0
+        epochLoss = 0
+
+        np.random.seed()
+        start = time.time()
+        for batch in dataloader_Train:
+            # get data and put onto device
+            imgBatch, segBatch = batch
+
+            imgBatch = imgBatch.to(device)
+            segBatch = segBatch.to(device)
+
+            optimizer.zero_grad()
+
+            # forward image batch, compute loss and backprop
+            prediction = model(imgBatch)
+
+            CEloss = CE_Loss(prediction, segBatch)
+            diceLoss = Dice_Loss(prediction, segBatch)
+
+            loss = CEloss + diceLoss
+
+            epochCELoss += CEloss.item()
+            epochDiceLoss += diceLoss.item()
+            epochLoss += loss.item()
+
+            loss.backward()
+            optimizer.step()
+
+        epochTrainLoss = epochLoss / dataloader_Train.__len__()
+
+        end = time.time()
+        # print current training loss
+        logger.info('[Epoch '+str(epoch+1)+'] Train-Loss: '+str(round(epochTrainLoss,5))+', DiceLoss: '+
+                    str(round(epochDiceLoss/dataloader_Train.__len__(),5))+', CELoss: '+str(round(epochCELoss/dataloader_Train.__len__(),5))+'  [took '+str(round(end-start,3))+'s]')
+
+        # use tensorboard for visualization of training progress
+        tbWriter.add_scalars('Plot/train', {'loss' : epochTrainLoss,
+                                           'CEloss' : epochCELoss/dataloader_Train.__len__(),
+                                           'DiceLoss' : epochDiceLoss/dataloader_Train.__len__()}, epoch)
+        
+        # each 50th epoch add prediction image to tensorboard
+        if epoch % 50 == 0:
+            with torch.no_grad():
+                tbWriter.add_image('Train/_img', torch.round((imgBatch[0,:,:,:] + 1.) / 2. * 255.0).byte() , epoch)
+                tbWriter.add_image('Train/GT', convert_labelmap_to_rgb(segBatch[0,:,:].cpu()), epoch)
+                tbWriter.add_image('Train/pred', convert_labelmap_to_rgb(prediction[0,:,:,:].argmax(0).cpu()), epoch)
+
+        if epoch % 100 == 0:
+            logger.info('[Epoch ' + str(epoch + 1) + '] ' + parse_nvidia_smi(GPUno))
+            logger.info('[Epoch ' + str(epoch + 1) + '] ' + parse_RAM_info())
+
+
+        # if validation was included, compute dice scores on validation data
+        if 'val' in setting:
+            model.train(False)
+
+            diceScores_Val = []
+
+            start = time.time()
+            for batch in dataloader_Val:
+                imgBatch, segBatch = batch
+                imgBatch = imgBatch.to(device)
+                # segBatch = segBatch.to(device)
+
+                with torch.no_grad():
+                    prediction = model(imgBatch).to('cpu')
+
+                    diceScores_Val.append(getDiceScores(prediction, segBatch))
+
+            diceScores_Val = np.concatenate(diceScores_Val, 0) # <- all dice scores of val data (batchSize x amountClasses-1)
+            diceScores_Val = diceScores_Val[:, :-1]  # ignore last coloum=border dice scores
+
+            mean_DiceScores_Val, epoch_val_mean_score = getMeanDiceScores(diceScores_Val, logger)
+
+            end = time.time()
+            logger.info('[Epoch '+str(epoch+1)+'] Val-Score (mean label dice scores): '+str(np.round(mean_DiceScores_Val,4))+', Mean: '+str(round(epoch_val_mean_score,4))+'  [took '+str(round(end-start,3))+'s]')
+
+            tbWriter.add_scalar('Plot/val', epoch_val_mean_score, epoch)
+
+            if epoch % 50 == 0:
+                with torch.no_grad():
+                    tbWriter.add_image('Val/_img', torch.round((imgBatch[0,:,:,:] + 1.) / 2. * 255.0).byte(), epoch)
+                    tbWriter.add_image('Val/GT', convert_labelmap_to_rgb(segBatch[0, :, :].cpu()), epoch)
+                    tbWriter.add_image('Val/pred', convert_labelmap_to_rgb(prediction[0, :, :, :].argmax(0).cpu()), epoch)
+
+            if epoch % 100 == 0:
+                logger.info('[Epoch ' + str(epoch + 1) + ' - After Validation] ' + parse_nvidia_smi(GPUno))
+                logger.info('[Epoch ' + str(epoch + 1) + ' - After Validation] ' + parse_RAM_info())
+
+
+        # scheduler.step()
+        if 'val' in setting:
+            endLoop = scheduler.stepTrainVal(epoch_val_mean_score, logger)
+        else:
+            endLoop = scheduler.stepTrain(epochTrainLoss, logger)
+
+        if epoch == (epochs - 1): #when no early stop is performed, load bestValModel into current model for later save
+            logger.info('### No early stop performed! Best val model loaded... ####')
+            if 'val' in setting:
+                scheduler.loadBestValIntoModel()
+
+        # if test was included, compute global dice scores on test data (without postprocessing) for fast and coarse performance check
+        if 'test' in setting:
+            model.train(False); model.eval()
+
+            diceScores_Test = []
+
+            start = time.time()
+            for batch in dataloader_Test:
+                imgBatch, segBatch = batch
+                imgBatch = imgBatch.to(device)
+                # segBatch = segBatch.to(device)
+
+                with torch.no_grad():
+                    prediction = model(imgBatch).to('cpu')
+
+                    diceScores_Test.append(getDiceScores(prediction, segBatch))
+
+
+            diceScores_Test = np.concatenate(diceScores_Test, 0)  # <- all dice scores of test data (amountTestData x amountClasses-1)
+            diceScores_Test = diceScores_Test[:,:-1] #ignore last coloum=border dice scores
+
+            mean_DiceScores_Test, test_mean_score = getMeanDiceScores(diceScores_Test, logger)
+
+            end = time.time()
+            logger.info('[Epoch ' + str(epoch + 1) + '] Test-Score (mean label dice scores): ' + str(np.round(mean_DiceScores_Test, 4))+
+                        ', Mean: ' + str(round(test_mean_score, 4)) + '  [took ' + str(round(end - start, 3)) + 's]')
+
+            tbWriter.add_scalar('Plot/test', test_mean_score, epoch)
+
+            if epoch % 50 == 0:
+                with torch.no_grad():
+                    tbWriter.add_image('Test/_img', torch.round((imgBatch[0,:,:,:] + 1.) / 2. * 255.0).byte(), epoch)
+                    tbWriter.add_image('Test/GT', convert_labelmap_to_rgb(segBatch[0, :, :].cpu()), epoch)
+                    tbWriter.add_image('Test/pred', convert_labelmap_to_rgb(prediction[0, :, :, :].argmax(0).cpu()), epoch)
+
+            if epoch % 100 == 0:
+                logger.info('[Epoch ' + str(epoch + 1) + ' - After Testing] ' + parse_nvidia_smi(GPUno))
+                logger.info('[Epoch ' + str(epoch + 1) + ' - After Testing] ' + parse_RAM_info())
+
+            with torch.no_grad():
+                ### if training is over, compute final performances using the instance-level dice score and average precision ###
+                if endLoop or (epoch == epochs - 1):
+
+                    diceScores_Test = []
+                    diceScores_Test_TTA = []
+
+                    # iterate through all test images
+                    test_idx = np.arange(dataset_Test.__len__())
+                    for sampleNo in test_idx:
+                        imgBatch, segBatch = dataset_Test.__getitem__(sampleNo)
+
+                        imgBatch = imgBatch.unsqueeze(0).to(device)
+                        segBatch = segBatch.unsqueeze(0)
+
+                        # get prediction and postprocess it
+                        prediction = model(imgBatch)
+
+                        postprocessedPrediction, outputPrediction, preprocessedGT = postprocessPredictionAndGT(prediction, segBatch.squeeze(0).numpy(), device=device, predictionsmoothing=False, holefilling=True)
+
+                        classInstancePredictionList, classInstanceGTList, finalPredictionRGB, preprocessedGTrgb = extractInstanceChannels(postprocessedPrediction, preprocessedGT, tubuliDilation=True)
+
+                        # here the evaluation is performed
+                        # evaluate performance (TP, NP, FP counting and instance dice score computation)
+                        for i in range(6): #number classes to evaluate = 6
+                            allClassEvaluators[0][i].add_example(classInstancePredictionList[i],classInstanceGTList[i])
+
+                        # there are regular dice similarity scores
+                        diceScores_Test.append(getDiceScoresSinglePair(postprocessedPrediction, preprocessedGT, tubuliDilation=True)) #dilates 'postprocessedPrediction' permanently
+
+                        if saveFinalTestResults:
+                            figFolder = resultsPath
+                            if not os.path.exists(figFolder):
+                                os.makedirs(figFolder)
+
+                            imgBatchCPU = torch.round((imgBatch[0, :, :, :].to("cpu") + 1.) / 2. * 255.0).byte().numpy().transpose(1, 2, 0)
+                            # figPath = figFolder + '/test_idx_' + str(sampleNo) + '_result.png'
+                            # saveFigureResults(imgBatchCPU, outputPrediction, postprocessedPrediction, finalPredictionRGB, segBatch.squeeze(0).numpy(), preprocessedGT, preprocessedGTrgb, fullResultPath=figPath, alpha=0.4)
+
+                        if applyTestTimeAugmentation: #perform test-time augmentation
+                            prediction = torch.softmax(prediction, 1)
+
+                            imgBatch = imgBatch.flip(2)
+                            prediction += torch.softmax(model(imgBatch), 1).flip(2)
+
+                            imgBatch = imgBatch.flip(3)
+                            prediction += torch.softmax(model(imgBatch), 1).flip(3).flip(2)
+
+                            imgBatch = imgBatch.flip(2)
+                            prediction += torch.softmax(model(imgBatch), 1).flip(3)
+
+                            prediction /= 4.
+
+                            postprocessedPrediction, outputPrediction, preprocessedGT = postprocessPredictionAndGT(prediction, segBatch.squeeze(0).numpy(), device=device, predictionsmoothing=False, holefilling=True)
+
+                            classInstancePredictionList, classInstanceGTList, finalPredictionRGB, preprocessedGTrgb = extractInstanceChannels(postprocessedPrediction, preprocessedGT, tubuliDilation=False)
+
+                            for i in range(6):
+                                allClassEvaluators[1][i].add_example(classInstancePredictionList[i], classInstanceGTList[i])
+
+                            diceScores_Test_TTA.append(getDiceScoresSinglePair(postprocessedPrediction, preprocessedGT, tubuliDilation=True)) #dilates 'postprocessedPrediction' permanently
+
+                            if saveFinalTestResults:
+                                figPath = figFolder + '/test_idx_' + str(sampleNo) + '_result_TTA.png'
+                                saveFigureResults(imgBatchCPU, outputPrediction, postprocessedPrediction, finalPredictionRGB, segBatch.squeeze(0).numpy(), preprocessedGT, preprocessedGTrgb, fullResultPath=figPath, alpha=0.4)
+
+
+                    logger.info('############################### RESULTS ###############################')
+
+                    # print regular dice similarity coefficients as coarse performance check
+                    diceScores_Test = np.concatenate(diceScores_Test, 0)  # <- all dice scores of test data (amountTestData x amountClasses-1)
+                    diceScores_Test = diceScores_Test[:, :-1]  # ignore last coloum=border dice scores
+                    mean_DiceScores_Test, test_mean_score = getMeanDiceScores(diceScores_Test, logger)
+                    logger.info('MEAN DICE SCORES: ' + str(np.round(mean_DiceScores_Test, 4)) + ', Overall mean: ' + str(round(test_mean_score, 4)))
+                    np.savetxt(resultsPath + '/allTestDiceScores.csv', diceScores_Test, delimiter=',')
+
+                    # print regular dice similarity coefficients as coarse performance check
+                    if applyTestTimeAugmentation:
+                        diceScores_Test_TTA = np.concatenate(diceScores_Test_TTA, 0)  # <- all dice scores of test data (amountTestData x amountClasses-1)
+                        diceScores_Test_TTA = diceScores_Test_TTA[:, :-1]  # ignore last coloum=border dice scores
+                        mean_DiceScores_Test_TTA, test_mean_score_TTA = getMeanDiceScores(diceScores_Test_TTA, logger)
+                        logger.info('TTA - MEAN DICE SCORES: ' + str(np.round(mean_DiceScores_Test_TTA, 4)) + ', Overall mean: ' + str(round(test_mean_score_TTA, 4)))
+                        np.savetxt(resultsPath + '/allTestDiceScores_TTA.csv', diceScores_Test_TTA, delimiter=',')
+
+                        printResults(allClassEvaluators=allClassEvaluators, applyTestTimeAugmentation=applyTestTimeAugmentation, printOnlyTTAresults=True, logger=logger, saveNumpyResults=False, resultsPath=resultsPath)
+
+        if endLoop:
+            logger.info('### Early network training stop at epoch '+str(epoch+1)+'! ###')
+            break
+
+
+    logger.info('[Epoch '+str(epoch+1)+'] ### Training done! ###')
+
+    return model
+
+
+
+def set_up_training(modelString, setting, epochs, batchSize, lrate, weightDecay, logger, resultsPath):
+
+    logger.info('### SETTING -> {} <- ###'.format(setting.upper()))
+
+    # class evaluation modules for each structure and with or w/o test-time augmentation
+    classEvaluators = [ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator()]
+    classEvaluatorsTTA = [ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator(), ClassEvaluator()]
+
+    allClassEvaluators = [classEvaluators, classEvaluatorsTTA]
+
+    resultsModelPath = resultsPath +'/Model'
+    if not os.path.exists(resultsModelPath):
+        os.makedirs(resultsModelPath)
+
+    # setting up tensorboard visualization
+    tensorboardPath = resultsPath + '/TB'
+    shutil.rmtree(tensorboardPath, ignore_errors=True) #<- remove existing TB events
+    tbWriter = SummaryWriter(log_dir=tensorboardPath)
+
+    if modelString == 'custom':
+        model = Custom(input_ch=3, output_ch=8, modelDim=2)
+    elif modelString == 'nnunet':
+        model = Generic_UNet(input_channels=3, num_classes=8, base_num_features=30, num_pool=7, final_nonlin = None, deep_supervision=False, dropout_op_kwargs = {'p': 0.0, 'inplace': True})
+    else:
+        raise ValueError('Given model >' + modelString + '< is invalid!')
+
+    logger.info(model)
+    logger.info('Model capacity: {} parameters.'.format(countParam(model)))
+
+    # set up optimizer
+    optimizer = RAdam(model.parameters(), lr=lrate, weight_decay=weightDecay)
+
+    # set up scheduler
+    # scheduler = MultiStepLR(optimizer, milestones=[5, 15, 20, 25], gamma=0.3)
+    # scheduler = ExponentialLR(optimizer, gamma=0.99)
+    scheduler = MyLRScheduler(optimizer, model, resultsModelPath, setting, initLR=lrate, divideLRfactor=3.0)
+
+    trained_model = train(
+        model,
+        setting,
+        optimizer,
+        scheduler,
+        epochs,
+        batchSize,
+        logger,
+        resultsPath,
+        tbWriter,
+        allClassEvaluators
+    )
+
+    # save final model (when validation is included, the model with lowest validation error is saved)
+    torch.save(trained_model.state_dict(), resultsModelPath + '/finalModel.pt')
+
+
+
+if '__main__' == __name__:
+    import argparse
+    parser = argparse.ArgumentParser(description='python training.py -m <model-type> -d <dataset> -s <train_valid_test> -e <epochs> '+
+                                                 '-b <batch-size> -r <learning-rate> -w <weight-decay>')
+    parser.add_argument('-m', '--model', default='custom')
+    parser.add_argument('-s', '--setting', default='train_val_test')
+    parser.add_argument('-e', '--epochs', default=500, type=int)
+    parser.add_argument('-b', '--batchSize', default=6, type=int)
+    parser.add_argument('-r', '--lrate', default=0.001, type=float)
+    parser.add_argument('-w', '--weightDecay', default=0.00001, type=float)
+
+    options = parser.parse_args()
+    assert(options.model in ['custom', 'unet', 'CEnet2D', 'CE_Net_Inception_Variants_2D', 'nnunet'])
+    assert(options.setting in ['train_val_test', 'train_test', 'train_val', 'train'])
+    assert(options.epochs > 0)
+    assert(options.batchSize > 0)
+    assert(options.lrate > 0)
+    assert(options.weightDecay > 0)
+
+    # Results path
+    resultsPath = 'SPECIFY RESULTS PATH'
+
+    if not os.path.exists(resultsPath):
+        os.makedirs(resultsPath)
+
+    # Set up logger
+    log.basicConfig(
+        level=log.INFO,
+        format='%(asctime)s %(message)s',
+        datefmt='%Y/%m/%d %I:%M:%S %p',
+        handlers=[
+            log.FileHandler(resultsPath + '/LOGS.log','w'),
+            log.StreamHandler(sys.stdout)
+        ])
+    logger = log.getLogger()
+
+    logger.info('###### STARTED PROGRAM WITH OPTIONS: {} ######'.format(str(options)))
+
+    torch.backends.cudnn.benchmark = True
+
+    try:
+        # start whole training and evaluation procedure
+        set_up_training(modelString=options.model,
+                                 setting=options.setting,
+                                 epochs=options.epochs,
+                                 batchSize=options.batchSize,
+                                 lrate=options.lrate,
+                                 weightDecay=options.weightDecay,
+                                 logger=logger,
+                                 resultsPath=resultsPath)
+    except:
+        logger.exception('! Exception !')
+        raise
+
+    log.info('%%%% Ended regularly ! %%%%')
+
+