Select Git revision
nest_controller.py
-
Jan Müller authored
#6
Jan Müller authored#6
lrScheduler.py 3.55 KiB
# this class represents a learning rate schedular that decreases the learning rate by a given factor after the validation error has not fallen for a specified amount of epochs
import torch
LR_Reduce_No_Train_Improvement = 15
LR_Reduce_No_Val_Improvement = 15
EARLY_STOP_LR_TOO_LOW = 4e-6
# This learning rate scheduler works as follows: After 15 epochs of no improvement on the validation loss, the learning rate gets divided by a specified factor.
# Training terminates if the learning rate has fallen below 4E-6, then the best model on the validation loss (with highest validation accuracy) will be chosen as the final model.
class MyLRScheduler():
def __init__(self, optimizer, model, foldResultsModelPath, setting, initLR, divideLRfactor):
self.optimizer = optimizer
self.model = model
self.foldResultsModelPath = foldResultsModelPath
self.currentLR = initLR
self.divideLRfactor = divideLRfactor
self.noImprovement = 0
if 'val' in setting:
self.bestValue = -1
else:
self.bestValue = 1E4
# either way you train without utilizing a validation data set, then instead of the later, everything will be performed on the training data set!
def stepTrain(self, newTrainLoss, logger):
# Update learning rate
if newTrainLoss >= self.bestValue:
self.noImprovement += 1
if self.noImprovement >= LR_Reduce_No_Train_Improvement:
self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestTrainModel.pt'))
self.update_lr_by_divison(self.divideLRfactor)
logger.info('### After '+str(LR_Reduce_No_Train_Improvement)+' no train loss reduction => Best model loaded and LR reduced to '+str(self.currentLR)+' !')
if self.currentLR < EARLY_STOP_LR_TOO_LOW:
return True
self.noImprovement = 0
else:
self.noImprovement = 0
self.bestValue = newTrainLoss
torch.save(self.model.state_dict(), self.foldResultsModelPath + '/currentBestTrainModel.pt')
return False
# when utilizing a validation data set as recommended/commonly suggested
def stepTrainVal(self, newValScore, logger):
# Update learning rate
if newValScore <= self.bestValue:
self.noImprovement += 1
if self.noImprovement >= LR_Reduce_No_Val_Improvement:
self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestValModel.pt'))
self.update_lr_by_divison(self.divideLRfactor)
logger.info('### After ' + str(LR_Reduce_No_Val_Improvement) + ' no val score improvement => Best model loaded and LR reduced to ' + str(self.currentLR) + ' !')
if self.currentLR < EARLY_STOP_LR_TOO_LOW:
return True
self.noImprovement = 0
else:
self.noImprovement = 0
self.bestValue = newValScore
torch.save(self.model.state_dict(), self.foldResultsModelPath + '/currentBestValModel.pt')
return False
# divides learning rate of network by 'factor'
def update_lr_by_divison(self, factor):
newLR = self.currentLR / factor
for param_group in self.optimizer.param_groups:
param_group['lr'] = newLR
self.currentLR = newLR
# loads current model with highest validation accuracy into current model
def loadBestValIntoModel(self):
self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestValModel.pt'))