diff --git a/lrScheduler.py b/lrScheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..003fa972b28d401fa91c24467294b5acad8ded4a --- /dev/null +++ b/lrScheduler.py @@ -0,0 +1,75 @@ +# this class represents a learning rate schedular that decreases the learning rate by a given factor after the validation error has not fallen for a specified amount of epochs + +import torch + +LR_Reduce_No_Train_Improvement = 15 +LR_Reduce_No_Val_Improvement = 15 +EARLY_STOP_LR_TOO_LOW = 4e-6 + +# This learning rate scheduler works as follows: After 15 epochs of no improvement on the validation loss, the learning rate gets divided by a specified factor. +# Training terminates if the learning rate has fallen below 4E-6, then the best model on the validation loss (with highest validation accuracy) will be chosen as the final model. +class MyLRScheduler(): + def __init__(self, optimizer, model, foldResultsModelPath, setting, initLR, divideLRfactor): + self.optimizer = optimizer + self.model = model + self.foldResultsModelPath = foldResultsModelPath + self.currentLR = initLR + self.divideLRfactor = divideLRfactor + + self.noImprovement = 0 + + if 'val' in setting: + self.bestValue = -1 + else: + self.bestValue = 1E4 + + # either way you train without utilizing a validation data set, then instead of the later, everything will be performed on the training data set! + def stepTrain(self, newTrainLoss, logger): + # Update learning rate + if newTrainLoss >= self.bestValue: + self.noImprovement += 1 + + if self.noImprovement >= LR_Reduce_No_Train_Improvement: + self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestTrainModel.pt')) + self.update_lr_by_divison(self.divideLRfactor) + logger.info('### After '+str(LR_Reduce_No_Train_Improvement)+' no train loss reduction => Best model loaded and LR reduced to '+str(self.currentLR)+' !') + if self.currentLR < EARLY_STOP_LR_TOO_LOW: + return True + self.noImprovement = 0 + else: + self.noImprovement = 0 + self.bestValue = newTrainLoss + torch.save(self.model.state_dict(), self.foldResultsModelPath + '/currentBestTrainModel.pt') + + return False + + # when utilizing a validation data set as recommended/commonly suggested + def stepTrainVal(self, newValScore, logger): + # Update learning rate + if newValScore <= self.bestValue: + self.noImprovement += 1 + + if self.noImprovement >= LR_Reduce_No_Val_Improvement: + self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestValModel.pt')) + self.update_lr_by_divison(self.divideLRfactor) + logger.info('### After ' + str(LR_Reduce_No_Val_Improvement) + ' no val score improvement => Best model loaded and LR reduced to ' + str(self.currentLR) + ' !') + if self.currentLR < EARLY_STOP_LR_TOO_LOW: + return True + self.noImprovement = 0 + else: + self.noImprovement = 0 + self.bestValue = newValScore + torch.save(self.model.state_dict(), self.foldResultsModelPath + '/currentBestValModel.pt') + + return False + + # divides learning rate of network by 'factor' + def update_lr_by_divison(self, factor): + newLR = self.currentLR / factor + for param_group in self.optimizer.param_groups: + param_group['lr'] = newLR + self.currentLR = newLR + + # loads current model with highest validation accuracy into current model + def loadBestValIntoModel(self): + self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestValModel.pt'))