Skip to content
Snippets Groups Projects
Select Git revision
  • 6b4b05348519d0d203d902781f6f6e4a60385b56
  • master default protected
  • feature/Add_arbor_support
  • develop protected
4 results

nest_controller.py

Blame
  • lrScheduler.py 3.55 KiB
    # this class represents a learning rate schedular that decreases the learning rate by a given factor after the validation error has not fallen for a specified amount of epochs 
    
    import torch
    
    LR_Reduce_No_Train_Improvement = 15
    LR_Reduce_No_Val_Improvement = 15
    EARLY_STOP_LR_TOO_LOW = 4e-6
    
    # This learning rate scheduler works as follows: After 15 epochs of no improvement on the validation loss, the learning rate gets divided by a specified factor. 
    # Training terminates if the learning rate has fallen below 4E-6, then the best model on the validation loss (with highest validation accuracy) will be chosen as the final model.
    class MyLRScheduler():
        def __init__(self, optimizer, model, foldResultsModelPath, setting, initLR, divideLRfactor):
            self.optimizer = optimizer
            self.model = model
            self.foldResultsModelPath = foldResultsModelPath
            self.currentLR = initLR
            self.divideLRfactor = divideLRfactor
    
            self.noImprovement = 0
    
            if 'val' in setting:
                self.bestValue = -1
            else:
                self.bestValue = 1E4
    
        # either way you train without utilizing a validation data set, then instead of the later, everything will be performed on the training data set!
        def stepTrain(self, newTrainLoss, logger):
            # Update learning rate
            if newTrainLoss >= self.bestValue:
                self.noImprovement += 1
    
                if self.noImprovement >= LR_Reduce_No_Train_Improvement:
                    self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestTrainModel.pt'))
                    self.update_lr_by_divison(self.divideLRfactor)
                    logger.info('### After '+str(LR_Reduce_No_Train_Improvement)+' no train loss reduction => Best model loaded and LR reduced to '+str(self.currentLR)+' !')
                    if self.currentLR < EARLY_STOP_LR_TOO_LOW:
                        return True
                    self.noImprovement = 0
            else:
                self.noImprovement = 0
                self.bestValue = newTrainLoss
                torch.save(self.model.state_dict(), self.foldResultsModelPath + '/currentBestTrainModel.pt')
    
            return False
    
        # when utilizing a validation data set as recommended/commonly suggested
        def stepTrainVal(self, newValScore, logger):
            # Update learning rate
            if newValScore <= self.bestValue:
                self.noImprovement += 1
    
                if self.noImprovement >= LR_Reduce_No_Val_Improvement:
                    self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestValModel.pt'))
                    self.update_lr_by_divison(self.divideLRfactor)
                    logger.info('### After ' + str(LR_Reduce_No_Val_Improvement) + ' no val score improvement => Best model loaded and LR reduced to ' + str(self.currentLR) + ' !')
                    if self.currentLR < EARLY_STOP_LR_TOO_LOW:
                        return True
                    self.noImprovement = 0
            else:
                self.noImprovement = 0
                self.bestValue = newValScore
                torch.save(self.model.state_dict(), self.foldResultsModelPath + '/currentBestValModel.pt')
    
            return False
    
         # divides learning rate of network by 'factor'
        def update_lr_by_divison(self, factor):
            newLR = self.currentLR / factor
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = newLR
            self.currentLR = newLR
    
        # loads current model with highest validation accuracy into current model
        def loadBestValIntoModel(self):
            self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestValModel.pt'))