# this class represents a learning rate schedular that decreases the learning rate by a given factor after the validation error has not fallen for a specified amount of epochs import torch LR_Reduce_No_Train_Improvement = 15 LR_Reduce_No_Val_Improvement = 15 EARLY_STOP_LR_TOO_LOW = 4e-6 # This learning rate scheduler works as follows: After 15 epochs of no improvement on the validation loss, the learning rate gets divided by a specified factor. # Training terminates if the learning rate has fallen below 4E-6, then the best model on the validation loss (with highest validation accuracy) will be chosen as the final model. class MyLRScheduler(): def __init__(self, optimizer, model, foldResultsModelPath, setting, initLR, divideLRfactor): self.optimizer = optimizer self.model = model self.foldResultsModelPath = foldResultsModelPath self.currentLR = initLR self.divideLRfactor = divideLRfactor self.noImprovement = 0 if 'val' in setting: self.bestValue = -1 else: self.bestValue = 1E4 # either way you train without utilizing a validation data set, then instead of the later, everything will be performed on the training data set! def stepTrain(self, newTrainLoss, logger): # Update learning rate if newTrainLoss >= self.bestValue: self.noImprovement += 1 if self.noImprovement >= LR_Reduce_No_Train_Improvement: self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestTrainModel.pt')) self.update_lr_by_divison(self.divideLRfactor) logger.info('### After '+str(LR_Reduce_No_Train_Improvement)+' no train loss reduction => Best model loaded and LR reduced to '+str(self.currentLR)+' !') if self.currentLR < EARLY_STOP_LR_TOO_LOW: return True self.noImprovement = 0 else: self.noImprovement = 0 self.bestValue = newTrainLoss torch.save(self.model.state_dict(), self.foldResultsModelPath + '/currentBestTrainModel.pt') return False # when utilizing a validation data set as recommended/commonly suggested def stepTrainVal(self, newValScore, logger): # Update learning rate if newValScore <= self.bestValue: self.noImprovement += 1 if self.noImprovement >= LR_Reduce_No_Val_Improvement: self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestValModel.pt')) self.update_lr_by_divison(self.divideLRfactor) logger.info('### After ' + str(LR_Reduce_No_Val_Improvement) + ' no val score improvement => Best model loaded and LR reduced to ' + str(self.currentLR) + ' !') if self.currentLR < EARLY_STOP_LR_TOO_LOW: return True self.noImprovement = 0 else: self.noImprovement = 0 self.bestValue = newValScore torch.save(self.model.state_dict(), self.foldResultsModelPath + '/currentBestValModel.pt') return False # divides learning rate of network by 'factor' def update_lr_by_divison(self, factor): newLR = self.currentLR / factor for param_group in self.optimizer.param_groups: param_group['lr'] = newLR self.currentLR = newLR # loads current model with highest validation accuracy into current model def loadBestValIntoModel(self): self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestValModel.pt'))