Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
FLASH
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Labooratory AI
FLASH
Commits
5da73bb4
Commit
5da73bb4
authored
3 years ago
by
Nassim Bouteldja
Browse files
Options
Downloads
Patches
Plain Diff
Upload New File
parent
7ddbfbc1
No related branches found
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
training.py
+426
-0
426 additions, 0 deletions
training.py
with
426 additions
and
0 deletions
training.py
0 → 100644
+
426
−
0
View file @
5da73bb4
# This file performs model training
import
os
import
shutil
import
numpy
as
np
import
logging
as
log
import
time
import
sys
# from sklearn.metrics import classification_report, confusion_matrix
import
torch
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
from
torch.utils.data.sampler
import
SubsetRandomSampler
from
torch.optim.lr_scheduler
import
MultiStepLR
,
ExponentialLR
from
tensorboardX
import
SummaryWriter
from
RAdam
import
RAdam
from
dataset
import
CustomDataSetRAM
from
model
import
Custom
from
utils
import
getCrossValSplits
,
parse_nvidia_smi
,
parse_RAM_info
,
countParam
,
getDiceScores
,
getDiceScoresSinglePair
,
getMeanDiceScores
,
convert_labelmap_to_rgb
,
saveFigureResults
,
printResults
from
loss
import
DiceLoss
from
lrScheduler
import
MyLRScheduler
from
postprocessing
import
postprocessPredictionAndGT
,
extractInstanceChannels
from
evaluation
import
ClassEvaluator
from
nnUnet.generic_UNet
import
Generic_UNet
import
warnings
warnings
.
filterwarnings
(
"
ignore
"
)
#################################### General GPU settings ####################################
GPUno
=
0
useAllAvailableGPU
=
True
device
=
torch
.
device
(
"
cuda:
"
+
str
(
GPUno
)
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
##################################### Save test results ######################################
saveFinalTestResults
=
True
############################### Apply Test Time Augmentation #################################
applyTestTimeAugmentation
=
True
##############################################################################################
# this method trains a network with the given specification
def
train
(
model
,
setting
,
optimizer
,
scheduler
,
epochs
,
batchSize
,
logger
,
resultsPath
,
tbWriter
,
allClassEvaluators
):
model
.
to
(
device
)
if
torch
.
cuda
.
device_count
()
>
1
and
useAllAvailableGPU
:
logger
.
info
(
'
# {} GPUs utilized! #
'
.
format
(
torch
.
cuda
.
device_count
()))
model
=
nn
.
DataParallel
(
model
)
# mandatory to produce random numpy numbers during training, otherwise batches will contain equal random numbers (originally: numpy issue)
def
worker_init_fn
(
worker_id
):
np
.
random
.
seed
(
np
.
random
.
get_state
()[
1
][
0
]
+
worker_id
)
# allocate and separately load train / val / test data sets
dataset_Train
=
CustomDataSetRAM
(
'
train
'
,
logger
)
dataloader_Train
=
DataLoader
(
dataset
=
dataset_Train
,
batch_size
=
batchSize
,
shuffle
=
True
,
num_workers
=
4
,
worker_init_fn
=
worker_init_fn
)
if
'
val
'
in
setting
:
dataset_Val
=
CustomDataSetRAM
(
'
val
'
,
logger
)
dataloader_Val
=
DataLoader
(
dataset
=
dataset_Val
,
batch_size
=
batchSize
,
shuffle
=
False
,
num_workers
=
1
,
worker_init_fn
=
worker_init_fn
)
if
'
test
'
in
setting
:
dataset_Test
=
CustomDataSetRAM
(
'
test
'
,
logger
)
dataloader_Test
=
DataLoader
(
dataset
=
dataset_Test
,
batch_size
=
batchSize
,
shuffle
=
False
,
num_workers
=
1
,
worker_init_fn
=
worker_init_fn
)
logger
.
info
(
'
####### DATA LOADED - TRAINING STARTS... #######
'
)
# Utilize dice loss and weighted cross entropy loss, ignore index 8 as this is area outside the image included by augmentation, e.g. due to image rotation
Dice_Loss
=
DiceLoss
(
ignore_index
=
8
).
to
(
device
)
CE_Loss
=
nn
.
CrossEntropyLoss
(
weight
=
torch
.
FloatTensor
([
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
10.
]),
ignore_index
=
8
).
to
(
device
)
for
epoch
in
range
(
epochs
):
model
.
train
(
True
)
epochCELoss
=
0
epochDiceLoss
=
0
epochLoss
=
0
np
.
random
.
seed
()
start
=
time
.
time
()
for
batch
in
dataloader_Train
:
# get data and put onto device
imgBatch
,
segBatch
=
batch
imgBatch
=
imgBatch
.
to
(
device
)
segBatch
=
segBatch
.
to
(
device
)
optimizer
.
zero_grad
()
# forward image batch, compute loss and backprop
prediction
=
model
(
imgBatch
)
CEloss
=
CE_Loss
(
prediction
,
segBatch
)
diceLoss
=
Dice_Loss
(
prediction
,
segBatch
)
loss
=
CEloss
+
diceLoss
epochCELoss
+=
CEloss
.
item
()
epochDiceLoss
+=
diceLoss
.
item
()
epochLoss
+=
loss
.
item
()
loss
.
backward
()
optimizer
.
step
()
epochTrainLoss
=
epochLoss
/
dataloader_Train
.
__len__
()
end
=
time
.
time
()
# print current training loss
logger
.
info
(
'
[Epoch
'
+
str
(
epoch
+
1
)
+
'
] Train-Loss:
'
+
str
(
round
(
epochTrainLoss
,
5
))
+
'
, DiceLoss:
'
+
str
(
round
(
epochDiceLoss
/
dataloader_Train
.
__len__
(),
5
))
+
'
, CELoss:
'
+
str
(
round
(
epochCELoss
/
dataloader_Train
.
__len__
(),
5
))
+
'
[took
'
+
str
(
round
(
end
-
start
,
3
))
+
'
s]
'
)
# use tensorboard for visualization of training progress
tbWriter
.
add_scalars
(
'
Plot/train
'
,
{
'
loss
'
:
epochTrainLoss
,
'
CEloss
'
:
epochCELoss
/
dataloader_Train
.
__len__
(),
'
DiceLoss
'
:
epochDiceLoss
/
dataloader_Train
.
__len__
()},
epoch
)
# each 50th epoch add prediction image to tensorboard
if
epoch
%
50
==
0
:
with
torch
.
no_grad
():
tbWriter
.
add_image
(
'
Train/_img
'
,
torch
.
round
((
imgBatch
[
0
,:,:,:]
+
1.
)
/
2.
*
255.0
).
byte
()
,
epoch
)
tbWriter
.
add_image
(
'
Train/GT
'
,
convert_labelmap_to_rgb
(
segBatch
[
0
,:,:].
cpu
()),
epoch
)
tbWriter
.
add_image
(
'
Train/pred
'
,
convert_labelmap_to_rgb
(
prediction
[
0
,:,:,:].
argmax
(
0
).
cpu
()),
epoch
)
if
epoch
%
100
==
0
:
logger
.
info
(
'
[Epoch
'
+
str
(
epoch
+
1
)
+
'
]
'
+
parse_nvidia_smi
(
GPUno
))
logger
.
info
(
'
[Epoch
'
+
str
(
epoch
+
1
)
+
'
]
'
+
parse_RAM_info
())
# if validation was included, compute dice scores on validation data
if
'
val
'
in
setting
:
model
.
train
(
False
)
diceScores_Val
=
[]
start
=
time
.
time
()
for
batch
in
dataloader_Val
:
imgBatch
,
segBatch
=
batch
imgBatch
=
imgBatch
.
to
(
device
)
# segBatch = segBatch.to(device)
with
torch
.
no_grad
():
prediction
=
model
(
imgBatch
).
to
(
'
cpu
'
)
diceScores_Val
.
append
(
getDiceScores
(
prediction
,
segBatch
))
diceScores_Val
=
np
.
concatenate
(
diceScores_Val
,
0
)
# <- all dice scores of val data (batchSize x amountClasses-1)
diceScores_Val
=
diceScores_Val
[:,
:
-
1
]
# ignore last coloum=border dice scores
mean_DiceScores_Val
,
epoch_val_mean_score
=
getMeanDiceScores
(
diceScores_Val
,
logger
)
end
=
time
.
time
()
logger
.
info
(
'
[Epoch
'
+
str
(
epoch
+
1
)
+
'
] Val-Score (mean label dice scores):
'
+
str
(
np
.
round
(
mean_DiceScores_Val
,
4
))
+
'
, Mean:
'
+
str
(
round
(
epoch_val_mean_score
,
4
))
+
'
[took
'
+
str
(
round
(
end
-
start
,
3
))
+
'
s]
'
)
tbWriter
.
add_scalar
(
'
Plot/val
'
,
epoch_val_mean_score
,
epoch
)
if
epoch
%
50
==
0
:
with
torch
.
no_grad
():
tbWriter
.
add_image
(
'
Val/_img
'
,
torch
.
round
((
imgBatch
[
0
,:,:,:]
+
1.
)
/
2.
*
255.0
).
byte
(),
epoch
)
tbWriter
.
add_image
(
'
Val/GT
'
,
convert_labelmap_to_rgb
(
segBatch
[
0
,
:,
:].
cpu
()),
epoch
)
tbWriter
.
add_image
(
'
Val/pred
'
,
convert_labelmap_to_rgb
(
prediction
[
0
,
:,
:,
:].
argmax
(
0
).
cpu
()),
epoch
)
if
epoch
%
100
==
0
:
logger
.
info
(
'
[Epoch
'
+
str
(
epoch
+
1
)
+
'
- After Validation]
'
+
parse_nvidia_smi
(
GPUno
))
logger
.
info
(
'
[Epoch
'
+
str
(
epoch
+
1
)
+
'
- After Validation]
'
+
parse_RAM_info
())
# scheduler.step()
if
'
val
'
in
setting
:
endLoop
=
scheduler
.
stepTrainVal
(
epoch_val_mean_score
,
logger
)
else
:
endLoop
=
scheduler
.
stepTrain
(
epochTrainLoss
,
logger
)
if
epoch
==
(
epochs
-
1
):
#when no early stop is performed, load bestValModel into current model for later save
logger
.
info
(
'
### No early stop performed! Best val model loaded... ####
'
)
if
'
val
'
in
setting
:
scheduler
.
loadBestValIntoModel
()
# if test was included, compute global dice scores on test data (without postprocessing) for fast and coarse performance check
if
'
test
'
in
setting
:
model
.
train
(
False
);
model
.
eval
()
diceScores_Test
=
[]
start
=
time
.
time
()
for
batch
in
dataloader_Test
:
imgBatch
,
segBatch
=
batch
imgBatch
=
imgBatch
.
to
(
device
)
# segBatch = segBatch.to(device)
with
torch
.
no_grad
():
prediction
=
model
(
imgBatch
).
to
(
'
cpu
'
)
diceScores_Test
.
append
(
getDiceScores
(
prediction
,
segBatch
))
diceScores_Test
=
np
.
concatenate
(
diceScores_Test
,
0
)
# <- all dice scores of test data (amountTestData x amountClasses-1)
diceScores_Test
=
diceScores_Test
[:,:
-
1
]
#ignore last coloum=border dice scores
mean_DiceScores_Test
,
test_mean_score
=
getMeanDiceScores
(
diceScores_Test
,
logger
)
end
=
time
.
time
()
logger
.
info
(
'
[Epoch
'
+
str
(
epoch
+
1
)
+
'
] Test-Score (mean label dice scores):
'
+
str
(
np
.
round
(
mean_DiceScores_Test
,
4
))
+
'
, Mean:
'
+
str
(
round
(
test_mean_score
,
4
))
+
'
[took
'
+
str
(
round
(
end
-
start
,
3
))
+
'
s]
'
)
tbWriter
.
add_scalar
(
'
Plot/test
'
,
test_mean_score
,
epoch
)
if
epoch
%
50
==
0
:
with
torch
.
no_grad
():
tbWriter
.
add_image
(
'
Test/_img
'
,
torch
.
round
((
imgBatch
[
0
,:,:,:]
+
1.
)
/
2.
*
255.0
).
byte
(),
epoch
)
tbWriter
.
add_image
(
'
Test/GT
'
,
convert_labelmap_to_rgb
(
segBatch
[
0
,
:,
:].
cpu
()),
epoch
)
tbWriter
.
add_image
(
'
Test/pred
'
,
convert_labelmap_to_rgb
(
prediction
[
0
,
:,
:,
:].
argmax
(
0
).
cpu
()),
epoch
)
if
epoch
%
100
==
0
:
logger
.
info
(
'
[Epoch
'
+
str
(
epoch
+
1
)
+
'
- After Testing]
'
+
parse_nvidia_smi
(
GPUno
))
logger
.
info
(
'
[Epoch
'
+
str
(
epoch
+
1
)
+
'
- After Testing]
'
+
parse_RAM_info
())
with
torch
.
no_grad
():
### if training is over, compute final performances using the instance-level dice score and average precision ###
if
endLoop
or
(
epoch
==
epochs
-
1
):
diceScores_Test
=
[]
diceScores_Test_TTA
=
[]
# iterate through all test images
test_idx
=
np
.
arange
(
dataset_Test
.
__len__
())
for
sampleNo
in
test_idx
:
imgBatch
,
segBatch
=
dataset_Test
.
__getitem__
(
sampleNo
)
imgBatch
=
imgBatch
.
unsqueeze
(
0
).
to
(
device
)
segBatch
=
segBatch
.
unsqueeze
(
0
)
# get prediction and postprocess it
prediction
=
model
(
imgBatch
)
postprocessedPrediction
,
outputPrediction
,
preprocessedGT
=
postprocessPredictionAndGT
(
prediction
,
segBatch
.
squeeze
(
0
).
numpy
(),
device
=
device
,
predictionsmoothing
=
False
,
holefilling
=
True
)
classInstancePredictionList
,
classInstanceGTList
,
finalPredictionRGB
,
preprocessedGTrgb
=
extractInstanceChannels
(
postprocessedPrediction
,
preprocessedGT
,
tubuliDilation
=
True
)
# here the evaluation is performed
# evaluate performance (TP, NP, FP counting and instance dice score computation)
for
i
in
range
(
6
):
#number classes to evaluate = 6
allClassEvaluators
[
0
][
i
].
add_example
(
classInstancePredictionList
[
i
],
classInstanceGTList
[
i
])
# there are regular dice similarity scores
diceScores_Test
.
append
(
getDiceScoresSinglePair
(
postprocessedPrediction
,
preprocessedGT
,
tubuliDilation
=
True
))
#dilates 'postprocessedPrediction' permanently
if
saveFinalTestResults
:
figFolder
=
resultsPath
if
not
os
.
path
.
exists
(
figFolder
):
os
.
makedirs
(
figFolder
)
imgBatchCPU
=
torch
.
round
((
imgBatch
[
0
,
:,
:,
:].
to
(
"
cpu
"
)
+
1.
)
/
2.
*
255.0
).
byte
().
numpy
().
transpose
(
1
,
2
,
0
)
# figPath = figFolder + '/test_idx_' + str(sampleNo) + '_result.png'
# saveFigureResults(imgBatchCPU, outputPrediction, postprocessedPrediction, finalPredictionRGB, segBatch.squeeze(0).numpy(), preprocessedGT, preprocessedGTrgb, fullResultPath=figPath, alpha=0.4)
if
applyTestTimeAugmentation
:
#perform test-time augmentation
prediction
=
torch
.
softmax
(
prediction
,
1
)
imgBatch
=
imgBatch
.
flip
(
2
)
prediction
+=
torch
.
softmax
(
model
(
imgBatch
),
1
).
flip
(
2
)
imgBatch
=
imgBatch
.
flip
(
3
)
prediction
+=
torch
.
softmax
(
model
(
imgBatch
),
1
).
flip
(
3
).
flip
(
2
)
imgBatch
=
imgBatch
.
flip
(
2
)
prediction
+=
torch
.
softmax
(
model
(
imgBatch
),
1
).
flip
(
3
)
prediction
/=
4.
postprocessedPrediction
,
outputPrediction
,
preprocessedGT
=
postprocessPredictionAndGT
(
prediction
,
segBatch
.
squeeze
(
0
).
numpy
(),
device
=
device
,
predictionsmoothing
=
False
,
holefilling
=
True
)
classInstancePredictionList
,
classInstanceGTList
,
finalPredictionRGB
,
preprocessedGTrgb
=
extractInstanceChannels
(
postprocessedPrediction
,
preprocessedGT
,
tubuliDilation
=
False
)
for
i
in
range
(
6
):
allClassEvaluators
[
1
][
i
].
add_example
(
classInstancePredictionList
[
i
],
classInstanceGTList
[
i
])
diceScores_Test_TTA
.
append
(
getDiceScoresSinglePair
(
postprocessedPrediction
,
preprocessedGT
,
tubuliDilation
=
True
))
#dilates 'postprocessedPrediction' permanently
if
saveFinalTestResults
:
figPath
=
figFolder
+
'
/test_idx_
'
+
str
(
sampleNo
)
+
'
_result_TTA.png
'
saveFigureResults
(
imgBatchCPU
,
outputPrediction
,
postprocessedPrediction
,
finalPredictionRGB
,
segBatch
.
squeeze
(
0
).
numpy
(),
preprocessedGT
,
preprocessedGTrgb
,
fullResultPath
=
figPath
,
alpha
=
0.4
)
logger
.
info
(
'
############################### RESULTS ###############################
'
)
# print regular dice similarity coefficients as coarse performance check
diceScores_Test
=
np
.
concatenate
(
diceScores_Test
,
0
)
# <- all dice scores of test data (amountTestData x amountClasses-1)
diceScores_Test
=
diceScores_Test
[:,
:
-
1
]
# ignore last coloum=border dice scores
mean_DiceScores_Test
,
test_mean_score
=
getMeanDiceScores
(
diceScores_Test
,
logger
)
logger
.
info
(
'
MEAN DICE SCORES:
'
+
str
(
np
.
round
(
mean_DiceScores_Test
,
4
))
+
'
, Overall mean:
'
+
str
(
round
(
test_mean_score
,
4
)))
np
.
savetxt
(
resultsPath
+
'
/allTestDiceScores.csv
'
,
diceScores_Test
,
delimiter
=
'
,
'
)
# print regular dice similarity coefficients as coarse performance check
if
applyTestTimeAugmentation
:
diceScores_Test_TTA
=
np
.
concatenate
(
diceScores_Test_TTA
,
0
)
# <- all dice scores of test data (amountTestData x amountClasses-1)
diceScores_Test_TTA
=
diceScores_Test_TTA
[:,
:
-
1
]
# ignore last coloum=border dice scores
mean_DiceScores_Test_TTA
,
test_mean_score_TTA
=
getMeanDiceScores
(
diceScores_Test_TTA
,
logger
)
logger
.
info
(
'
TTA - MEAN DICE SCORES:
'
+
str
(
np
.
round
(
mean_DiceScores_Test_TTA
,
4
))
+
'
, Overall mean:
'
+
str
(
round
(
test_mean_score_TTA
,
4
)))
np
.
savetxt
(
resultsPath
+
'
/allTestDiceScores_TTA.csv
'
,
diceScores_Test_TTA
,
delimiter
=
'
,
'
)
printResults
(
allClassEvaluators
=
allClassEvaluators
,
applyTestTimeAugmentation
=
applyTestTimeAugmentation
,
printOnlyTTAresults
=
True
,
logger
=
logger
,
saveNumpyResults
=
False
,
resultsPath
=
resultsPath
)
if
endLoop
:
logger
.
info
(
'
### Early network training stop at epoch
'
+
str
(
epoch
+
1
)
+
'
! ###
'
)
break
logger
.
info
(
'
[Epoch
'
+
str
(
epoch
+
1
)
+
'
] ### Training done! ###
'
)
return
model
def
set_up_training
(
modelString
,
setting
,
epochs
,
batchSize
,
lrate
,
weightDecay
,
logger
,
resultsPath
):
logger
.
info
(
'
### SETTING -> {} <- ###
'
.
format
(
setting
.
upper
()))
# class evaluation modules for each structure and with or w/o test-time augmentation
classEvaluators
=
[
ClassEvaluator
(),
ClassEvaluator
(),
ClassEvaluator
(),
ClassEvaluator
(),
ClassEvaluator
(),
ClassEvaluator
()]
classEvaluatorsTTA
=
[
ClassEvaluator
(),
ClassEvaluator
(),
ClassEvaluator
(),
ClassEvaluator
(),
ClassEvaluator
(),
ClassEvaluator
()]
allClassEvaluators
=
[
classEvaluators
,
classEvaluatorsTTA
]
resultsModelPath
=
resultsPath
+
'
/Model
'
if
not
os
.
path
.
exists
(
resultsModelPath
):
os
.
makedirs
(
resultsModelPath
)
# setting up tensorboard visualization
tensorboardPath
=
resultsPath
+
'
/TB
'
shutil
.
rmtree
(
tensorboardPath
,
ignore_errors
=
True
)
#<- remove existing TB events
tbWriter
=
SummaryWriter
(
log_dir
=
tensorboardPath
)
if
modelString
==
'
custom
'
:
model
=
Custom
(
input_ch
=
3
,
output_ch
=
8
,
modelDim
=
2
)
elif
modelString
==
'
nnunet
'
:
model
=
Generic_UNet
(
input_channels
=
3
,
num_classes
=
8
,
base_num_features
=
30
,
num_pool
=
7
,
final_nonlin
=
None
,
deep_supervision
=
False
,
dropout_op_kwargs
=
{
'
p
'
:
0.0
,
'
inplace
'
:
True
})
else
:
raise
ValueError
(
'
Given model >
'
+
modelString
+
'
< is invalid!
'
)
logger
.
info
(
model
)
logger
.
info
(
'
Model capacity: {} parameters.
'
.
format
(
countParam
(
model
)))
# set up optimizer
optimizer
=
RAdam
(
model
.
parameters
(),
lr
=
lrate
,
weight_decay
=
weightDecay
)
# set up scheduler
# scheduler = MultiStepLR(optimizer, milestones=[5, 15, 20, 25], gamma=0.3)
# scheduler = ExponentialLR(optimizer, gamma=0.99)
scheduler
=
MyLRScheduler
(
optimizer
,
model
,
resultsModelPath
,
setting
,
initLR
=
lrate
,
divideLRfactor
=
3.0
)
trained_model
=
train
(
model
,
setting
,
optimizer
,
scheduler
,
epochs
,
batchSize
,
logger
,
resultsPath
,
tbWriter
,
allClassEvaluators
)
# save final model (when validation is included, the model with lowest validation error is saved)
torch
.
save
(
trained_model
.
state_dict
(),
resultsModelPath
+
'
/finalModel.pt
'
)
if
'
__main__
'
==
__name__
:
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
'
python training.py -m <model-type> -d <dataset> -s <train_valid_test> -e <epochs>
'
+
'
-b <batch-size> -r <learning-rate> -w <weight-decay>
'
)
parser
.
add_argument
(
'
-m
'
,
'
--model
'
,
default
=
'
custom
'
)
parser
.
add_argument
(
'
-s
'
,
'
--setting
'
,
default
=
'
train_val_test
'
)
parser
.
add_argument
(
'
-e
'
,
'
--epochs
'
,
default
=
500
,
type
=
int
)
parser
.
add_argument
(
'
-b
'
,
'
--batchSize
'
,
default
=
6
,
type
=
int
)
parser
.
add_argument
(
'
-r
'
,
'
--lrate
'
,
default
=
0.001
,
type
=
float
)
parser
.
add_argument
(
'
-w
'
,
'
--weightDecay
'
,
default
=
0.00001
,
type
=
float
)
options
=
parser
.
parse_args
()
assert
(
options
.
model
in
[
'
custom
'
,
'
unet
'
,
'
CEnet2D
'
,
'
CE_Net_Inception_Variants_2D
'
,
'
nnunet
'
])
assert
(
options
.
setting
in
[
'
train_val_test
'
,
'
train_test
'
,
'
train_val
'
,
'
train
'
])
assert
(
options
.
epochs
>
0
)
assert
(
options
.
batchSize
>
0
)
assert
(
options
.
lrate
>
0
)
assert
(
options
.
weightDecay
>
0
)
# Results path
resultsPath
=
'
SPECIFY RESULTS PATH
'
if
not
os
.
path
.
exists
(
resultsPath
):
os
.
makedirs
(
resultsPath
)
# Set up logger
log
.
basicConfig
(
level
=
log
.
INFO
,
format
=
'
%(asctime)s %(message)s
'
,
datefmt
=
'
%Y/%m/%d %I:%M:%S %p
'
,
handlers
=
[
log
.
FileHandler
(
resultsPath
+
'
/LOGS.log
'
,
'
w
'
),
log
.
StreamHandler
(
sys
.
stdout
)
])
logger
=
log
.
getLogger
()
logger
.
info
(
'
###### STARTED PROGRAM WITH OPTIONS: {} ######
'
.
format
(
str
(
options
)))
torch
.
backends
.
cudnn
.
benchmark
=
True
try
:
# start whole training and evaluation procedure
set_up_training
(
modelString
=
options
.
model
,
setting
=
options
.
setting
,
epochs
=
options
.
epochs
,
batchSize
=
options
.
batchSize
,
lrate
=
options
.
lrate
,
weightDecay
=
options
.
weightDecay
,
logger
=
logger
,
resultsPath
=
resultsPath
)
except
:
logger
.
exception
(
'
! Exception !
'
)
raise
log
.
info
(
'
%%%% Ended regularly ! %%%%
'
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment