Skip to content
Snippets Groups Projects
Commit 1357faaa authored by Nassim Bouteldja's avatar Nassim Bouteldja
Browse files

Upload New File

parent 35610bfb
No related branches found
No related tags found
No related merge requests found
# this class represents a learning rate schedular that decreases the learning rate by a given factor after the validation error has not fallen for a specified amount of epochs
import torch
LR_Reduce_No_Train_Improvement = 15
LR_Reduce_No_Val_Improvement = 15
EARLY_STOP_LR_TOO_LOW = 4e-6
# This learning rate scheduler works as follows: After 15 epochs of no improvement on the validation loss, the learning rate gets divided by a specified factor.
# Training terminates if the learning rate has fallen below 4E-6, then the best model on the validation loss (with highest validation accuracy) will be chosen as the final model.
class MyLRScheduler():
def __init__(self, optimizer, model, foldResultsModelPath, setting, initLR, divideLRfactor):
self.optimizer = optimizer
self.model = model
self.foldResultsModelPath = foldResultsModelPath
self.currentLR = initLR
self.divideLRfactor = divideLRfactor
self.noImprovement = 0
if 'val' in setting:
self.bestValue = -1
else:
self.bestValue = 1E4
# either way you train without utilizing a validation data set, then instead of the later, everything will be performed on the training data set!
def stepTrain(self, newTrainLoss, logger):
# Update learning rate
if newTrainLoss >= self.bestValue:
self.noImprovement += 1
if self.noImprovement >= LR_Reduce_No_Train_Improvement:
self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestTrainModel.pt'))
self.update_lr_by_divison(self.divideLRfactor)
logger.info('### After '+str(LR_Reduce_No_Train_Improvement)+' no train loss reduction => Best model loaded and LR reduced to '+str(self.currentLR)+' !')
if self.currentLR < EARLY_STOP_LR_TOO_LOW:
return True
self.noImprovement = 0
else:
self.noImprovement = 0
self.bestValue = newTrainLoss
torch.save(self.model.state_dict(), self.foldResultsModelPath + '/currentBestTrainModel.pt')
return False
# when utilizing a validation data set as recommended/commonly suggested
def stepTrainVal(self, newValScore, logger):
# Update learning rate
if newValScore <= self.bestValue:
self.noImprovement += 1
if self.noImprovement >= LR_Reduce_No_Val_Improvement:
self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestValModel.pt'))
self.update_lr_by_divison(self.divideLRfactor)
logger.info('### After ' + str(LR_Reduce_No_Val_Improvement) + ' no val score improvement => Best model loaded and LR reduced to ' + str(self.currentLR) + ' !')
if self.currentLR < EARLY_STOP_LR_TOO_LOW:
return True
self.noImprovement = 0
else:
self.noImprovement = 0
self.bestValue = newValScore
torch.save(self.model.state_dict(), self.foldResultsModelPath + '/currentBestValModel.pt')
return False
# divides learning rate of network by 'factor'
def update_lr_by_divison(self, factor):
newLR = self.currentLR / factor
for param_group in self.optimizer.param_groups:
param_group['lr'] = newLR
self.currentLR = newLR
# loads current model with highest validation accuracy into current model
def loadBestValIntoModel(self):
self.model.load_state_dict(torch.load(self.foldResultsModelPath + '/currentBestValModel.pt'))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment