From 307668c0f666b849cd638dba0e81b4d11408f562 Mon Sep 17 00:00:00 2001 From: Ycblue <yuchialan@gmail.com> Date: Mon, 30 May 2022 15:27:32 +0200 Subject: [PATCH] working --- .../{TransMIL_simple.yaml => Resnet50.yaml} | 5 +- DeepGraft/TransMIL_debug.yaml | 50 ++ DeepGraft/TransMIL_dino.yaml | 6 +- DeepGraft/TransMIL_efficientnet_no_other.yaml | 48 ++ DeepGraft/TransMIL_efficientnet_no_viral.yaml | 48 ++ .../TransMIL_efficientnet_tcmr_viral.yaml | 48 ++ DeepGraft/TransMIL_resnet18_all.yaml | 2 +- DeepGraft/TransMIL_resnet18_no_other.yaml | 2 +- DeepGraft/TransMIL_resnet18_no_viral.yaml | 2 +- DeepGraft/TransMIL_resnet18_tcmr_viral.yaml | 2 +- DeepGraft/TransMIL_resnet50_all.yaml | 2 +- DeepGraft/TransMIL_resnet50_no_other.yaml | 48 ++ DeepGraft/TransMIL_resnet50_no_viral.yaml | 48 ++ DeepGraft/TransMIL_resnet50_tcmr_viral.yaml | 48 ++ MyBackbone/__init__.py | 2 + MyBackbone/backbone_factory.py | 80 +++ .../__pycache__/loss_factory.cpython-39.pyc | Bin 2411 -> 2545 bytes MyLoss/__pycache__/poly_loss.cpython-39.pyc | Bin 0 -> 2726 bytes MyLoss/loss_factory.py | 5 +- MyLoss/poly_loss.py | 84 +++ README.md | 4 + .../custom_dataloader.cpython-39.pyc | Bin 7729 -> 15301 bytes .../__pycache__/data_interface.cpython-39.pyc | Bin 5690 -> 7011 bytes datasets/custom_dataloader.py | 506 +++++++++++++----- datasets/data_interface.py | 82 ++- fine_tune.py | 42 ++ models/TransMIL.py | 2 +- models/__pycache__/TransMIL.cpython-39.pyc | Bin 3333 -> 3328 bytes .../model_interface.cpython-39.pyc | Bin 10240 -> 11282 bytes models/__pycache__/resnet50.cpython-39.pyc | Bin 0 -> 8628 bytes models/model_interface.py | 164 +++--- models/resnet50.py | 293 ++++++++++ train.py | 27 +- utils/__pycache__/utils.cpython-39.pyc | Bin 2832 -> 3450 bytes utils/extract_features.py | 27 + utils/utils.py | 154 +++++- 36 files changed, 1583 insertions(+), 248 deletions(-) rename DeepGraft/{TransMIL_simple.yaml => Resnet50.yaml} (93%) create mode 100644 DeepGraft/TransMIL_debug.yaml create mode 100644 DeepGraft/TransMIL_efficientnet_no_other.yaml create mode 100644 DeepGraft/TransMIL_efficientnet_no_viral.yaml create mode 100644 DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml create mode 100644 DeepGraft/TransMIL_resnet50_no_other.yaml create mode 100644 DeepGraft/TransMIL_resnet50_no_viral.yaml create mode 100644 DeepGraft/TransMIL_resnet50_tcmr_viral.yaml create mode 100644 MyBackbone/__init__.py create mode 100644 MyBackbone/backbone_factory.py create mode 100644 MyLoss/__pycache__/poly_loss.cpython-39.pyc create mode 100644 MyLoss/poly_loss.py create mode 100644 fine_tune.py create mode 100644 models/__pycache__/resnet50.cpython-39.pyc create mode 100644 models/resnet50.py create mode 100644 utils/extract_features.py diff --git a/DeepGraft/TransMIL_simple.yaml b/DeepGraft/Resnet50.yaml similarity index 93% rename from DeepGraft/TransMIL_simple.yaml rename to DeepGraft/Resnet50.yaml index 4501f2c..e6b780b 100644 --- a/DeepGraft/TransMIL_simple.yaml +++ b/DeepGraft/Resnet50.yaml @@ -5,7 +5,7 @@ General: amp_level: O2 precision: 16 multi_gpu_mode: dp - gpus: [1] + gpus: [0] epochs: &epoch 200 grad_acc: 2 frozen_bn: False @@ -32,9 +32,8 @@ Data: Model: - name: TransMIL + name: resnet50 n_classes: 2 - backbone: simple Optimizer: diff --git a/DeepGraft/TransMIL_debug.yaml b/DeepGraft/TransMIL_debug.yaml new file mode 100644 index 0000000..d83ce0d --- /dev/null +++ b/DeepGraft/TransMIL_debug.yaml @@ -0,0 +1,50 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [1] + epochs: &epoch 200 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_debug.json' + fold: 0 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + + + +Model: + name: TransMIL + n_classes: 2 + backbone: resnet18 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_dino.yaml b/DeepGraft/TransMIL_dino.yaml index ffe987c..b7161ba 100644 --- a/DeepGraft/TransMIL_dino.yaml +++ b/DeepGraft/TransMIL_dino.yaml @@ -6,10 +6,10 @@ General: precision: 16 multi_gpu_mode: dp gpus: [0] - epochs: &epoch 1000 + epochs: &epoch 200 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 50 server: test #train #test log_path: logs/ @@ -17,7 +17,7 @@ Data: dataset_name: custom data_shuffle: False data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' - label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' fold: 0 nfold: 4 diff --git a/DeepGraft/TransMIL_efficientnet_no_other.yaml b/DeepGraft/TransMIL_efficientnet_no_other.yaml new file mode 100644 index 0000000..79d8ea8 --- /dev/null +++ b/DeepGraft/TransMIL_efficientnet_no_other.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [0] + epochs: &epoch 1000 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 5 + backbone: efficientnet + + +Optimizer: + opt: lookahead_radam + lr: 0.0001 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_efficientnet_no_viral.yaml b/DeepGraft/TransMIL_efficientnet_no_viral.yaml new file mode 100644 index 0000000..8780060 --- /dev/null +++ b/DeepGraft/TransMIL_efficientnet_no_viral.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 4 + backbone: efficientnet + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml new file mode 100644 index 0000000..f69b5bf --- /dev/null +++ b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 200 + server: train #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 2 + backbone: efficientnet + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet18_all.yaml b/DeepGraft/TransMIL_resnet18_all.yaml index 8fa5818..f331e4e 100644 --- a/DeepGraft/TransMIL_resnet18_all.yaml +++ b/DeepGraft/TransMIL_resnet18_all.yaml @@ -9,7 +9,7 @@ General: epochs: &epoch 500 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 50 server: test #train #test log_path: logs/ diff --git a/DeepGraft/TransMIL_resnet18_no_other.yaml b/DeepGraft/TransMIL_resnet18_no_other.yaml index 95a9bd6..c7a27f2 100644 --- a/DeepGraft/TransMIL_resnet18_no_other.yaml +++ b/DeepGraft/TransMIL_resnet18_no_other.yaml @@ -9,7 +9,7 @@ General: epochs: &epoch 500 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 50 server: test #train #test log_path: logs/ diff --git a/DeepGraft/TransMIL_resnet18_no_viral.yaml b/DeepGraft/TransMIL_resnet18_no_viral.yaml index 155b676..93054bf 100644 --- a/DeepGraft/TransMIL_resnet18_no_viral.yaml +++ b/DeepGraft/TransMIL_resnet18_no_viral.yaml @@ -9,7 +9,7 @@ General: epochs: &epoch 500 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 50 server: test #train #test log_path: logs/ diff --git a/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml index e7d9bf0..c26e1e9 100644 --- a/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml +++ b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml @@ -9,7 +9,7 @@ General: epochs: &epoch 500 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 50 server: test #train #test log_path: logs/ diff --git a/DeepGraft/TransMIL_resnet50_all.yaml b/DeepGraft/TransMIL_resnet50_all.yaml index eba3a4f..e6959ea 100644 --- a/DeepGraft/TransMIL_resnet50_all.yaml +++ b/DeepGraft/TransMIL_resnet50_all.yaml @@ -9,7 +9,7 @@ General: epochs: &epoch 1000 grad_acc: 2 frozen_bn: False - patience: 200 + patience: 50 server: train #train #test log_path: logs/ diff --git a/DeepGraft/TransMIL_resnet50_no_other.yaml b/DeepGraft/TransMIL_resnet50_no_other.yaml new file mode 100644 index 0000000..d3cd2aa --- /dev/null +++ b/DeepGraft/TransMIL_resnet50_no_other.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [0] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 50 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 5 + backbone: resnet50 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet50_no_viral.yaml b/DeepGraft/TransMIL_resnet50_no_viral.yaml new file mode 100644 index 0000000..2e3394a --- /dev/null +++ b/DeepGraft/TransMIL_resnet50_no_viral.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 50 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 4 + backbone: resnet50 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml new file mode 100644 index 0000000..a756616 --- /dev/null +++ b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml @@ -0,0 +1,48 @@ +General: + comment: + seed: 2021 + fp16: True + amp_level: O2 + precision: 16 + multi_gpu_mode: dp + gpus: [3] + epochs: &epoch 500 + grad_acc: 2 + frozen_bn: False + patience: 50 + server: test #train #test + log_path: logs/ + +Data: + dataset_name: custom + data_shuffle: False + data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/' + label_file: '/home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json' + fold: 1 + nfold: 4 + + train_dataloader: + batch_size: 1 + num_workers: 8 + + test_dataloader: + batch_size: 1 + num_workers: 8 + +Model: + name: TransMIL + n_classes: 2 + backbone: resnet50 + + +Optimizer: + opt: lookahead_radam + lr: 0.0002 + opt_eps: null + opt_betas: null + momentum: null + weight_decay: 0.00001 + +Loss: + base_loss: CrossEntropyLoss + diff --git a/MyBackbone/__init__.py b/MyBackbone/__init__.py new file mode 100644 index 0000000..1fca298 --- /dev/null +++ b/MyBackbone/__init__.py @@ -0,0 +1,2 @@ + +from .backbone_factory import init_backbone \ No newline at end of file diff --git a/MyBackbone/backbone_factory.py b/MyBackbone/backbone_factory.py new file mode 100644 index 0000000..ff770e5 --- /dev/null +++ b/MyBackbone/backbone_factory.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn + +from transformers import AutoFeatureExtractor, ViTModel +from torchvision import models + +def init_backbone(**kargs): + + backbone = kargs['backbone'] + n_classes = kargs['n_classes'] + out_features = kargs['out_features'] + + if backbone == 'dino' or backbone == 'vit': + + if backbone == 'dino': + feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16') + model_ft = ViTModel.from_pretrained('facebook/dino-vitb16', num_labels=n_classes) + + def model_ft(input): + input = feature_extractor(input, return_tensors='pt') + features = model_ft(**input) + + + elif kargs['backbone'] == 'resnet18': + resnet18 = models.resnet18(pretrained=True) + modules = list(resnet18.children())[:-1] + # model_ft.fc = nn.Linear(512, out_features) + + res18 = nn.Sequential( + *modules, + ) + for param in res18.parameters(): + param.requires_grad = False + self.model_ft = nn.Sequential( + res18, + nn.AdaptiveAvgPool2d(1), + View((-1, 512)), + nn.Linear(512, self.out_features), + nn.GELU(), + ) + elif kargs['backbone'] == 'resnet50': + + resnet50 = models.resnet50(pretrained=True) + # model_ft.fc = nn.Linear(1024, out_features) + modules = list(resnet50.children())[:-3] + res50 = nn.Sequential( + *modules, + ) + for param in res50.parameters(): + param.requires_grad = False + self.model_ft = nn.Sequential( + res50, + nn.AdaptiveAvgPool2d(1), + View((-1, 1024)), + nn.Linear(1024, self.out_features), + nn.GELU() + ) + elif kargs['backbone'] == 'efficientnet': + efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_widese_b0', pretrained=True) + for param in efficientnet.parameters(): + param.requires_grad = False + self.model_ft = nn.Sequential( + efficientnet, + nn.Linear(1000, 512), + nn.GELU(), + ) + elif kargs['backbone'] == 'simple': #mil-ab attention + feature_extracting = False + self.model_ft = nn.Sequential( + nn.Conv2d(3, 20, kernel_size=5), + nn.ReLU(), + nn.MaxPool2d(2, stride=2), + nn.Conv2d(20, 50, kernel_size=5), + nn.ReLU(), + nn.MaxPool2d(2, stride=2), + View((-1, 1024)), + nn.Linear(1024, self.out_features), + nn.ReLU(), + ) + diff --git a/MyLoss/__pycache__/loss_factory.cpython-39.pyc b/MyLoss/__pycache__/loss_factory.cpython-39.pyc index ed7437016bbcadc48f9bb0a97401529dc574c9b2..14452dde7b34fd11b5255f4ed07bdeb48a27e49f 100644 GIT binary patch delta 939 zcmaDY^ih~Ek(ZZ?fq{YH{qy`JS+0qEGK^*uwRhArr*P${<f=xgGBTtvrf}z|<*G+% zFfycwrSPTjw=hO&rZA>3<!D7|r3j=5wlG9#r|P5#Wiw4+EHX;fZf1zmP1Q`*ZDx$p zOVv-+Yi49*NM&4LkRqHSvXCjtFqJXOC`CAxG0Qkbw3lgtNs7oq#wgQNy;Q?g(`Ke9 zvlJ$lU<OU`%~6co896xu@^dPE@{5Zn^D&#%ConQF)G%Z*)H0_q)UuQ?mN3;YH!~G! zlrTeBAZPV5*0M}s%rmKBPGQbwDzYhIS-`rG0W79a!<@pB%~TXn!m@x3D#la8oWh#T zRFni0%llQsoWhpPR8&&JvVa|=D(^)Na|(MlQ;~363CjWwu+X&{<`j->rlMt&FEESO zZ>wQW;ml?#+EK!?fEz5oq=q?#E1Rk4MhVLT&JwO>28i)Bj5SOPS!!AGT56b6xU-px zejv&7EM$Por_?Z~@PI>93P}YAk_x99<`iDA3ZsY;mMm6K)PVe+r&Yt8A^?`~DB(>J zOc4T$ilhimU@Qtr5uU(U9O1>##8|@D#8?9gX%W971qKENO~zaNIf+TBIq}80`S~Ro znR)5ACQo2#SBv5($j_<F$uBN;(`3BGUYws+lABmj#p9V*mY9>7q5zW7<e#j;x{<X= zmVtp`@=aC)HbVvmhF^A*`PpQ-bX`(Y3*3tm(@G}mu&GFiFfcF_+kp}tBO5adBNw9x z69*#?BL^eT<SaH7ZbmLf9wq@s7RDmu$y3?(8R;@GFcg7IgAhgx3=Frpauc&t;|mgt zic_sXrg4Bx<6snE<X|kao!rW<tR)Z92KJWuEw1?Z#L|+C{G#~yTdXDdMadbrIKY%% zUS5&L<V);V-8eyBj0gDy?1d<Sf=aN8_>%nmoTSv8k|H^ffm$HK7(|$Z2wRY2IO5~; z5_41I<BN1B-{r6kHUP;W91LcGUCm*Wo1apelWNBZa<c{l0|N`l=S&Jr7R)@10*pM2 bGE6*7JWO1w9E=>ST#Rgt9H3BP6k`Mc@EPyd delta 775 zcmew;{91@Fk(ZZ?fq{X6kzFZiA?HLs8Ah#%+B@nQQ@C<ea#f?$7#UJTQ+QMOS{S3$ zQy5d2ax|heQutE@S{R}<Q?*hAvzaC^7OACbHZw$Nr>duFH#0`*r0S;XG&3?Xq%tnh zOA$&DUdR-spURkJkRp`Im}QtE(#y2KC`EW7W0Y~KPO5&YaWhktNeWXigQi#%TX9BW zepT}3JjV5mlQo%*>s=Tb7-|@@7;2eQ7;0Hc7)zLHn46i3R7#j3tY(H<mI;h`GBwO8 zjM+>@79}hTSQkRXIBJ+vn6jCQyh>OWuz|$#zSS_NFlRFr#lS_L)G(*8WHS}zm9Q*e z2dlYM!<@pJ%~aG-!m@y4AumHMOWuJR<`lMUrlJLt-!Y3zFQ{QoVb5kN+El`_fE%o( zuZB5=Bb%w{(qsh|an8IN<`m9srlL=i9azM}{c4y~xWIuVQo^!;vxKXe0c<cs4Py<{ zLY7*VJe?Zm6mGBzof4KT))bx;-d@I9mOPOZz6p#)7AgD_7>jMZ7@8PMc$*k&Sb`Ze z1^l8W-(hJNbgSa=%qvUG$xKnm$uBO}<eS{dx^c2Sn}M<}0|Ub^6a9?*+*JL_oWwl+ z^8BLg;)2BFRQ-^m#Jpl(PoK$i*i@v185kIft-$fm#>~RV#VEkU!N|kN!N@iF7Mlt? zlK>+NW0C%3PWF99nhXpKMIdt_gdPI}!!54d#O&1gg2bZYRCAC?9AJ|;7zG$P7>g_? z|72I@mIG;%XJBB^G@Y!@A;+jP*^T3>nLt5hNq$jsMtn(reoj(qPDzn0NSzvp&<7F5 zAi@%42S<E-USe))e0-7S<U~%}Kpl__!X021*fAV7x%nxjIjMGxAg6$gVgY%ENrB0N fnTJt;k%v)+nTLr>nS+som5Y&$k%JNBRWU{Y(4f2$ diff --git a/MyLoss/__pycache__/poly_loss.cpython-39.pyc b/MyLoss/__pycache__/poly_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08edf030976f28689763861889e7ace3d3747a59 GIT binary patch literal 2726 zcmYe~<>g{vU|{H-Rg~1s%fRp$#6iaF3=9ko3=9m#ISdR8DGVu$ISf%Cnkk1dmnn*g z5yWQ9VUA)3%dzCLMzMm$m~+^2*`wGQ8B&?EI8s>Bm{OUuI8(V&nX|aTBxefi9M%-J z6!sR@C?0o)6pj?m7KRkgRQ6`(DBcvtU<OUDmmqieX)@m8@GmII%+E{A(PX^E8t;=| zT>O%efq_Aj^%h%BVp3{O@h$GW(%ksuoW$bd)Z$yLDJ7K!skfL@GIMXS6{VJx7Ud<g zfXsqoke65)7#N&E-s)jsV5nipVyIzGVH9V`W+<{LVXR@uVrphgVN7R)h?X!zL?jqQ z7{nQBSxQ(Iu+}ituq<S%Wq_%!WvyYYVU}b_Va#SOiYj4CVQOY<Vys~ViP|ufuw=2< zu+}iuFl2G0F!wUovemF<an>-`Ff8C&$dJOA!qUrJ%TU7*%%I8YR|Ilwl^Miq3dI@u zr8y}I8Hr`73W*BI8Hss$sW}SenI#ztIjMQ+B^e5-g{6r(3MKgp`FW{&n#{LYiZiQH zi$K1-#g>v<mYJNY$y~(2z`$^eH8U?Iwc-|QSz=CUswVd>mOM}xvE*bHm)v44&PW7> zT25--E!L9!qU4NQ>_w>%K|aUg;?yEg9Jm%0<rm#ztx7G*FD~K+naNR{oLEwlS`>ea zB_}^GU6Z{C6t+cd3=9llH-K#}5&$V>D$cmYn7@*tNQ!}h;g_|3Mt*Lpeq~N#o_=|L zQFd`bVsff}NKs;5v9G6(u1jiafqPM6T8X}IB`C`E3-WU+<8$(hi}eaBZ*i66$LHsz z#%JW0fLu{50*YHEK1L-*9wsg(9!3r(4kjT+5vC#;1_p*?kR&n&g$;;b?8U&qkjfCn zn8Fanl)|0LoXV2Qn#Cr`kjkFQk-`K@qU{W6jNoL-8^sBdgW(j8Im{`XDO@cqQC#f| zEDTZH!3>%_w>SdwbHL7Z^Si~8o0^#S9Apmz!v}j!)?4hw`DrD&i50gvlT+i%Q!~>u zN^Wr$rKXf7gM#)Jdul;(W=?)y5huvapiqS9WnciYL5YGHlwu7S7#K<zni&=_E@WV2 zs9~&Oh-U(&q<H2UmKuh5mJ-%1wi>1s#uTPrrW%HL_7qU6V)nbml9!*Cs>yVV4PtYV z5-1CC++r;*El4f8#StH$nU`4-A77*ZiaC&Ekuo^a^73v8I2Yv?7rW+_6y+CyeRqpD z86*;)3K7y|y~R?Tnv(|6tB7L2O2%89@sPY39}kZCB59BUbp{3od60j=<qsPdBO4f2 z3Bp2959T6GW>CT{C@s0gR+3nho>~G4AW&ey0tl44L6Hj1(oBpD3^j~d3@MC~3^hy; znu(#7xt67dxrQa1xyYl2VF6<eV+m6hGc2GMuq<R)2nsIN61Ezq8s-${UZz^6JPEi; zc9=?tDo{o)Y6BGy*-S-MC7cVmQdkx;GBV`xlrSyes9{WDT*y=l%Kl6u3|ZXZ$Y88x ztzpOlVV)Gm8dhjw!dAnO#hb#K%~W)tglPd^4ckJ-TJ}6UuqazLQ_-~&rUl$kQ4O#t zdp1+in;NzS{3(oJwL&EVDI6e~*$gS1bD6=#4m+s)OuWR%z`&&d1&&4O#a3`0NH8<6 zptQtFp(G<!0i4ZYMNLwwf|H*|bht~bjzW1xYEi0!pF(CaSY2LeZc=KILVlV8w6N7f zF%e{HWPYhaa$=rBL1J;SLQZ~qW=XL^evv{!QGQZlQf5wONoH!XLSnH3*mfO-%rqpG zItsa^#U%>GsU-@~tYRBdRGNxn2P6sLb+&>=eqO3>Mt+GxT7FS(Vu>cm0ZxWwd&#o| z7BtEEc_oRNc?wCHd5J}p3gB|I81B5F#LVJUv_K966}8|}(@G&cBQ;MUApw+A^x)Y( zApsP)dHE%v;u{pQ(Mq6lLrF(Lskk&3M1k_HQY<{A!PSE%Be-1EWCRzqRYt{`dFeT+ z@amzUC^aP$T&n0OBp?T6f<k6`UVc$(ie8Z%D0ed#m*y6!FfcGwiHGJDmlhP{7nP)@ zD8OB3rBG#{qmWpZn3<E9l#{BE53Y%e6%vb56>vFJQxII-7iodAxe&+}_Ts|Q)YPif zB5^P;ITc)vgUjYyY>7qbxrr6vqT?1vd16ssW?p*nEfx@~2vpXAiy?4n1#vQGX&%(L zTU?0H)#NDB2HB?vBJ>#;7@|1y;z7X>pPUiJofi*DHu1?Bx0nh_ZgIjZlPFdYGY3*I zf$Al2>2!-dEx)Kdu_y&(X0aNmnBfv*;$V_t;$xOzVq@fB<YD4r5?~Zz;$q}s1Y;H+ zCLzWu0eI;JH9(Wa%@6GEB2Zxmsv+Xz^AdAY<Ku5}#mDF7r<CS^*gWy^pjsR%!;+Mr zpHsvH@(MWo*g{hCit~$ZF&CE<-C|A4$xkdP0+pXdR-o|Yg%+>LsW~~YA{Oiu9gqRw z!d8}nfdLfX#S#n*3>-`hj9~bOiIItsgAqhB{TARU0;$sED>4BY%2rZYkeQc`<OmLM zqd+e&uSf$FyrM9!9>f#HdZ4<w2vjN;If5*<2N}bhnO6d727oGJunkB-4R$5Sfwwqp ZAg0)X^cVAhT*=78#KFkJ$iu863;>xw!Q=n{ literal 0 HcmV?d00001 diff --git a/MyLoss/loss_factory.py b/MyLoss/loss_factory.py index 1dffa61..f3bdceb 100755 --- a/MyLoss/loss_factory.py +++ b/MyLoss/loss_factory.py @@ -13,6 +13,7 @@ from .hausdorff import HausdorffDTLoss, HausdorffERLoss from .lovasz_loss import LovaszSoftmax from .ND_Crossentropy import CrossentropyND, TopKLoss, WeightedCrossEntropyLoss,\ WeightedCrossEntropyLossV2, DisPenalizedCE +from .poly_loss import PolyLoss from pytorch_toolbelt import losses as L @@ -22,7 +23,7 @@ def create_loss(args, w1=1.0, w2=0.5): # mode = args.base_loss #BINARY_MODE \MULTICLASS_MODE \MULTILABEL_MODE loss = None if hasattr(nn, conf_loss): - loss = getattr(nn, conf_loss)() + loss = getattr(nn, conf_loss)(label_smoothing=0.5) #binary loss elif conf_loss == "focal": loss = L.BinaryFocalLoss() @@ -46,6 +47,8 @@ def create_loss(args, w1=1.0, w2=0.5): loss = L.JointLoss(BCEWithLogitsLoss(), L.BinaryDiceLogLoss(), w1, w2) elif conf_loss == "reduced_focal": loss = L.BinaryFocalLoss(reduced=True) + elif conf_loss == "polyloss": + loss = PolyLoss(softmax=False) else: assert False and "Invalid loss" raise ValueError diff --git a/MyLoss/poly_loss.py b/MyLoss/poly_loss.py new file mode 100644 index 0000000..e370545 --- /dev/null +++ b/MyLoss/poly_loss.py @@ -0,0 +1,84 @@ +# From https://github.com/yiyixuxu/polyloss-pytorch + +import warnings +from typing import Optional + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss + + +def to_one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor: + # if `dim` is bigger, add singleton dim at the end + if labels.ndim < dim + 1: + shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape)) + labels = torch.reshape(labels, shape) + + sh = list(labels.shape) + + if sh[dim] != 1: + raise AssertionError("labels should have a channel with length equal to one.") + + sh[dim] = num_classes + + o = torch.zeros(size=sh, dtype=dtype, device=labels.device) + labels = o.scatter_(dim=dim, index=labels.long(), value=1) + + return labels + + +class PolyLoss(_Loss): + def __init__(self, + softmax: bool = False, + ce_weight: Optional[torch.Tensor] = None, + reduction: str = 'mean', + epsilon: float = 1.0, + ) -> None: + super().__init__() + self.softmax = softmax + self.reduction = reduction + self.epsilon = epsilon + self.cross_entropy = nn.CrossEntropyLoss(weight=ce_weight, reduction='none') + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD], where N is the number of classes. + You can pass logits or probabilities as input, if pass logit, must set softmax=True + target: the shape should be BNH[WD] (one-hot format) or B1H[WD], where N is the number of classes. + It should contain binary values + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + """ + n_pred_ch, n_target_ch = input.shape[1], target.shape[1] + # target not in one-hot encode format, has shape B1H[WD] + if n_pred_ch != n_target_ch: + # squeeze out the channel dimension of size 1 to calculate ce loss + self.ce_loss = self.cross_entropy(input, torch.squeeze(target, dim=1).long()) + # convert into one-hot format to calculate ce loss + target = to_one_hot(target, num_classes=n_pred_ch) + else: + # # target is in the one-hot format, convert to BH[WD] format to calculate ce loss + self.ce_loss = self.cross_entropy(input, torch.argmax(target, dim=1)) + + if self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.") + else: + input = torch.softmax(input, 1) + + pt = (input * target).sum(dim=1) # BH[WD] + poly_loss = self.ce_loss + self.epsilon * (1 - pt) + + if self.reduction == 'mean': + polyl = torch.mean(poly_loss) # the batch and channel average + elif self.reduction == 'sum': + polyl = torch.sum(poly_loss) # sum over the batch and channel dims + elif self.reduction == 'none': + # If we are not computing voxelwise loss components at least + # make sure a none reduction maintains a broadcastable shape + # BH[WD] -> BH1[WD] + polyl = poly_loss.unsqueeze(1) + else: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + return (polyl) \ No newline at end of file diff --git a/README.md b/README.md index 04c7dba..effc0fe 100644 --- a/README.md +++ b/README.md @@ -9,3 +9,7 @@ python train.py --stage='train' --config='Camelyon/TransMIL.yaml' --gpus=0 --fo ```python python train.py --stage='test' --config='Camelyon/TransMIL.yaml' --gpus=0 --fold=0 ``` + + +### Changes Made: + diff --git a/datasets/__pycache__/custom_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_dataloader.cpython-39.pyc index 147b3fc3c4628d6eed5bebd6ff248d31aaeebb34..4a200bbb7d34328019d9bb9e604b0524127ae863 100644 GIT binary patch literal 15301 zcmYe~<>g{vU|?83yE%y^hJoQRh=Ytd7#J8F7#J9e3z!%fQW#Pga~Pr^G-DJ~3PTE0 z4s$L`6bmCnj5Ufig&~DGhb@;qiXAM*lEab98N~@^v*vK+az$~0*^D{dxja!kP&RKA zFW79pTz)17MuuF0C_$)<P?S&#Lke4taIQ#{2vkfoN))VGEJ_?qOGHV4X~`(56owS` z9O+z{C>gL?jvU!sxhOd>n=?l~S0PFf%x27yk5WouNa4y+&Q*z0VPtS;Na0T5X<<m= zNmXrTj#6`HNa0Q4Yhg&?OJ&beZ)T3taA!#2PZ4NgND+XFXu30`2&M?NFr)~9^=PGN zrLtyer|P7L%wb9qO%ZEhjnZ{zND)U?ld6{@k<B!LsVD~{?DV7bQ&|=mEM&+sOp%mi zh%!o%N;R6zkRm;oIm#HymYK^OWddc(&Sj1=g|g-5GDn%E>Svj!nx)96DD*NhGNdR* zS)^*DnxvYhD5WT;TBN9?sP?i&S*GfwTBTYxGp4ACFf=npS;P3L))FATO)6`aZ5I0i zyHv|m>r|U`riDx~Oi}hx4yl}}_9=|1DyjC(j4A3V8oex0ju5dF^%Tus<|xM$hA5{L zhIGbM3lWAW=Tt`*hFG;ImsFP+<|x-Fw<z~i=TxQz9;t2%8Ce-58B!$88JHPt7$z_k z8l-xrIyW;iGNgK@T7k?5u~Qg>88o$Dg3_X&CgUxZfW(pv5S@}(lBmgei_0akB+(~7 zF(tJKBwA9KlB>yhi`6qXF+CL|9g-iCnpd1(6lSE!cuT~wv?SjxHL;|$DAl#1q$n}D zBp;-WE3q^^H#M&$wWwH=@s@;VZc<`SVqS7;3dr!{)RN%D+=86cqGXUGkTDZ1Q;0Ay zFr+d>F{Us?F{Lo3Ftsp5F{dyGGib8h5(-MpOHcL7FUn2K$*f8&$;{772I+&D!_2_I z0OC7?G8f348pc|t8m1IRX@(Rg35FCVX@+11O=iDzu!aZu4$r|1h6na588n%1u@tA~ zq}^gG2D|VUb7pRO5ibJ+1DN>bs-Kaco2p-#lbEMpo?nz*T#%TYsvlC6m{;uU>7(nC znp)sql$chc4+_lU)RJQT<kI4j{M>jDKL;GBdIgoYIO5}z6LWIn<7Gf@WC7XF$ii49 zjp;Kzm`XoQ7Lc~Q#N5>Q_*-1@@wxdar8yurPkek~X<`mUrbr0pR<MI1gdj+an}LBr z8e|<PY&aN;Kzxkw2Rj9%C<!I}qnJ}9+8NRqQ`l12TR5XwQaDmLTNt8PQ@B#NTNt9) zQg~8$TNt9)Q}|N&TNt7^QUp>2TNt7^Q-o54TNt9aQbbZjTNt9aQ^ZolTNt8v+8J0F zqIiQDG$n5ddbqfm`g;0+5(XrI9{yrvU<gVrC`v6(%_}KZNX%16OezMaEd`Kb1yEQP zr<Uj`xaAiq6ldn8=cFolm?|Xa=anR8=A{-XDsU+%C_o6$yn@mag@VMQ#N1*lurNp} zEi)%Iz96wA!%6{^`V~s@6-qKv71Hu^Qc{ax24&`@gGHg%DQM*9r7GkXDdd*slw=lw z>~wK6RR9@Ptf>b#IyXNh)k;CVq$n{nPaULOy(G1`L>(><@*_AkSSfg>DTEZ2rs^nw z#EKOXb8_JJDP-p6rlw>jmZauXDrDxB<SPU@`syi!7N;tdWMmdAWTwFl1o;Q#c$hOY zixtu`ONt@7a#M5jiz@XL{7W)Yi^?;LL6)T^mO)KPEKx{H%qdP)NX$!t1Q*CK5QZ73 zP@b8S19qQ6PGVJNP9?&lAp7Ez6O%Ji<BKz^QmqvHN^_G^ixl$HAh84v7=@CI#1e(% z#5{$hR0Xi&6ot$@us6Xx1&x%{w8YY!5?f<UJua}H{M=rGGVMzS76yi7PEgi>Vo?SL z22e@{<?rGMCI*HQhGvEZj0+hU8PXYQ7~+{SK_oMnWC4q?*09$w#Iw~f*D%Dh*D%&F z#B<az)iA_!*09tt#B<ef)G)+z*Ra+w#Pg&u)^e8c)-YsorZ6=#^=H*`mGCX#ui;wA zSi>yIP|IDz?ZOc27Q<A_Q_EW-P{Y&Am?c=tQzB5qkj2r=2ogzQ7H4Q?lxC>qtKnN9 zw2+~wpax_fOAU_=Lk&X~>jL2#zJ-j8421>Z3@HpO3@i-I%!~|q0)`C5XABq^!7!46 zk)cEcq`H~Wg&|g=mcNE!foO?X4gUi1g$&J%HEau+7#Rv5lt?U)ERo2PN?}Q1?PW}7 zTnO@;G&pxM)C#0<q_DRz)C#6>f@z@?E-)>e!VRWHN@TKROJuU-K)E)Bua~JtutX+H z9#jdZ2=+2HGnU9?DS&cribyY0iA<JaicpFuh+QL4B9o<*B9_9KBAz1A%M4bjTq9H> zlcka(nIe@UogxDgtr0Gf$x=;`O_58H0#}+W&5Si5Th&q&QxsE_<}jrwr>LYzrKqN; z^|CfIrZcCggUnB7OVI$)H4O3UB^nDPOEjAqN;DQ|Eo4Y%0{c)KEUJ;ElOmR)*~?g> zk);b`EznER0{LzsV>&Zftv*<-K?)BTYWFgum}FSPkj0at18y*clo%~APSFMHGD!i| z4SIedw-}46SoJbe(oA16FfuT#WW2?Zn^*x(>B)@H1PP)*Y*q#ahG0;kDZl_qlv#`^ zj5Q1?OexI0%)tyRS^SE47#J9$*n<-DQu1>RS2EsW0hQG&nQk#>CMHI42IuFdmSpCp z7O!Ny#hh4*RNWMTlKU+Vo1Dblq?AOvKn4be&mc>ybU@`GN};J2AD@|*SrQ+wXOokk zoS0K=rw7wUvU}5K%>;py_Aoaa70H8Y5JeE73?kqaBG`pRstgPaK_GE81_lO_t$Tj< z3<$h|TW3s!b)fJrBFnlWO{CZ`A;PK<vVDePl_?Qcg~F^_$&~r||Ns9lK>>Y>)wj|k zIBX@;E!MKcoYK@H1qKF&l}zczVDQ>LJq-*#+23OKt#nV!%}oR;O$RZ+O2NtR7Kd-8 zTTW(yPY~D{Fm9;JOHc|1=?uy*u{2o8e2XzViZeSkwE&#f!3L~kO5Y*Z`Z{n&pM83@ z$U66jRa5MXVA=vQQ<GCyGT&k?PEO28y~UDTmY4(8^)iKlf#Kx|ke?(N7#MDG`c}H; zBo>!sCa+|<#hRE?kdb(cwKy|9H?hbbqzLLSuy4}d*{4I&#RvP9Ot%;-ZZTG_WXik@ z3KU(i|6l$Au|W+MkY^mz(lYZ>SF#q_fox(c$}dSQNxj8doROMXgfJcCPyf7B|Fm1I zzLk!p>6(nU7z;2GEU1D3rLSU82d;!+0b>orLdIZ*m5hFxOt+Zw3U09`78NB{f;1;W z8d#ucLS%(zkY_+?Q<M1?e@1C)d~rr*T1k9PW^QK5E&k%f5>$b*#2gd>mXyrA{3?;O z#N^bZ{QPWv5LdS>vn0vTOw;rhYjJ5oY7wX@2+8laIKdThW?owUEiNz<T(REbE6&MG zN%cugO3iV}OfI>_3a*cD@gtOgBls2zsKIiJ6C?|65#8d-0qKFTIrHL^a}tY-Q;TnL zBqgRpbZ{gkCTAz*=cV3a$uCIFyTy`KoS%1#1=NTs0<|q~u@sl2=H6n@ERF}6dW$Ww zpdd9br3h3^-r|o>Oi2N^f8s$-F22QFTvBw4y(qOfKc_797Ng@WcIW)ug8bstTl~Sf zi8(o`#U;Lp6~UQRska1+Qb8?^_!3ZaJ1xH`x41|N6x&?EsfDGfc_o>NIk!X|Q&K|m zJxWs@^HPFA-4byB;1;htC_6ak=am#C7MI*&17(GrA_!fYa*Gw5_;0a7(gr^`O@mrX zkW_U`6q;<o?R=2M5JTCZ(gNVd9+I<eaUnSo)Xu)e0qZs0;sZrDLi&~nq+0OEFUqU} zHDGf<P6PLGQu1@dQj1D5lM^9gY(C&>Ac`+DHyznB38)FBsR$2;feMN!zKr5BB$Y*= zA%|Q1X+`<D@dZVxU{|N6+!9QK^xxuB;T<^9B0f+W2KA7NGC-^}5RnNYvOyuv35slR zEEc7L1hc?`MXAZ9Ma7w*XiU2$T9TSu5D!j-@sJdil9^mm<Ox#l1uB90E8>&$LA}5d zNK^b4PlhQdaTceR#HSV9;w*+$i?_JoOd}+=F%sJZiEX-)@fK%1B;CfxL-Hr69xX~> zU|`SzRns+~I+T%%iH#8n!gx$9EPRZ7j6957jAD!&j9iR5%zTVu%q)yt|2ddB7@?4f zg^`C*fRTrhhf#=;g;9u6gi+=%6ARNHmS0RfOkX$@S(sTE+5T{_voLZoaxwAz;a~&l zVEf0x!oyf4g}eGj>B)lPi;aPS0o;?7VPs&aVOYRW!?=(!hPjrhmbsQCouQVsh9Qe_ z0n<W;TDBU7Eanu(Y^I{B8YUNpSg%_48ukS&3mF!$mat{9FJ!D`2enpHn6jCQB0w$O z8uk?C6qa75TJ{>Y1za^y6$`kbEY<})AQmG-p-x@^K@+#2naHq!7o;0xCl(XyO89D6 zn;BDBvzdxsmGCbBwW9?=?bj6cg)FtKd0Zv@3xsM|7c$nemhdkSu3=utSj$|(pCytd znk5Em&u59JaP%_OGS{%yFxRl8fkp;6{cf>>({xcBC^<3}lxQ;DV#@@LTNK~oD@rXX zEy{~80gqi2gR8HXpn|>#)F;+tyv3bYnhQxZsl}QSw>Z*MOCSjxRGQxsD9MjcNi9h& z%FWEnEH24RE&}C<TdaAdxdoNC*b<9NDhpDJB0%=CmS*OaSQLTsNl_0-fEPLz5)XDt zQ5z@+6oQBfkSudfYThl@lKi6Nj9aWliFqk`Mc~}TT3nKtoPCQku^8&+TTFR*x46<u z^O8X$9*H@(*z@yJ<1_M0ia_~NlLMS(Z?R_PrKDEeVogdc0T~F+dEn|eiZj2oBpxEn zS)5uJ4=P1cia<?PNFIULVYj&B<I_`1GD}i(<Kz248O8vVVVDFMSs1yPxR^N@1sJ&) z`IxvsB#Z}=VdP=rVG?2FV&P&GVXTtHmdN4eVpO`Ingg89A(bvDRW~yRGZck`EBm6Q zpa4q*u^GW0MRzlMe0)x7UVQvkkb6O9F)&t1VsimlRWc~%L7oC(P!3~<<uV&^6TXJA znK6s8mKmJVn93N6ENU3Cn7}<jmJ~*Co?~9XQo~roQo~%sn!+@P1=O<x_Z^u1UjG08 z|NlxBXvGW~Whhz?ay_VdbBhyPJI80HRA{n-Ys;d1kQ`_rrU>M{B2dcEWP|t-T>kcg z3QLai;>>tZuK{e-E#}OW3bYU?Dq>(@*a7lAC;%Bi#SSAEXi!NJODLe$!6;=8Cj$cm z2Ll5GD1?fI7(pGH1&kdGDU1u5`j^!*r8Cqr*D!&~8yALHlUkM<rWzJWP-lm!l}Qp* z8nG^5E@8=n<XpA|?2tTM!@Ph4#A0M9lqz9czzNOIT;TBKZf1%Bm1b-;3=4Q_m^&C2 zGS#w|@Pfyim{V9<I7;|x7_#_5J>q7@X2x0$P_I6VzlNiRT^uyr$!fz8&XB@{(xcvi z(xV2KF3pTJ3|T_iEE5=u_)3Ip7_tPL89Ny}7#E0iFo0qYB(gxXgE56o63lOBOk+%8 zPvL0c0NL&W>U}fB^QSNcGiY-Dssas=f(N2tV@Ke4RY=b)ON9(?fSNww@eIMdhP z2U$uBQW8r*qc70W5Cw2{g${|pm4m85aPvr^xTL5w88o<x7<Va>VqjqKL&-m&xB_8g zQ1pYEL^TYtQZY=mka152Nrn!FY^EZn8pbdNMutKr&;amCrlRd2bHJ%+CG#yNJ%d|} znYS3rRx%cWlO3Ws;Iv6k%}cE)D6%^bN~)lSZIwQjL<MgXAay}uI*ULl6I|$`I2@Ew zL0L~2<Zvqn2Jk>mtWhmv4Py$}*(r>YAg3Ea_%)0vOp**Wj48~L44|QpIV^chHH;IP z3R!{~G}*u&0^0%3B;Zo72vi;vZ3DR*<e`;JMSDSoE>lY8O2%6}MX9;@pypV1YGpAv zAi!<-TU<6dnZ+fb#);i6kWR3V4T$tHT&t$oEtU+^g34PgZkeEt5!jMNAoD=gMA2dp zYZ=HfT;K*7xF!MT{-PBiQNfb@__WN5)D&=|J3c8f9a1mwf*a5!`H-QvTP%606(zSg zQ}ar5Q;QNyQXz#1xU>Wf;ewlIw^%@qy2S?eZt5-Og8TwaPDBZFi#@{>+)|I?1eG$7 zDuSgXGba^XJKSQ7FWL!m5eFzqgGw%NP_1Nwq?4l{e}LLox47crDPTV+K0yhQft8OD z)bIf{mAIHdjVcgJhS7jo2wZ;%FjvXo3s{tb8kEFA7+g(&QVVD>FN>j!u}GqZF&&h> z8ETkp81g{bouP(F64Y~J1ZC|KrW%GUut+mw3Zn!=FvCiws)>#%DaBZ_Fla0k(pV@~ zfQ+N2q*mxDgk+=^r7C0=D}Y9U!I>SNT|uF%kd|MhkXn{nR0$pt26vxykVle1qr0Gd z4Vn~y<af~MYi3R=QqDF9)o;v2;D`g&1K?V3B{L+dKt6{gaZpMI#{oFG7M)>WV7LPc za8R+$zyhk81sJR3u>>oq*oUPYltLe5HYhKGLl!iH01hukNd{2<15KWQLd*}GqgFD4 zlNiiZ`#{kTs@Wh80y_<y8o&)GP>uBl<TOw>oq>rDYtt9h(gG<&aS_NhAXkBKF~~*W z>Hstm#8JZtDvZHXI2eTq$R2QE0xBfa8EP0}g=!f)K#>pf41**Cw6I`=6cM0c(}a`> zMPQ$726+eWiK1iB5jXHiQx%pX0OWRX8><+p&;z9haOnXK6HO*?{M};D%uC5kPK6dU zV2i*3#f32ABPggqK4)OUmMstpQGyE8egOp&IPyT99Q4QowWZ1!i<D{@7BJN?fieVK z1*pzuhVwup-7Ih(3%EXK$YKO#B~aU^hNXrT)Y4wbT6M=6l%7jKQxUKhsX}62C1@B^ zp)$Wzp*%6K1XMF7CMTyB7b_$xBxjW7WrOAhKzTqRttdYiGzkFDLZC_#Bm<t-2>}@k zo_NSj%uZD(E-eDpn29AI9gqy0Sqz?&NCi(OBqtW9DrA-@WEP_smzNKo_bCBc3ZKpZ zXIw;%16iA?U<;{*;kqHY3|wQu8*VF^!O0$yH?M>8DYT6SN;bEc(h49o99%s(>LCOu zmW$>yFfe=tMK7ph&A`qFYM+5SntV*yx`Uvo28D+vi=SH&Xz;8^l7WFildlMrxQjsZ zBSoO;kX!8W@hSPq@$ul!ZP7+h>fHh&K#jp7kTb!hL(x5u*aHv&a>p%RX!kfd6*LD4 zaV0p>PX{RhyH|{XfdQ1Qi$QjAFtTv-C2^>8FmteT$Z`m9uyU{zfmDE+t)M0}9E0*7 zh*KPeK7S|z&cG}wtl;@W&^#h&{tz;c$OfH9WCzVifaeb(^N5_#c|<PIYy@cjupKmS z$P>(<DSAuLJvA@2II|crZ&)PGz`)@45;V+|3~E$>+yTNM|9~(kj}!|LH*W-*6T&iY z1RB_n=Sl(1A#s=REZ|+pki`dTjMTE1@PqolDJ;!Q{Ry=kB?1csYdD~DN1QdBE)20Y zwOl1UHC)Y%5FIt_*~~=?YPc2%p^7eKs^tRB|11!Oh=X!#VF|{J(J_?vE%J;JH+05` z8#-ge0-iCtRU(lkxj+g`N-typ(=s&-S=<X5p>r#;C2|X-O60QSQ<#wElw?7D@)m|# z9_XABFLX|c4?3sBp8}av5-3r~QY=x(QUb*wcutA8L?KHVX--KYO9d2*p!uB=g)CM0 zoDxroLY7(zGiYWQI;W(NrC!5VqL8J5GN;5}qL8JDHm3x#Rg363CG8SW2<S93lt?bn zh0iJJfqQM5elJ12aqwV35vUplx3Y^s)AJ~G3~1CFl)%8l!=S!%2?OeMPmvrb3*nt9 zv1eppfK(_t7-^Vl(>O)y1l{2Pa|ek&)F5cGBbg>^5j5F}Op~<<n(PcSxd=2ySfmHi zpbsJpL4*;A0ClgxlPN`}ATD@L%#4A7A*9Fx#J2<y)*u3uN{T=u2}QObE@+h(q!0&H z(M2vGF?SH*0V2FXgfEB)01==`z9R6H6KFC>Q@<!4Bnqn5kt!fixmyI9T>-c6!DCU7 z);M^u7Ce5$l?ht;0PZ}1#~E%hCzq$(Vl7B4Dk(0?0hyT#BJw~jd-kH#f}F(UR7m#& zl)k~E2}Q*q6(t}Q+~5JKvdrSl{Jf$TkZ>7Dm@6bdz|#l3H0Tyv5O{RH2%L{fK}yO& z1gIBKR0(2Lfrx4l0UCfUss*v?Ktw%=XaEt7AfgFGfJWWGOH_(lL0r(FnWiAPA%)~o z&=_11XnX`547XTQQj<#4Z?R<L=cX3*gS1Tm5uPAoB8Zp-A|``~DIj7hh?oW<!1W%u z1_V_U;L22lk%6HY6r2p0BgIT?Ol*+Kn~{S_h*6BufSr#~jERkrgOTSS2MdW~#k`QQ z;wmZJr2$HN4-`M3x(t+<igOr{MunL`qr#xp9%y`~7Sxi1j0bzvFrkbCGu5yx-~hD* zYnf^o7H~pXc~Tgv;BqxgC0r>?b3kJs%zi~PK}`{+B9I%v&1g+-aPy`J9A^ta;R%kI zg&-DMhj$4`3>29VjjWIuT?-Ngm8RexM$rZk7aSX40vsivrcE>`N<e)91|~ixE+#G} z&^ROyh-9vk#TIzvjW!l-0@)94po8rP6JSSxnpSxrM}XVtj8&4@EG00&2r4j27+?d8 zP!<!opT&~GgfX<pj5M^!;`j33|Ns9HV;U$!h@gHJq=*1@9kC1{f(Cp+{VZ_MgS+%d zEkCgTK&{3?kpI9vNzm{v6ANRNA{JkQdXk{_DR`8s7$tdvnqDCNpyW{uYIA~SP@w~e zWwlHtOf^gkm_dtKKw~+e!9-BkB854H89ac<Qp2)<rGzz$Z6PBl0WV;$VF3>avX*eb z8;~3&oHYztT<`%wHqa8DEUp^18dmUtARBl<kQrq_umh#R#0^^h0Ui+K&1NaO10EmX zfer`q!G{BR7Vvj~N1Q;7FqGjyE>I5?w1fz}B$2ad7bwMmx}4zQCqI-56%-C245~ef zL8T$#!MUQ{AoIZKZV!lc07QVf;AlF?z`)Q7tFEf_u|yAiAO@*diZT?3;zCer26ZAq zeT!lf*eIM1#weT)@+cf=>Szu#()b%Iq}v3p`#?Pha5B|ohKPa%z&$FkHx7Y3bOc0z zxnN&``q~rVzA_-vSMa{Jra;kgkb&Uk2H=6plOQfy`2!j&DLM<1I0quYBkrh!WknZ2 z5*I;596O|R0uP5?0*PJ*5#VwPG$2-V6~qO5;2MYp_64ZVRRPLEpiUbDD<6{pXqu3j zhY>X9B>*4u!sk0$jd>N_0J-uehyV|3f?W<Kz%B;`%_NY^L75jZjD*Fdpv+J1FjCQN zko}->m!i8M7TA6;0d@qaqdEuV2yjOdJYa<B08r-`RFI(*JD>y%%4XnN5!76Qk7F=O zGJuA#K=mRuM?s1ng53WIM1a=zqnD+ij`l*3Pq20PvHJ;BplUL^6}<rUV>Nlv`uasr zK~{l$MxdV$vbpFPNHI8E!34O00}f%(VrLGJyFe8wxMT0mAq?)=OL1^<uoi*jLCs50 z&l!$Eu>oqu7K3^>@MYPYpfUIs#wcb;ClRzPIG90`tEdY!Ji?Ue16rR0sxntHWio=; z|3Pj7t%XYlHDO?egV>-F2pn;s4i|W#CIe_}H=YqRK@`sfUarXu>Zn1c16hiqK}N-Z zh*%I22C@;_n7|^19yg#s+ztu@a8ifP{QpwKwT=SbDMLwcpoub&+rjPz6&f{+wM?My zAZSS{Xu2?)q3BErXqu3zggJ|KAyX|=3G)IrP-82G88mrc%L-NvnUrUg1T7Nf0M&^p zOyZy^Ft!p<^9xiDf(9bk%o%E#!RkT71KgnY0;u~AQV*UkV_v{h!&1Y(kg1jhu8J3` ziWjPiqXeXiwT5FMQ!Q%=^8!B5)H!HKZz0I#FqiU!Tw22v%&?N#4_uvr`@<^*AhX+S zpegG7+#*nj-(pTKGm7F!E-P`)&&e-};&t}-@ehg*a(6QF2oAf&T3nEmS#paVq$4x0 z1QKV9LEQ%S(!5O2GMXY#4Blc+Oi3vMMdK~j+|;7<RG2v)!C^*0?oKPAQx*K6hz1qD zw-_^SF%}nr<`|+l;mfO{IN=McqBud5I^acD=rIpU^!q?D4@zhZOe{<SjC_m&jB<=z zj2w(oj67fr62qGsAZz#iG+BzifE@J=PJj~kN>=oY3rcRd*X4t9%T|z+K$(n#iHjYa ze_n#NFlaK}Vunrpprk!ex&uw~g3}(TlLBgEfzli(#uqZCFw`)DTT@Im;8_n)%2~i% z!(74up0sAp0=G9M7-~Q)CP@g*45e8l85XjnF@f7(EG4Wpj3sO}49$$qOlhFzKPyt2 zcnQk3MROrlC3q@}0~FT<sYSWBSc@|f3sP?}CnuKNVy#Lo$}hgfoST^kp3%@`E#d-I z^q_uf6nI@(W=chI6gPN<8F<@8aTGU5PkcdQN=j-9dL*7^U|_fa3NcV{GcZdqaxk+o zYW-(pDgwz+q3}YAR*)4G7FR_-K<VQrhya(xU;<nggA&F)kVlZqp`u@)MYN!xg_V%Q zzN9Jo4GNq;Ac6&yFu)-MCcwc2%Ct{G!9*z27O^ujFyP3RMW9L#DJ3C{2W7xJAmd>f z5VW{5H!(9WJ|0~3{bJOw(t&LGf$#5$FDS~-N=+`&PtH$C)d!C|fEL<i=IJFDmsRPa zsfO(a(g&?4*EcdXi#IYgE6t4u7m)f@mgvgCtC%wL(&I}KlX6mv^}*`m0~~|nOOkVo z;>$9N5_9xGt1hd|(6xfv5XGq_`pFro$=MiIL02LbhiEe1;z+H?EG~)9&o1Hyg*m8G z51DGvWJOx@2OeDkEieMlK0{ZAzqALpQ8bxuu|u{=6+xCyfp$cJcYGCzGBPk|G8S=y zC@)CE96THVn)<%Q4q6V8lbHk_mns4k98r+*CB4#;%$#EIz``wQ@Q4#cL=WUEFahe= zq!tx{VhX&TG)fGt3^FjKhgen#8KZg!S~?>PR-u=drw5y<O3W!*1xjnHL3I(M+b!ng z;<8)J0iHfZphjfTKaen6L1JD?V(~4ef|MfAKv@)DVop+NZfaf$c<pC#5om-fN(@y{ zub>iaSrMoYe2WL}oYbOX@Y>5Lw#?jg&<cN0j)tsVWy(yv#RJv^4ap)<xqgc+KM9n) zz$0L{*g<m~#i=FWDHm`94O~UsVuNhNyTuHuT|q@3xZ=6RlwVx*8stClh`=XM7;u79 zR6NL^kmbTf`S~T_SiZ%bUs_U7S^}Q12JLakPEE-yDh979yTuL)caX*4E)r-;=N4nq zEmm+W-Qoe2_4&|59K`{#A*DEq4OIE26yIV>$+^Xt9K{Wu9*s}VFU>2t#R6JX1}XgE zGiA5f5{uGv6Dw}978HRz0?wES0#pLr;;_jD?{>2T^)8D+qe~o29E?1S9891Um^@58 zj2w(YpwTzj`b_Z3OpquKvjD45Jwv8YJ%gTrv<Q!YEEiZGXoV+89}goR6AP0FxSJ-y UD8k6c23oPn!OX=d#mK=30K)+=ivR!s delta 4320 zcmX?FzR^Y}k(ZZ?fq{YH?xXmmH%1H$k3k${%+A2T;K0DZP+Y(`Q75K8myd~oks+5q zN`R3eg&~D8M=(k-g&~D4M<`b~N*F385+wqb6O9rB)8bL$?hGmHDI6^fDIBTnSrX06 zQIhTqDV!->Eet7KsZ!0%QPQb0Dcspi6BvtfQW#R0a+q^jqFA8j$wtYhvMi8W$dFYp zpTZ-_5T%g9o2oFIA%$-)bCe>K%|DkpN(ssqn9Cfc3}p+>WsXuwl}uGiRZbB~5l#_F z5$$D-QcdMdRZZbZ5$k1&Voza+QcGcoQcq=fVTd(|(nwWL5zl6wz*ywR$S{GiFeg<b zhB-<zN-IiRDpe;{vzd{RAyqwv$DAQbeF9^lRSIJ;gQmpf1x%)spDNS{dAPWl`g;1f zB$g!R<R_-27A0=BX8p%l@AguJfq@|z6rivG1_eC_0|NsyD1h@A85l|!ni&=_E@WV2 zNN4C^jAzOOk<4I{rG~A8F`l)Cxq~sDt%k9KF`m7Kse>_|qlTq}F`hF8B*K-#RLfq% z-NBf}k;2@})UQy>QNpu;w}xXOV@*A?B*-w9aE25H76ulEW@bi)JOM+7;xYpUMlg(I zU}Q*PNnvea=wQrZabai%8Oc(@SHj=Hn8n`AxIiF<Z6RYbV<$8&I~YJ}IvLs-+8NWB zQaDpMS~yDhKzbJlE@WtC1j%+V#tW4Qmxy#Q)@KPdGlGJyM6`nu#CBnbm8j+HU|b+p zB3{F}Kmw$thIJtmBSYbv63GQpC6WuI7c!JcE|5v#TF98rv=HPRS%^%QTncvzPcLJM zWR^UPwLl?-H-)c-VIgBWGgz%6SgleDKN!{v^fH%7E>K>`05wV_g&~+hQ_$~a1p@;^ z$SuaAWJXxLK%<0(fq@|y6g4so3=EkJH4L#LF-*0LwM-=pHH^)SS&TJIlbH%xf*Dpa zXfhQsF)%RPVks_3&DCVS#iVC&i?QMsV`ULH0|NtyU}pGbJ^4Db?BsvU@?u;zDVfP7 z$@#ejcIFHW44*+3SLsc@C@DPIK!|;^y$W|dTtO9^mR?3mnkkC&K$d}04cK`?3=9nE zAm?$_GS)C=F-!zI3S=pRCSwr~0|UcKrXpUDJNQ5Z%tZ`E0w5=I+2mvvL)_vDk_MTE z;ufCCe`Go8F|220U|<jig^n7`dO48w9Smg*MJy$tG{o4<Si_jYAkH8Sb&@7y6)V^; zu;AiM%gjlQFGwuOC;}O^k|~N8BvKronFkiV#ZsJ_lLqpM9K^LnpgiOUbLLA1Mh1qL zplF-Cl-F4q&0`!kIf=PRDT#K8F!fbBNI@*bsmYZ2`Tzg_FL@_3@>wuxs!k5)v#jS& zNi0c>PfkqENR2PftV+GblAE8BS|kb%s+`26)EtP*IP>C@a}tY-Q;VZGk`mLQYV!+H z^KP+Z73b&OVgcp;TP&c!g`^GEqV$~nq+2Z6sg=dISi%0h#ad95nOAa)IVUx*2o#{V zCX4gSt8$eTCFT{U<rn1^-(q*p&n?I=PQAqyk{{sd<C&Y7o_dQdD7AR9Kfl;ybpa#B zTa%*%H0y71gye^$<`w4`-QvqAE{jJPp~+Pw4hjK~7mBn&EG-bB1M(3k$dU0NFBfTo z`9-P8rA5V=WvRCq)1r7XOhHn`sU`7g1uGeEamGXZ93NjK1~LMaeu`8X7#QRkycrmZ zoj}Q$k&B6q5e^}2CKh%+Mm;7WUJ*t<P6Z}DW<Dk{W)?;^Mvi|RtSpRNj9g4Se>m7! z7<m}k{#8jqBT1wlR93PvFfcfSlGzys28J4j1q?Ne3mIdWYnf`9Ygy76YFTR-vKUhs zvzdxKYM5LYVwGyyY8VzU)vzsO1hbf-EY<}qAQmG-p;Dd#Sfx}gTM26x+X8lwcr9BE zLWNHWM-6K;V|@x!Hd9es3FiW?6lPG#n8Lb{rIt01tAuj_cMa=8##+`I))LMIJT<IY zyeUkQ3=0`08EQcyd^OAq8EctKIJ5Y(*s=s_m}?la1XI|0nQED9SZkPTSkjn-88q4b zZn2gWC1&Oofs*jc6b1%{m!M2tB*DPIpvicvo;$BJ7ZT#B#hOC5IMP!~AaMnbKu1s% zxPW4fEwP{=H7}*e62#$3E6qzT$;{77%(=x`k{@3JP83C^AUW2O{G#NHTdYNic`13f zSc^*%le2GefOrL|MY*?Fi!%}nQg3l47K3%w=NH{#%FDaOo}ZT*pOIfu1WGcR?BLXN zi#0PZCAH!fYf@qf$V70?xW$r}n422KnO|BG4-se1%uNSnUvLo!$-=j|<Kxp)OEODR zbK~Q)K{2ZYN;gabj4VuC%p9Q9z{tnU!N|kN#ly(MB*e(YD8e-PkBE<fG$>e@z&<Gg z6*`)XV5gv)&mJG2lbRPF-!{2I)W{W-hCmpchCnSDSm9mE42~A2GKM0X8ip*U8fHlb zNQz1Ur<^S21uQj;H7qsEHEfe#iz?N}fWkW#6pWmZ;vqAoLX#Dox50&Mkp@Ts$YDkP zAeTY{3S4p&flAgWj`HHncu)uxfsDGv3F5+nDWw9<8%1Ue3=Ey1-~koj42)dd;E-YA zW#Ot)oZKiQKKYTBxGV<)0|N+yl1=e922cxP0b>V43gbei{+`KK#q{cH7#6TW(<<`< zb`Xn^p-`%ZVF5=Ca|gphrdqZV&US`0Mo<cG;V9v%VaVc!)ZNXDwd^%)Abt&d4VyTF zBtr_T4MR9X3KL5EVIFe(p@gS~A&a+}v4$awFPo+4Q;7pVM4*$ggK>dC2g5=}Mh1|` z0>KW(6gE(nVFWcF7{Sd$4v=M_W<w1_Ja-CHFoPy%)#Q8PHH=)7qa|D%HJOT%85kJ+ zz#*l{46f2{F=pOkEYoC!1T(k>0F_F&IBn8X^HM7citILnVg+2O=}*2Wsm=uxFG`s# zAekHlG7+5rz>!h}ir*q|g^>i(2Qp9-QhDBDO3Bn@yv0+Lnwwvim=g~w#=$NDSEJyH z=>kZL(B$*7;*)E1L?{22<<!#@y2X-VT2Oh5#Vs=@wa5sReAp6;ODYRei$KkaTdbv- zc_kJ_wjizklS8Cb>cJ@plwFI|L4u%ixyS&l4{UYnE#`v!0!<D`+6NWtQS2F}@u2#o zC>5lPr37SC97t~(0|P@82S_SD6Pm&_nIJ(0%6VXy-{OjgR|3lz7#L)V<v{t7m5&ir z9t(i-AtMhX52FmDfvCV_Md?~b&B^nnr&>#b(l|53jiB@n&X=0Z5NCnXBRH}kAp{CW zuw_LB3=9mHCg;l}T4^%*fkRu95t5mqI@3Y!2FC(e6PN(UD7f?k<(@~AxnxZz7w8u< zYEHf*_aFisgP_6_5rdk{U^5|(D+4){4dO14>9?5D3W`wGqq`E+!gxMeM8TMo$*l;K zzchI#2eYWu-(ruCPsvY?j|bO!Me!iFCV&V~W-S6)1omK2IY<m-$Sq!Im6@EHlM^40 zWRC|(3CLqbpwbD{`Y8q_QVvEIZayClX%1!%Sq>cz0S;EMJjjm>43iHTSTZ_I{$n7` z`5Y8S3=9wKC(9UGM}td7mLf%vnV{0NNEyVE0(qPj;&CieNR9!!15}a!1r<0T4>Le& zZpNy~2aIJVpELX-r^)OG$tU257l5Q*Hc;J>pIam~`GJv<9N3b4P#R<}&C3L}Uy4Ak zxy78Al2W8R*}zy*Mw1ilo?C3iiMa(isYRe<dW$jR7GrTy#^fYpRY`D!g9(tMZgIrN zCnth}j%9LyxeC}hYN>S~-!;IACWt>PK>}4Eq8dbiyi?Q*Vu4LV5Fm4lKy8)ZAkTst z`5a7K>|ha)b-9U|dGYZ@l9P{{cZ)iL+JSlT0gl1(i8(oXpoU$Q-sHdL5|isK#KMcb z7#JAbUV@q^nk-07D^O_<sVo_9ae+&CaP#mbsK2MFTa*I|A=bRo+=5C-D6)fEz&V*o zkh&yjvX*6NJ*X~^5(d}TdU<(zh)TUE0_3LvkQt0_x0sWQ%Wg3Tc={B9Tv5~l5@st% z%u7iuzQt6KQltkGU<Y+Zi&IO$N$D0_eo|IyatYYOMW8^t#gt!si#fSG<rZr}Vo^!) zEtZV@dQd$MZk)t}Y~Td71LKSG^Gm?qyv3bgT2fG20%;1fgCs%5gSr>D7?W?YfGYPW z=A^`QM3yWv0XZ8H3t$!~DsFMu<mRW8=A_zzT0zC21`P)jha)2oBL@?x+~Hy3VH9En twFfvDIT$&Zgz6bG1!TFv@_e9@38aceh>?exg^`63)P7)N<YMHI0svo~<<0;A diff --git a/datasets/__pycache__/data_interface.cpython-39.pyc b/datasets/__pycache__/data_interface.cpython-39.pyc index 4af0141496c11d6f97255add8746fa0067e55636..9550db1509e8d477a1c15534a76c6f87976fbec4 100644 GIT binary patch delta 2831 zcmdm`^Vp0pk(ZZ?fq{YHhHOVtyuw628OB2swTnYi7*e=$gmXorL>L)T7*e=%M03TW z#26Xe8B%yscv~1!cvD%k#G9F;B-|NN_)_>=7*hD5B9iV5DFP{iEet7wP!XvV#$X0b zp@}D|O*9#Ai8z*)<h!LNmXsEyx>l4FB_@~T7iluylJLw;O3X>jOHNI3Ni0b$PAv&e z%q_@CE!ynN$jB@u%)r1<BmyEtL4+8H5C;(wljB$<8J#9qu*hobx}>HSxECd+mFTBH z%qRvk;xqF~Qj5|OlT-ByDvN{|7#IX5Z)1tlRAgXaSjku<3o-}9E&>xu3=9mncnXSA z3lfV`<3R?gO}1op5UF5bV5nhCVa#SK7N}uNVJc-P3Ypx?YG}+<!(79X&eY5l!&Ji% z&sxI}&sM{j!j{ccoB~x>Qo>Nfkj0q7p2E@#Qo}x(k?jtn%;bA)(xMzi93cB;K(;bx z=9S!H&PmOiEXr==33Ash*5Z=H^weAIAa@j}mfT`Yyv3LVaVWyIti`D%r3E?+3=GAf zkYV5vU}R%t`p?EBz$m~h#y)upyFcqnW<NL0$uHQoHMK$ZAp%qx#08rLCcxeoC@D(J z%$t0VLws^DN3cVYF#`jGCQA_uNGZsgBK+F<$`W(HA(xY%n37r)F!>cntOK$cx*&rP zF^SCy{3WTyB}j$@Pj2VD%&0gynoC+5>`*WPb}g4pPG)gQa(-@sUBu*xT&;}alf}6` zRYBnw#g&|xo0=1!mzbN17HUP33=9n6n_IcpFe-x5#4VP@qV(b-2~b3^WT#dZ-(o8* zNJ%V7)nqG@ogBn#AP2S)On|NB&CDw<Nz6;mOw7rwN==x&fY*YJ59Bfa$@h3IvOs1Q zset%KAi^9(Sbzvi5CO8h$PZ)}Z+v`mPGWI!a%xUad_0mP<w5E|HWl$QFff2jE*4;5 zVBlcl6XRgxQsiLfVCG=q;4I>u+`wlh$u@_zogodB#yDDdqu5eJQrT0PB^gp!C!gnw z^x|h=V0g*Iz`zjVx01C8oY=T>5|dJM;tLW>GKx$<R+)la!<w9!oROLgN|T_10fa#% z1V}5$jN**Rh5Vu7HS9GE@q8&9H4O3mDGW6X@d7Cf!3--WU*uPlT`5>32r>@re~?W@ zVIVHVipe4Z`PsKPk`j}%lk)RYi^3Tg7@~O6QsYaC67!1F@{4j;azQdAx=xT8aUcSu zr6>W!@&XZ{_$mSi6+C9ZS-b$08x$sA6p&|OV`O8T{6Rp1WhGOQ-eg`unR<2x1_ox3 zf*A}93?&Q;7#A{RF{Lm{GSo7GvTPYcQCbOe4MP@74MP@d3W&v4!;l5azO1#(B^)K3 zHO$S7DXbz4&5X4yFn$e-1c=XC!UZZ2YFOb#0(T8VJP)WyC}Sw<0-0M6HJH7IAq#GB z4MRL{312fq7Jm(63P&$vFvChtzao%JHJSX1R6%}KgJeTcc-~^k%}+@!0*44VFw{YE zOnyZgAcY`8ld(t-RKN)ofg-R-3M4KKA~+dj85qE601`6@Gc-ZcT9b{0bn3xT2uips z*&y)&GOI`qq(C1;7=Q>v5P=r<AU!RhpasQ010x?32cs0D0HXk-7!wDh3{#b`ucr^F zg!j!)Da}dM1D9MY8H<t`7#NZnLCFn_C*Kv)wOPqj1WJ=dW*~1sOP_R*m5{UqVqqyc zia<GS(quQ`Buluul~COoAU$x=B5b-rS#;{;hr-b&c}Uh~64nUH)-xs-h-itif@7@+ zlwXQMz<ExS(QoDC^&(;kpb);rl9ZpH15TtxpnO#X3NlF93UYc8$OT0p(~5#Xwt|y2 zm;i?~xIhSIU|;|_uNV}y985AIQha>Oe9T3`1d9S!8!klXW4N;@6ci2kGrv9d%wH4% za%?P!NCXjJ$AX;=b}=Y5%mXD0aF(nGWyvaGgiG||<1_OzOXA~;l0Ys11q(PGFfcH% zF)%QIasf5-Ydtsv6Q5y;%dbT#Adjbkh%69+76hPZ-UJE)AqG(X#LOB3NZ|p_7e(2V z&kI>uxfSJrjLQWP*pe(L)f@mB3=THNDrvYmlM^_UCrgPa>4OqWQ9dpMKq>Apk^z#F zKM4Ck^ni*{toa0#K#w8mk(yjDBJGPc=@%7(EC#0qFaZu1aH=R_U|^U7G6XFhgTytZ zCu@mrRpNoy45>xMMcm-}q9ng4IincTJ}lyyd|On~2vkKDfwDyrNH0<?3bqks*)0y6 cTyU9V2P$TYWhN_%iE%RWF!C@7B?##O0A>I?4*&oF delta 2069 zcmaECwo8XEk(ZZ?fq{V`vCkrDg49Gl8O9kCwTrdf8B(}XxLX)fxKmlPgqxY8MBEut zcv5&<7*cqlBBCjb!3>&w6F*gNUdYJE%qhsgz)&Q_z`!thA&bc5{VZZ~7Wx_axvBb< zIf;4t<@rU~#RZAUsrn&BiFw7oo<5WRvdFLsFfcIiO_pPg(v)XlU|7jmBn>hJ#4Z98 z3JeSkw|EMQQVSA`QsYw+OA=KkFJyHP>0n@Bs9{WD%w{SUs9{WDDrG22nf#U2(3!cJ zk&&T<rG^Dev(_-xFxRl8Gc_~CFx4=`v(+%fv)3@Duw^q9mq7KklrYpVWHF|&r?B*b z)No9WWV>S_1+tj6q$n{n@A<5mAn?-u7E4KLafzEIXAuX;m0%^znRz9*m~&F|Zm|{= zW#*MkPGdLn1$p8YYjH_pdg?89kfV!JOKveH-eOFG_zK}Q*5cHX(gH0828LpgaSS{H zjBJcd|Jj%X7zG$P7&#cl*d{+@_h;2)_6u1#*_1<D9;CJi9@a&QAeVs60}~39>p4Uj zy(UlPh?{J{!QxP4#K6Fy$x_4uQl<?eaBBtmn6E4`2OP{f`H3m1MZS}bIb$8nk<8Em znIQ}!L_h?(^&m6&L7{|XNWkO^oR`f(&eLSP#g&{`kXe$LlUbFj$#jcJ&)^ngCVHT7 z+2mvvmn7%s7TASOzQxtbC^|WT+fx-3s!?3YiMgpc@p*~4sc71Y#2FYELN=e{Uc;ye zN>;a65{uG{i^M=qV#!XeEWX87T9A@hlB&s8Bt3aHuYnxMNf1N8$&5EMuec;JFF6zJ zy4cAtc`ev@K%V5Ctj=d44>Gez3B)%95oRF50z_DX2=U2vd=kb86J<eyAY+TT85kHq z#uS4L;9%kt<6z@b;9%xp=3wFAEaINLPSi}2bq;GgLmFcWTMBy%ZxmaKNGf|Ovm`?b z$K=1Fu9IW=MLc;K7#LnMfqbaRS_Dcex43c=lTvfy3ld8*ii|<FnSgxCnw*%Nk($g7 z%F|HH$-uw>(#a1>WigZM#6r1KIBFQ;`BE5y88j!W3#iFz@)v<*i$G2Un_d(OG7)0z z<Sc>w`l2vUBH)DNPmuk{DnM$ZK?F!qkr#*s3icv!Ab`UJ91<YYb3u-f0lAEki;)dG zWU3PL_4IK`EJ^guPbtkw)r*g>&&<m#iI4ZwWGd1Hd6TUuwWPEtFB#-7kYOOq#=yV; z!ptBmCNMBClrSt{T*#2cl)@;<P|F0$?qv)`NhQoR3|TBS3|XuxAQoE+OD{_;a|wG1 zM-6i`V+yMXLo;J73yfdGA_3ypvzBmzN|G8@cv-?#!w}C6Doe^3ikgrNWv^k#0vlSx z5YJP>+su%~SHqaX(F-aqIsJ-2F{8=kSELLIB^6MhGJ}Ki7E5k^N@@`}7*#>SOnyab zAT=OD(})r5Z%zIpP-$Hx0g{je5u6|b5g%af>L4+V$-9Jf>cO!JO0S?q4-y82UJ)q# zi}XNTeGp*)BGBRrq^BMfkf0Ksft!zwgHa4zD2Xv~FiJ3038TgkxERx9EQ)1dU`S>J z#ReEpKE<Kw04lpQnTkNkv&a<WGiY%X53&}L#6T=8MM)7Tlg*iYPB_UDt`1h1C4lt6 zMKQ`jkSr*7&Yzqq5^a)%WNjiy3cp5Bj$SnRhlo}^3pl=tK)J6d1f2UA{WSTC{6S^} zfQUd)YGf`hDJlYm`Yo2E{QMk9HUcHhB2X|w3SW>{ia;(Z0-0VE1kw!78ejq(`ru3w z%)r2q4GMTrL<@2-$p}fw@-g!<7l9-+#U?9BZH*T8$xP2E$;-@3M`RLkIRjP%Dbm2k kf(ekZw>WHa!FkCJ6gkD9+zF~(co;bt1sHi?kVz;W0My@>+5i9m diff --git a/datasets/custom_dataloader.py b/datasets/custom_dataloader.py index ddc2ed3..02850f5 100644 --- a/datasets/custom_dataloader.py +++ b/datasets/custom_dataloader.py @@ -9,12 +9,25 @@ from torch.utils.data.dataloader import DataLoader from tqdm import tqdm # from histoTransforms import RandomHueSaturationValue import torchvision.transforms as transforms +import torchvision import torch.nn.functional as F import csv from PIL import Image import cv2 import pandas as pd import json +import albumentations as A +from albumentations.pytorch import ToTensorV2 +from transformers import AutoFeatureExtractor +from imgaug import augmenters as iaa +import imgaug as ia +from torchsampler import ImbalancedDatasetSampler + + +class RangeNormalization(object): + def __call__(self, sample): + img = sample + return (img / 255.0 - 0.5) / 0.5 class HDF5MILDataloader(data.Dataset): """Represents an abstract HDF5 dataset. For single H5 container! @@ -28,68 +41,89 @@ class HDF5MILDataloader(data.Dataset): data_cache_size: Number of HDF5 files that can be cached in the cache (default=3). """ - def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=20): + def __init__(self, file_path, label_path, mode, n_classes, backbone=None, load_data=False, data_cache_size=20, bag_size=1024): super().__init__() self.data_info = [] self.data_cache = {} self.slideLabelDict = {} + self.files = [] self.data_cache_size = data_cache_size self.mode = mode self.file_path = file_path # self.csv_path = csv_path self.label_path = label_path self.n_classes = n_classes - self.bag_size = 120 + self.bag_size = bag_size + self.backbone = backbone # self.label_file = label_path recursive = True + # read labels and slide_path from csv - - # df = pd.read_csv(self.csv_path) - # labels = df.LABEL - # slides = df.FILENAME with open(self.label_path, 'r') as f: - self.slideLabelDict = json.load(f)[mode] - - self.slideLabelDict = {Path(x).stem : y for (x,y) in self.slideLabelDict} - - - # if Path(slides[0]).suffix: - # slides = list(map(lambda x: Path(x).stem, slides)) - - # print(labels) - # print(slides) - # self.slideLabelDict = dict(zip(slides, labels)) - # print(self.slideLabelDict) - - #check if files in slideLabelDict, only take files that are available. + temp_slide_label_dict = json.load(f)[mode] + for (x, y) in temp_slide_label_dict: + x = Path(x).stem + x_complete_path = Path(self.file_path)/Path(x + '.hdf5') + if x_complete_path.is_file(): + self.slideLabelDict[x] = y + self.files.append(x_complete_path) - files_in_path = list(Path(self.file_path).rglob('*.hdf5')) - files_in_path = [x.stem for x in files_in_path] - # print(len(files_in_path)) - # print(files_in_path) - # print(list(self.slideLabelDict.keys())) - # for x in list(self.slideLabelDict.keys()): - # if x in files_in_path: - # path = Path(self.file_path) / (x + '.hdf5') - # print(path) - - self.files = [Path(self.file_path)/ (x + '.hdf5') for x in list(self.slideLabelDict.keys()) if x in files_in_path] - - print(len(self.files)) - # self.files = list(map(lambda x: Path(self.file_path) / (Path(x).stem + '.hdf5'), list(self.slideLabelDict.keys()))) for h5dataset_fp in tqdm(self.files): - # print(h5dataset_fp) self._add_data_infos(str(h5dataset_fp.resolve()), load_data) - # print(self.data_info) - self.resize_transforms = transforms.Compose([ - transforms.ToPILImage(), - transforms.Resize(256), + + self.resize_transforms = A.Compose([ + A.SmallestMaxSize(max_size=256) ]) + sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1") + sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2") + sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3") + sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4") + sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5") + + self.train_transforms = iaa.Sequential([ + iaa.AddToHueAndSaturation(value=(-13, 13), name="MyHSV"), + sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")), + iaa.Fliplr(0.5, name="MyFlipLR"), + iaa.Flipud(0.5, name="MyFlipUD"), + sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")), + iaa.OneOf([ + sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")), + sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")), + sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine")) + ], name="MyOneOf") + + ], name="MyAug") + + # self.train_transforms = A.Compose([ + # A.HueSaturationValue(hue_shift_limit=13, sat_shift_limit=2, val_shift_limit=0, always_apply=True, p=1.0), + # # A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=0, val_shift_limit=0, always_apply=False, p=0.5), + # # A.RandomGamma(), + # # A.HorizontalFlip(), + # # A.VerticalFlip(), + # # A.RandomRotate90(), + # # A.OneOf([ + # # A.ElasticTransform(alpha=150, sigma=20, alpha_affine=50), + # # A.Affine( + # # scale={'x': (0.95, 1.05), 'y': (0.95, 1.05)}, + # # rotate=(-45, 45), + # # shear=(-4, 4), + # # cval=8, + # # ) + # # ]), + # A.Normalize(), + # ToTensorV2(), + # ]) + self.val_transforms = transforms.Compose([ + # A.Normalize(), + # ToTensorV2(), + RangeNormalization(), + transforms.ToTensor(), + ]) self.img_transforms = transforms.Compose([ transforms.RandomHorizontalFlip(p=1), transforms.RandomVerticalFlip(p=1), @@ -100,32 +134,45 @@ class HDF5MILDataloader(data.Dataset): RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)), transforms.ToTensor() ]) + if self.backbone == 'dino': + self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16') # self._add_data_infos(load_data) - def __getitem__(self, index): # get data batch, label, name = self.get_data(index) out_batch = [] + seq_img_d = self.train_transforms.to_deterministic() if self.mode == 'train': # print(img) # print(img.shape) - for img in batch: - img = self.img_transforms(img) - img = self.hsv_transforms(img) + for img in batch: # expects numpy + img = img.numpy().astype(np.uint8) + if self.backbone == 'dino': + img = self.feature_extractor(images=img, return_tensors='pt') + # img = self.resize_transforms(img) + img = seq_img_d.augment_image(img) + img = self.val_transforms(img) out_batch.append(img) else: for img in batch: - img = transforms.functional.to_tensor(img) + img = img.numpy().astype(np.uint8) + if self.backbone == 'dino': + img = self.feature_extractor(images=img, return_tensors='pt') + img = self.resize_transforms(img) + + img = self.val_transforms(img) out_batch.append(img) + if len(out_batch) == 0: # print(name) - out_batch = torch.randn(100,3,256,256) + out_batch = torch.randn(self.bag_size,3,256,256) else: out_batch = torch.stack(out_batch) - out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch + # print(out_batch.shape) + # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch label = torch.as_tensor(label) label = torch.nn.functional.one_hot(label, num_classes=self.n_classes) @@ -138,29 +185,7 @@ class HDF5MILDataloader(data.Dataset): wsi_name = Path(file_path).stem if wsi_name in self.slideLabelDict: label = self.slideLabelDict[wsi_name] - wsi_batch = [] - # with h5py.File(file_path, 'r') as h5_file: - # numKeys = len(h5_file.keys()) - # sample = list(h5_file.keys())[0] - # shape = (numKeys,) + h5_file[sample][:].shape - # for tile in h5_file.keys(): - # img = h5_file[tile][:] - - # print(img) - # if type == 'images': - # t = 'data' - # else: - # t = 'label' idx = -1 - # if load_data: - # for tile in h5_file.keys(): - # img = h5_file[tile][:] - # img = img.astype(np.uint8) - # img = self.resize_transforms(img) - # wsi_batch.append(img) - # idx = self._add_to_cache(wsi_batch, file_path) - # wsi_batch.append(img) - # self.data_info.append({'data_path': file_path, 'label': label, 'shape': shape, 'name': wsi_name, 'cache_idx': idx}) self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx}) def _load_data(self, file_path): @@ -173,30 +198,15 @@ class HDF5MILDataloader(data.Dataset): for tile in h5_file.keys(): img = h5_file[tile][:] img = img.astype(np.uint8) - img = self.resize_transforms(img) + img = torch.from_numpy(img) + # img = self.resize_transforms(img) wsi_batch.append(img) + wsi_batch = torch.stack(wsi_batch) + wsi_batch, _ = to_fixed_size_bag(wsi_batch, self.bag_size) idx = self._add_to_cache(wsi_batch, file_path) file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path) self.data_info[file_idx + idx]['cache_idx'] = idx - # for type in ['images', 'labels']: - # for key in tqdm(h5_file[f'{self.mode}/{type}'].keys()): - # img = h5_file[data_path][:] - # idx = self._add_to_cache(img, data_path) - # file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == data_path) - # self.data_info[file_idx + idx]['cache_idx'] = idx - # for gname, group in h5_file.items(): - # for dname, ds in group.items(): - # # add data to the data cache and retrieve - # # the cache index - # idx = self._add_to_cache(ds.value, file_path) - - # # find the beginning index of the hdf5 file we are looking for - # file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path) - - # # the data info should have the same index since we loaded it in the same way - # self.data_info[file_idx + idx]['cache_idx'] = idx - # remove an element from data cache if size was exceeded if len(self.data_cache) > self.data_cache_size: # remove one item from the cache at random @@ -223,6 +233,182 @@ class HDF5MILDataloader(data.Dataset): # data_info_type = [di for di in self.data_info if di['type'] == type] # return data_info_type + def get_name(self, i): + # name = self.get_data_infos(type)[i]['name'] + name = self.data_info[i]['name'] + return name + + def get_labels(self, indices): + + return [self.data_info[i]['label'] for i in indices] + + def get_data(self, i): + """Call this function anytime you want to access a chunk of data from the + dataset. This will make sure that the data is loaded in case it is + not part of the data cache. + i = index + """ + # fp = self.get_data_infos(type)[i]['data_path'] + fp = self.data_info[i]['data_path'] + if fp not in self.data_cache: + self._load_data(fp) + + # get new cache_idx assigned by _load_data_info + # cache_idx = self.get_data_infos(type)[i]['cache_idx'] + cache_idx = self.data_info[i]['cache_idx'] + label = self.data_info[i]['label'] + name = self.data_info[i]['name'] + # print(self.data_cache[fp][cache_idx]) + return self.data_cache[fp][cache_idx], label, name + +class GenesisDataloader(data.Dataset): + """Represents an abstract HDF5 dataset. For single H5 container! + + Input params: + file_path: Path to the folder containing the dataset (one or multiple HDF5 files). + mode: 'train' or 'test' + load_data: If True, loads all the data immediately into RAM. Use this if + the dataset is fits into memory. Otherwise, leave this at false and + the data will load lazily. + data_cache_size: Number of HDF5 files that can be cached in the cache (default=3). + + """ + def __init__(self, file_path, label_path, mode, load_data=False, data_cache_size=5, debug=False): + super().__init__() + + self.data_info = [] + self.data_cache = {} + self.slideLabelDict = {} + self.files = [] + self.data_cache_size = data_cache_size + self.mode = mode + self.file_path = file_path + # self.csv_path = csv_path + self.label_path = label_path + # self.n_classes = n_classes + self.bag_size = 120 + # self.transforms = transforms + self.input_size = 256 + + # for + # self.files = list(Path(self.file_path).rglob('*.hdf5')) + home = Path.cwd().parts[1] + with open(self.label_path, 'r') as f: + temp_slide_label_dict = json.load(f)[mode] + for x in temp_slide_label_dict: + + if Path(x).parts[1] != home: + x = x.replace(Path(x).parts[1], home) + self.files.append(Path(x)) + # x = Path(x).stem + # x_complete_path = Path(self.file_path)/Path(x + '.hdf5') + # if x_complete_path.is_file(): + # self.slideLabelDict[x] = y + # self.files.append(x_complete_path) + + for h5dataset_fp in tqdm(self.files): + self._add_data_infos(str(h5dataset_fp.resolve()), load_data) + + self.resize_transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToPILImage(), + torchvision.transforms.Resize(self.input_size), + ]) + self.resize_transforms = A.Compose([ + A.SmallestMaxSize(max_size=256) + ]) + sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1") + sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2") + sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3") + sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4") + sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5") + + self.train_transforms = iaa.Sequential([ + iaa.AddToHueAndSaturation(value=(-13, 13), name="MyHSV"), + sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")), + iaa.Fliplr(0.5, name="MyFlipLR"), + iaa.Flipud(0.5, name="MyFlipUD"), + sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")), + iaa.OneOf([ + sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")), + sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")), + sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine")) + ], name="MyOneOf") + + ], name="MyAug") + self.val_transforms = transforms.Compose([ + # A.Normalize(), + # ToTensorV2(), + RangeNormalization(), + transforms.ToTensor(), + + ]) + + + def __getitem__(self, index): + # get data + img, name = self.get_data(index) + # out_batch = [] + seq_img_d = self.train_transforms.to_deterministic() + + if self.mode == 'train': + + img = img.numpy().astype(np.uint8) + img = seq_img_d.augment_image(img) + img = self.val_transforms(img) + else: + img = img.numpy().astype(np.uint8) + img = self.val_transforms(img) + # out_batch.append(img) + + return {'data': img, 'label': label} + # return out_batch, label + + def __len__(self): + return len(self.data_info) + + def _add_data_infos(self, file_path, load_data): + img_name = Path(file_path).stem + label = Path(file_path).parts[-2] + + idx = -1 + self.data_info.append({'data_path': file_path, 'label': label, 'name': img_name, 'cache_idx': idx}) + + def _load_data(self, file_path): + """Load data to the cache given the file + path and update the cache index in the + data_info structure. + """ + with h5py.File(file_path, 'r') as h5_file: + + tile = list(h5_file.keys())[0] + img = h5_file[tile][:] + img = img.astype(np.uint8) + # img = self.resize_transforms(img) + # wsi_batch.append(img) + idx = self._add_to_cache(img, file_path) + file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path) + self.data_info[file_idx + idx]['cache_idx'] = idx + + # remove an element from data cache if size was exceeded + if len(self.data_cache) > self.data_cache_size: + # remove one item from the cache at random + removal_keys = list(self.data_cache) + removal_keys.remove(file_path) + self.data_cache.pop(removal_keys[0]) + # remove invalid cache_idx + # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info] + self.data_info = [{'data_path': di['data_path'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info] + + def _add_to_cache(self, data, data_path): + """Adds data to the cache and returns its index. There is one cache + list for every file_path, containing all datasets in that file. + """ + if data_path not in self.data_cache: + self.data_cache[data_path] = [data] + else: + self.data_cache[data_path].append(data) + return len(self.data_cache[data_path]) - 1 + def get_name(self, i): # name = self.get_data_infos(type)[i]['name'] name = self.data_info[i]['name'] @@ -248,6 +434,45 @@ class HDF5MILDataloader(data.Dataset): return self.data_cache[fp][cache_idx], label, name +class RandomHueSaturationValue(object): + + def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5): + + self.hue_shift_limit = hue_shift_limit + self.sat_shift_limit = sat_shift_limit + self.val_shift_limit = val_shift_limit + self.p = p + + def __call__(self, sample): + + img = sample #,lbl + + if np.random.random() < self.p: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32 + h, s, v = cv2.split(img) + hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1) + hue_shift = np.uint8(hue_shift) + h += hue_shift + sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1]) + s = cv2.add(s, sat_shift) + val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1]) + v = cv2.add(v, val_shift) + img = cv2.merge((h, s, v)) + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + return img #, lbl + +def to_fixed_size_bag(bag: torch.Tensor, bag_size: int = 512): + + # get up to bag_size elements + bag_idxs = torch.randperm(bag.shape[0])[:bag_size] + bag_samples = bag[bag_idxs] + + # zero-pad if we don't have enough samples + zero_padded = torch.cat((bag_samples, + torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3]))) + return zero_padded, min(bag_size, len(bag)) + + class RandomHueSaturationValue(object): def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5): @@ -285,48 +510,81 @@ if __name__ == '__main__': train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv' data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/' # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json' - label_path = f'/{home}/ylan/DeepGraft/training_tables/split_Aachen_PAS_all.json' - output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/' - # os.makedirs(output_path, exist_ok=True) - - - dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=6) - data = DataLoader(dataset, batch_size=1) + label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json' + output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments' + os.makedirs(output_path, exist_ok=True) + + n_classes = 2 + + dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20) + # print(dataset.dataset) + a = int(len(dataset)* 0.8) + b = int(len(dataset) - a) + train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b]) + dl = DataLoader(train_ds, None, sampler=ImbalancedDatasetSampler(train_ds), num_workers=5) + dl = DataLoader(train_ds, None, num_workers=5) + + # data = DataLoader(dataset, batch_size=1) # print(len(dataset)) - x = 0 + # # x = 0 c = 0 - for item in data: - if c >=10: - break + label_count = [0] *n_classes + for item in dl: + # if c >=10: + # break bag, label, name = item - print(bag) - # # print(bag.shape) - # if bag.shape[1] == 1: - # print(name) - # print(bag.shape) + label_count[np.argmax(label)] += 1 + print(label_count) + print(len(train_ds)) + # # # print(bag.shape) + # # if bag.shape[1] == 1: + # # print(name) + # # print(bag.shape) # print(bag.shape) - # print(name) - # out_dir = Path(output_path) / name - # os.makedirs(out_dir, exist_ok=True) - - # # print(item[2]) - # # print(len(item)) - # # print(item[1]) - # # print(data.shape) - # # data = data.squeeze() - # bag = item[0] - # bag = bag.squeeze() - # for i in range(bag.shape[0]): - # img = bag[i, :, :, :] - # img = img.squeeze() - # img = img*255 - # img = img.numpy().astype(np.uint8).transpose(1,2,0) + + # # out_dir = Path(output_path) / name + # # os.makedirs(out_dir, exist_ok=True) + + # # # print(item[2]) + # # # print(len(item)) + # # # print(item[1]) + # # # print(data.shape) + # # # data = data.squeeze() + # # bag = item[0] + # bag = bag.squeeze() + # original = original.squeeze() + # for i in range(bag.shape[0]): + # img = bag[i, :, :, :] + # img = img.squeeze() + + # img = ((img-img.min())/(img.max() - img.min())) * 255 + # print(img) + # # print(img) + # img = img.numpy().astype(np.uint8).transpose(1,2,0) + + + # img = Image.fromarray(img) + # img = img.convert('RGB') + # img.save(f'{output_path}/{i}.png') + + - # img = Image.fromarray(img) - # img = img.convert('RGB') - # img.save(f'{out_dir}/{i}.png') - c += 1 + # o_img = original[i,:,:,:] + # o_img = o_img.squeeze() + # print(o_img.shape) + # o_img = ((o_img-o_img.min())/(o_img.max()-o_img.min()))*255 + # o_img = o_img.numpy().astype(np.uint8).transpose(1,2,0) + # o_img = Image.fromarray(o_img) + # o_img = o_img.convert('RGB') + # o_img.save(f'{output_path}/{i}_original.png') + # c += 1 + # break # else: break # print(data.shape) - # print(label) \ No newline at end of file + # print(label) + # a = [torch.Tensor((3,256,256))]*3 + # b = torch.stack(a) + # print(b) + # c = to_fixed_size_bag(b, 512) + # print(c) \ No newline at end of file diff --git a/datasets/data_interface.py b/datasets/data_interface.py index 12a0f8c..056e6ff 100644 --- a/datasets/data_interface.py +++ b/datasets/data_interface.py @@ -8,6 +8,8 @@ from torchvision import transforms from .camel_dataloader import FeatureBagLoader from .custom_dataloader import HDF5MILDataloader from pathlib import Path +from transformers import AutoFeatureExtractor +from torchsampler import ImbalancedDatasetSampler class DataInterface(pl.LightningDataModule): @@ -56,9 +58,10 @@ class DataInterface(pl.LightningDataModule): train=True) a = int(len(dataset)* 0.8) b = int(len(dataset) - a) - print(a) - print(b) - self.train_dataset, self.val_dataset = random_split(dataset, [a, b]) + # print(a) + # print(b) + self.train_dataset, self.val_dataset = random_split(dataset, [a, b]) # returns data.Subset + # self.train_dataset = self.instancialize(state='train') # self.val_dataset = self.instancialize(state='val') @@ -72,7 +75,7 @@ class DataInterface(pl.LightningDataModule): def train_dataloader(self): - return DataLoader(self.train_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=True) + return DataLoader(self.train_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=False) def val_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=False) @@ -106,7 +109,7 @@ class DataInterface(pl.LightningDataModule): class MILDataModule(pl.LightningDataModule): - def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, *args, **kwargs): + def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs): super().__init__() self.data_root = data_root self.label_path = label_path @@ -121,41 +124,74 @@ class MILDataModule(pl.LightningDataModule): self.num_bags_test = 50 self.seed = 1 + self.backbone = backbone self.cache = True + self.fe_transform = None def setup(self, stage: Optional[str] = None) -> None: - # if self.n_classes == 2: - # if stage in (None, 'fit'): - # dataset = HDF5Dataset(self.data_root, mode='train', n_classes=self.n_classes) - # a = int(len(dataset)* 0.8) - # b = int(len(dataset) - a) - # self.train_data, self.valid_data = random_split(dataset, [a, b]) - - # if stage in (None, 'test'): - # self.test_data = HDF5Dataset(self.data_root, mode='test', n_classes=self.n_classes) - # else: home = Path.cwd().parts[1] - # self.label_path = f'{home}/ylan/DeepGraft_project/code/split_debug.json' - # train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train_small.csv' - # test_csv = f'/{home}/ylan/DeepGraft_project/code/debug_test_small.csv' - if stage in (None, 'fit'): - dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes) - # print(len(dataset)) + dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, backbone=self.backbone) a = int(len(dataset)* 0.8) b = int(len(dataset) - a) self.train_data, self.valid_data = random_split(dataset, [a, b]) if stage in (None, 'test'): - self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes) + self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone) return super().setup(stage=stage) def train_dataloader(self) -> DataLoader: - return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, + return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, + #sampler=ImbalancedDatasetSampler(self.train_data) + def val_dataloader(self) -> DataLoader: + return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers) + + def test_dataloader(self) -> DataLoader: + return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers) + +class DataModule(pl.LightningDataModule): + + def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs): + super().__init__() + self.data_root = data_root + self.label_path = label_path + self.batch_size = batch_size + self.num_workers = num_workers + self.image_size = 384 + self.n_classes = n_classes + self.target_number = 9 + self.mean_bag_length = 10 + self.var_bag_length = 2 + self.num_bags_train = 200 + self.num_bags_test = 50 + self.seed = 1 + + self.backbone = backbone + self.cache = True + self.fe_transform = None + + + def setup(self, stage: Optional[str] = None) -> None: + home = Path.cwd().parts[1] + + if stage in (None, 'fit'): + dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, backbone=self.backbone) + a = int(len(dataset)* 0.8) + b = int(len(dataset) - a) + self.train_data, self.valid_data = random_split(dataset, [a, b]) + + if stage in (None, 'test'): + self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone) + + return super().setup(stage=stage) + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True, + #sampler=ImbalancedDatasetSampler(self.train_data), def val_dataloader(self) -> DataLoader: return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers) diff --git a/fine_tune.py b/fine_tune.py new file mode 100644 index 0000000..5898e75 --- /dev/null +++ b/fine_tune.py @@ -0,0 +1,42 @@ +from transformers import AutoFeatureExtractor, ViTModel +from transformers import Trainer, TrainingArguments +from torchvision import models +import torch +from datasets.custom_dataloader import DinoDataloader + + + +def fine_tune_transformer(args): + + data_path = args.data_path + model = args.model + n_classes = args.n_classes + + if model == 'dino': + feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16') + model_ft = ViTModel.from_pretrained('facebook/dino-vitb16', num_labels=n_classes) + + training_args = TrainingArguments( + output_dir = f'logs/fine_tune/{model}', + per_device_train_batch_size=16, + evaluation_strategy="steps", + num_train_epochs=4, + fp16=True, + save_steps=100, + eval_steps=100, + logging_steps=10, + learning_rate=2e-4, + save_total_limit=2, + remove_unused_columns=False, + push_to_hub=False, + report_to='tensorboard', + load_best_model_at_end=True, + ) + + dataset = DinoDataloader(args.data_path, mode='train') #, transforms=transform + + trainer = Trainer( + model = model_ft, + args=training_args, + + ) \ No newline at end of file diff --git a/models/TransMIL.py b/models/TransMIL.py index ce40a26..69089de 100755 --- a/models/TransMIL.py +++ b/models/TransMIL.py @@ -86,7 +86,7 @@ class TransMIL(nn.Module): h = self.norm(h)[:,0] #---->predict - logits = self._fc2(torch.sigmoid(h)) #[B, n_classes] + logits = self._fc2(h) #[B, n_classes] Y_hat = torch.argmax(logits, dim=1) Y_prob = F.softmax(logits, dim = 1) results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat} diff --git a/models/__pycache__/TransMIL.cpython-39.pyc b/models/__pycache__/TransMIL.cpython-39.pyc index 4e1ddff6d6f3cbadfd7f1c0c4686a7806f5896e8..cb0eb6fd5085feda9604ec0f24420d85eb1d939d 100644 GIT binary patch delta 953 zcmZpbYLMbh<mKgJU|?Xdf100ExRLiEBjc^f?-?gBT1{?Xl9ktWNlh(qFG@@+(a+6K zNzEzN4=GB_EB5vD(JQDd;+wpQsaRfsfq|ijpMilvlc`9Mfq~%`V?~iDNDNGfO?F{U z*5m{6B|wtA@$tzyiN(dqsX00E@kk1VK}y6nUts286yOE1<v|2kqs-(tEJ{)!vx_)D zf?%Z((?MLm$r`Lr7!@b;vT1Rvfdte+gvMkiHZ?Y|Ch^I6Y$gI=4PZhHBrh>}6PuB% zB#0>mBBVhC*gB9qi{wCDuxc;?wi2ul<SbbR28Lo51_lNWMn1M8`Nf*-j5?Yi4T_*( z5G?|`-vT56HV~xc7H2_zaePi<Wol88%;ZoGeO^6KaIh97=B4BnnN6O^p`rsa795hd zSmM)?jiBLKWDhb3>;SM+z^>q%{E)+&QE0L*=VOT)riF~PoF!~o0yUf|OexI0OtoAk zf>}b7^|(~nKrSlMoE*WW9wB^-H7zGUv7|@~WHWPSUdb({yn<US$*Gw+w^)h`i%N<> zZYt6T*=Gh)!<Jf6keHWpizT@<C9%i>ByI^J96<yq%8Q&B7#Jo$;1W*<1@tYp?DE8- z^x|8L8AXN+3=C0B@g7l3@!?Tii76@ZIjMQ+B^gD=AcMJ*bBg0j^0QO(ii=!9hPi_X zH;@!hQEG8%PDycmN@j8iQu+XUTViq?w^}_HBOfClBM)Ph06aCO<rkGF7Nr!KgUkU3 z6PN&p8dxVNHbB0EMjH=r5lG|~M|^y4VrE`^ye4xID5x?E#2FZhKv7%728v5YaOi6? z7CD0q20PUaoKipz)|3PX*e#Bd{G#Lxy}Z04naOfIhKyd519&vUK(P+CqzIhI>_7(i sfe4UMw^+gX=@x525hxy5Qp*x^ib6rM5IaCDkQGP)0QQjB<nugc0Nnw}e*gdg delta 915 zcmZpWYL((m<mKgJU|?Wa!k&=Su#xv6BV*L$_ly%54JJ1*$%^Xd=BK3Q6zhi+CFT|T zdiv-UR2K10-osQZ&(FZXP{hZ;z@W)gB*4JHaEq~`NCYGXCPXLuF(+&Cg81SfN#6MQ z<ebFf;^fqvocMSog+d@DqML6q^DqkVfY@>%0<2Mb@;4SGDUjJk93VlkQi$mwF7IR` z)+dY#lf~JzxK%*{Y9K;=vJaaY8(5Rr<O((u0k8%zAqtWgpS*|7$W;Qwlmrn{AOdV1 z$el&9ATC%nm;hS|)(3JHNV1rPfq{X8k&mrNZt_QFk;(4tZ*(+3niN1mAzB3XgE>e5 zY$!<EEzW}c;`p4z%G9DF>B;dN`n<ZJAYm;^%uC5DGMzk|Lq$ghqyrS1w^-uSl8vBY zT4V<j20H@m9I#7xC%@+KW|Ww0&iPobgl&O94bwu#g^abFC2U!OHJm9-Da^f0wOl1a zS;CWzxK!9c?kdul9K)r~C31^3Ehj&*q)2meFPCIJ$Q4C;ARA0U?qN%<C`impxy6!P znvz-&Tk5e^^%<g_A3kTUk-%=Fy+%#>SfiACwTi50gP-EOfL=ckqACRS+jLXrYo zc6nk^dhspBj3NUD28Jl6c#kNi`0yyM#FUiyoYcJZl8hoFTaam7$vMUGCHdK@dBsI8 zAdB2Ugeypjrzo|!G^eCEJ|#1`1S!RU13`RpB)1x)z~lyQsd_Vz7I65032<<M6@j7w z<UeSf@$eRbL~e1!$LA(y=EcWrG8chDC$m7DfuRT#uSIO2NMr;DyC!3i6Ubn&t3VDa z0y#}nvIrDxw>V1ji;^?+^74wLCrk1eGI~z-=FyY_#WvWEB5(o&`LM_rWQfz`J|0<l QRHKl*1a`OR<WoFm0Jyll!vFvP diff --git a/models/__pycache__/model_interface.cpython-39.pyc b/models/__pycache__/model_interface.cpython-39.pyc index 3e81c0ccdd33e6fe68a5ffe1b602990134944c80..0bf337675c76de4097e7f7723b99de0120f9594c 100644 GIT binary patch literal 11282 zcmYe~<>g{vU|<kCG$qN#fq~&Mh=Yt-7#J8F7#J9e?HCytQW#Pga~Pr^G-DJKn9m%= z45nG4Sim%E6f2l!i(&`U98sJp3@J=GT)Es)+>9W-%sD){yit5$HcJkFu0WIkn9Z6a zm@5<|1ZK142<M7KiGbPcIik5@QDR^=M~-N&c$9dqM3h9XWRxUWj4?+lN(yY3bgoR4 z3{*@uSB8m!ks((uN**er5T%g9kiwZGm8%$~2o+O`QcmGY;m%RXRgF>wi}B>B<*G-i zgW0?}8o8QLnqW3xj#jRAly<I8ln$8BpQD?r7p2F@;LeaDkRsT^kRq6>oTcB)9A)6n zkRp^K+`^C|oGO!L*vuSdl)@Oypegba6taGrjJH?<5=$~P8E>%_R2JmqmuNEH5=bse zO)N=`&o3y+%+0JyEz)GX#SIn9$uBO}WV|JuT$Ep29G{w3Qj}j%8DElLlngRM(y_QQ zH#fDUC^Ok7zqmL!GcP?S)h972HAj>AmK2x~oSUCtl98F0?hG=)6=J|GvEcl)l90rr z^wbg*aelAj6l4KS##^jlXT4-(U|`T>yv13TSrT8Io0yXW@<MKYN@@<seQZUk#d)bE znoPGu97{{`-BJ@vN{domD@uwIlS}f8ZgGTVhWLUMYck&A@ky*qEpjhPOmR-kO$G%E zGG>NyigOqk7*ZLc7*iOcnA#cA7*iNim|8fZm{XWjSXvmOSW;M1*jgB(SX0<jI9eE@ z*itxCxLO#Z*i*Pucv={uI8u01_*xjEI8*pj1X>uPxKadDgjyJ)xKo5vL|Pc4cv3`D z#9A1lc-t9R7^3)s88pRj@qwM_nOBlpl$MyBnheSSQ0Fl4F)%RjGB7Z3gEB-4GXp~j zLkVLILo?F?riBcQ3^fc3n4vTah%RBRVQOZqVTfm|VXk3_XRl$*;;3PW=cr*xVU%R3 zVTk8UVM<|cVJP9MVaVcWW^`cytL6srQdqKCiY}D!WbuM|tSM~0Of?Mgyd^vf_)2)P z_`#y=Akh+@ECDc&qnELSCrc2_=7h=!rEr4zT)j*sJXyjhJP`yhooyjwEo)u~$i5VC zsG{1I!jb~FD}^NmYEueJ3X(l3EGckXAnbHjkf|j+S)yP!aP)%0Lac^0OB@tGV4eh+ z$D6_f=1JDDr0_{HWJ%R9#7n2}*D%D()Uc!ofWt~QMX-h;UJk-f5z1zoz*sb=M1Fxn zits{^xr!x9aJKRU#=Ij?l_J?p6PSvmN>s8`L9yM-Sj$$TR>QVHeIY|LV+uoxcq@}6 zLoIuWMwVs`dmlqBM-9gUt%VFO46!`5oGB7DoblQd7z^)|Xe`jI;V98qppzoGkg-HJ zMXH&xL}P&-$W>W#3z<qZ7U(ZzSjbw#5D&854QjhI*!Bq}Dp>|l+qp{AYPjIG%b?lL zUBiuGJ5ym4*miCt+hw7)%hfP|Y+uL%vN!Jn)K+=0t#3*Uvy8HgQxsAZ!FDs(^3?EP zF?|QvA4(9DYIw3t;!VN!@j(3qN_Wb6B~Y`KvzaC^7bT=f)bfH<m1ty{rKqH+rl_T; z_cE7gEHGclP@<7#k;0O~lA_Ve1eRA#(ZngQn}X&Sh#3nxK|V@h2xidK^1H>Hlb@ap zDt%x%k%@tUft7)QAsCbs#TXbE(iv(PV)<%8G0Kp_IEQH>Qz1(*!%C(ikS<N8Tdc|X zrFkW{*pu^%ic*tHRx+$)y2YetaEmdsNP>ZZ0ZjaI)X&JzP1Uc=NzBtP&o9a@E=WvH z)ek93%q#Zw^wD)mO)YRQN=z%!hm?5`Dn1idR_GN}-r};!$t*5O&d)8dGiG35_zZGz zl{RXbq8A^ZnU`4-AFpSVlb@WJQ*5V)(Df4JwObtJshQ~+C8;TzOt-l6N^|3ra}tY- zQ;To0CzhoaC8np|V$ID=F3Q(rEK+1(V7SGan+W1D6)Ax{0~TT_OU#J}70?AmscD%N zw^&P3i%a5fv7{uHB!WZb7DrNIa&}UFUg}E*CI*IEEGe0J`BfsIQZFe#KU*Kf)h){` zNisCM#Q`bX4J|@68E<hF6s49FC1&QOro8<B|NnotA|6o1z*M9MvRn(~9;kX#gO`8* z|NpNEk$TC%!oW~u2$JVXO-sv6&P>fK$xAJ%Qt=D(bn$f52L-TCYGP4dW?s5$MPhD2 zPHM3gxSq=>O^TAuE6YsDOpHg?8eg87l3JV^pJb2;N)NZ#iZemFUV>_nmm&-d3@<^c z1XP!1rxxX<=EN6gR;7XhrMRRhGbQz<5y;_|w^)ly3sQ@Ugh5W@fW^)&k>bR%)cB0b zg4Cjd#G=I9)RNSq;#*uf`H3m<;1o~<GW-?`sD8f139ceRHTNyfJa}|*<m9Kv7bKQs z++r<C%u7!#(g0b(1NIZ7&Mv;i>KGd2?|e(Z7hH!s=jY_4CYNO9=iTCPOinH>N=&Z2 z#pRrzk(%e7SWu97i`~sII61#4^%h%DYI0&u&MnS>qSWNfVvr8*;DXfTOi;*{RNmru z&d*CL1&RA6mK0@H+!9EJ3B>0@L|Bt^^7B$}@s=g#WTwQ!9K-`oDNtS%4=BOJ7iW~F zrRAjFVuvz{K$)&c6YQJ((vtWzNNrzSBmffPPb<pLjYlMtTY^ye_*7VPphys;kOLf! z@o6PRpt`k42PDP<%0ag{k~1=MQi@XZZZYNM-Qo&PEi6sVE6GgExy6MP8N5ZQg{7HA zsm1Z>MTsf51RYZn3raG}QXR|E1M>58j8bl~gk`3d-(vI0%u7uyy2aw|>JwU|53+_8 z90|9WGfI;{QJ9!=iwm5|GSf0si*7NcCEsFm&d)0Y84;A~6MBo&H?aa@rY7Glmg3Z$ zv|FrT*B9A=lD$3H%<ROX^x|9W(6EE}hdnnxr8FnCIEu9>wb;<&7Ax5OU`K$3Obv=a z$uLR?EsQer(pNIx;*5tRrucYp8C4_$vdo%+fk72i^qm8hFN}PQ9E=={T&!#$l39$I zi&2P?hmq+&3lk3`2bgAKWckm+ghYPl;1yux0?TnRvVmzN@+TXo03#nGA2Sal8zUDJ z8zUbh4<kr~hf#!?1*Dpphmnhsg^`JogOQDqhmq|!7dsyl50eN;>^}>t>>mzxKCp|p zn1n#`jBNinSoj!O7+Dzkz`7wM*qrZG;y4P+WKf<4c^8C1tuRpLa|Q*%B?bnD8pasr zTIO1oTGkqt1q=%rYFTQS7cizUE@TArm}?lam{ORsnTnih7_ykt7%~}Zn3pj2G1jux zuw}8-vgL7sT5mPX3&A1_SV1~#*=iWF*lO5Pn0py(*=yKSSR_I11NIbFNrqa^8uk=6 zNrqaGKK2@p8kU7jwcI5f3pi_77BbfI)G%am)$m9%q_DFxNHU~wm@_al*f79FIYFW& z+%;Uyj5R#jOi)?w8qON78kQ8UIV`~pn%sWiV!BA1fq`KqBc$YD%gImAEGfRl78zeq zl%I5qH8MUUu>_JCL4|OUIH*8lfh5TyP*N!dr^_60e}Lx}Yg$fzV#zJ`;=<C@R8TJi zq$CuiL@+0@xFkL!GbJT8FTS`Wu_W~tXK5ZxT_#8!IGrbe6tg83rROGAK-|QfnOCC8 zTa*EkU`<LaNzS;%38vyRQy@x;N<eaq@wYgjNxb+LPf==dX--LTd`f0=Nl_ulc(`-H zuD`_?2}$ta>KT;kZ}EaFoy@%S_~Me(f@z?r2UT|r%v_8djC{;oi~@`tOk9jy%sfn> z{J_D?!^y=Y#wf>BC4-ixPz*{2)wrNU1;U&R3=A9$3=H6e^@jn}#0Mp#6s8u28gK)B zA!9953Bv-W6lN4L<`RYlj43QAVk{*LS<E%i=5#Gf2}=z_7AvTkAM01k3Qk(Apadex z08MBh9tWr~TEmvYnaxzxl)_NUk;hZQUc=tZSi>gGP|H!mk;1irvxc>XqlUePZ6QlD zBO^m$ND4Q|UWEDJZUiXta;5NyF@VjwQNxhMUBgtv0_qa=GBbjbA`hsEy?_@ax`1yX z!$PJM0WcKoWd-$cgBdi1{E|WKCs4qEFesov{ULCx3sk#;TU|W0j5UlY43Z4cHWy<N zI3_{ywvwqx3LNznXmMF2%fP^}9~A3y3=9larnurKwIDw^BR(}R1zS^zEVqI31H>`S zAnQS{gm6Jg4D74}AaRhh$TA<4)FB4qGXD_Fd~maACF3pTl+0X676#>BaHjty53Y+8 zs%xzj5|fh^s%v!=piNhW>RSEkS`4>=N<(nGfGTEiykNA3i$p;-gMBOsXMwqJzn=k# zgZ%yyG%R$B6`WC4GTve>E=f#Jt>O#;HT(1vOOr!ZGT-7TD9TTdPf9Gh#gU&E500u^ zpaLpAwMdh{2wYGwCnuJGH792mRf5cR0ui7AC2-kMR0>KRtOZ4xc_p{l((;RP6H9LK zCYKf!rRJ4DO%wq&q0n0{MWEDI<N`9R3{>KAf{ULhQLuYqB}9CFX-Pq8N%1XKP+3`g zi!(Q|0$gks-(o8Pk9*u=OhyF!E#}1JWJuJ(1Lc+w+(fYdKuK{gsNeyG5d)J5qX?4- zsG<a84n{U69!4=nF%~{X0j4TNf{7ZXN&^+&phyIlJWGhG(iSkmD=cPsg~d|Cmcj(B zu$ZA07O3I^Ral^ki>rin0UNlI0u3~<*KkWRq_BaiqZD>%HN=nw5#<1h)^IK0SjbSr z4X>m)pp_ISxRT<6R#F9^s6Z>mK`9ZO#=*7HN*45D9+Xv9a)Q%95hxamKowC@1;`|D zeFI8CkW>S1HM3-vq~@;V0mpPvHAqzrhyclfE105M5VsCQ)Po2#*EE2{z!3{3KuP`< ze;KIZn+R%cgX@&5pr{3vSPV>DOdO1Sj9iRd;5HNoBL^$AE|EivRx~3}Y7<Z<g4ZT# zjG)>i3o?iWtxZ4!S-ni)8ig^16;z7DYZ*|D!j!^>BE|x)W!O=~Si!XnYYImSOD_|+ zUSUb$1Pwjaur@QLaDgE=s6?+}$YM$1sbN?E8jfOJzz#0)QRO*m7(fjekX#MxLWpi& zXl=n&!(79f!k5Ai;+KHZ4iC643jo(%yfti+44|Q!TJ{=-EC^2!n);Z*gEdSk47Hqj zJSBWJ9L<b1?9vRioF)7zLJI_H*lIXI^%=A_yObgf@*lzmaBT(}1`|x-0-Kan!;mGE z!XXYCLK7+xULXSIi-<5RWQ5jRqGZ%tnv6x@6a!8jxXSMLphN`9`&H(+lL5S>z*cF2 zi}#{-1_lO<BAj5=TGRxx9<KvGf>eMUNQUXHAia1^{{&J&rs-`Uy?9Ok0#X4o9b9Nc zaexLeKz-|wqCN%&hA4IrKe04flc{I|NaaLO-N9Uzm{ZgP;(|*Xa0S|lEpS0aE9Lbg zKO+M}5!i{AplU}HT+0<r0@*(qM1bo!aNWcP;=(JaTi}5Ja9h_OBrzRCV5tj>rhp`- zf`~v60rJ!>wxYzml>A&!r!zH0lM7q`MhSpI1Jn+N)pC&3I|HO^CWx2?B4&dKkRjkA zxM&WD3yuRY0WJeYkP0w(%6$ha2tmDZ21W^3EypUwsK6q`$nu|!MSz75T;p-UYCL8h zMjpm0Wg_wf?mBN7MRi^cOA1pj6S$j3w^~n=3DSxIRb-Ih2e*8|O-yj0gX(lh$j<|X z{Cp6x5JZ53A54JbfD=680*XyBMh1prXblAF)qratNdKOLnTrwJn~^{ZX1HdQx`z+B zC*uaLd#KfusbN^a0`4`ig6a@Zodc><7O*a40C$7gYFKO7QrK-6YS=(Ei3>xlRt!@u zJE-?k!#<lKg`<XJHp5(IP~WA7V*%$vh6P*;8EV*S*lidJs}SXWtY0l?@`AU9Q<5Qt z6Vlbrg7CN`8EUy}xKg;YnTqbEFx2vZ>lN;1#u_eZhFYEyjuf5+{5704JT=@kT+n)D zLkcgBdPM-_7kJGkxIn0eEuC>8V=ZqD?*ic(hAfd30ck|tB1lHvLPWu>4J%oy46#Ko zJZEC7Q)p7w>%bgDhUtXMdOeuw;IbZEA4YM26LMl{auu_JtpcK94yiI}SgL!$9CM2W z<OFbiQpE)sq0mb%F4M$b&cjRhq5x1;#0qMk7K2OhTg=G?rMFn~N^=V;i?9{*;1V8G z3EW~TNV&!7l30@H1{z_$#a5CZpIls4<Or%>_#j0+AEX}>4{=a&5o#HKi#f40xhNat z0+#I5%HmsWWr;bZsl`QD%KlrRIoi^k67*gUAIRH~{uZc*qXjC2K;!rfjL;qqvjC$2 z3kSOZqY$GI6Qt~i_jZ^Bn0OegWbvk2j5Guq%mepygur7gH4Isdpdq?uCeUC)FvChl zKTYN$P=lxlG_YC(>JMl#frA4g04_Vh1h@bd0d){F(@TqxX3&cRLEZqF!NA0aSvg_T zf#OV%PEailb|$D`uVKhy$YRW5s%1=JtYIu;DAG$|S^%Aku3=oj0_uW+=5RqnM0qSV zjI~TvQYEY<Yzx>^SZf$(Gt6aL$XLVF%v8%9&X6a<$N+*R94Tzgj7^M@40%ijpc!ON zP?xTTA)XOD_r>l9?$z8Py0Hiv?SzcOfa4f6Ta%nvkXe$L1Dcw^)$NakWq>MOv}i<3 zc3>-bsz77gsX4`ZRr;QJ;C6(9FL<z9!7Vc<RlzSYH&r3ONC77244T9Oiz;ez-D1xG z4}=$kBli|dUSe+QEtahO%)DEy#RWN;CAZkqQ%e#{N{ViAX66>;7nS5>Cf(wNFhTP# zr8%j$xWYhb-L<GFzvvckW?pegVqS74*sq!_kkqmZlrgxH6LV8@;=xw1f_(@M9S8wR zKt;9;3=H9*^aA2BFmkbQFv>8oFfuW*F>?H8Vr2Ty0UZUek_79AFA&f}gb+&g%g(?6 zn(kuo2DPM;z=?bTV+~^#(?Z59X3#nY@XRl0T$&||bpab_#5$9qmbHf2g&|h1maT@Z zhNX<5NUH?Yonpyitzm0s0#&Ov424=XtP415m_VJK8b<I)HY;c(JD5R}x$2aHCvqSt zBr1TXC=^PIGxO3FN-|O*qe#UC`FSZI5pWEE=e4*LkN{}<MgcU%p$Cm41yIv9Ge0k} zs8UZMGQU(IIWbQmF{e0RAv3R_v_v5>uTlXt^?_naNxnjUS!z*vQD#Xh7Q4a4ltN}6 zibZ;G6TlUNpC;!m_RPHEg4E=aTU?;@P?Qeh6oF<yZn1!@EZPZbP=N}FqKzOHXn5up zTWLW`Vo9ndJ0uYt1}B32l8n@%c#t0GR1bvBmYD}GH(5b4hPRj#i_#$p175ro9b{l& z$ONSh(Ciiis~95>BMW%WMTC)y5o2^2B^{zHAOR%`P)7t@q<{vnYZw+V)G#h!T*y$v zw2-NmrGyDIV!VK5Ap@v2RRS8rWp8GL(J73*%(bkbNe1xp60R)n8dlH}k`$I+<^?<p zA>+urpy`P$z6Ja>%(EF%*yb|TfJ_ls2paQcPGQexD(WcVSO8ikA_QuSfeIx?h7yhi z!ZmE5Bn2*ZIQ>9V=9zgZsTH6p^PJM$yyB23j$lwU6eZ?>Bop)U@=I>9B<Cg;fa~Wf zL4EjqqXJ}lv1lb|Y?nhnIX50uzi5gS`GO({Jk+xm#99X;R)dH&AOc=d-eNAu0gVTO z3-nve#d*cI*fUZSOL7wnqCkltJ`Ezno|B)Ro(h_iO|2+MEy~PI%`3UZm6(zORZs*9 z;#=&XY238Tba0(|3pC|f3{Mz{o>DI;iLiqzh1|rFDAts;_~hIumbA=t=)4;uF@UP= za(H536kt?g<X{wGmSf^z6vC`#u#H}7^55c!2bV_i@wd3*p@ms|{4JjN_`=e}98ftG zA72DY<3)|2keLf27J-Q6AYuiG*a9N9fe2991x}(xAgA5ph0Yizr{?6qW(~ov&jTq2 zm2cpQZqU?8F{qcq!OT<0q0S+}A<3b^A;O`~!NwuP!3S1@du5gift6X{6;-Uk44SM( zv7qiaXzBnog#f}JHVA{-0L7rratT8-11JwLGSq;U@-TtYV9{1k3y8VM3gl=|NZevA z&PXgsg|!0igT%o;0{ac@Es(M{P|Sd;ZU#`+<OR+C!q;1%6wn~sL9PJjMNq!3Vax)} ziZJ#u)H2mD1v6+eR>i_gqo7o<?Zx173|i(TC+8P|OVN^iurOGK0w}emfQmIxX<L?= zTCM<^KvyU(DJo3{Er>&uaGFfFSU?JjKz`9=1{YPgSiy$hV$LtcXr-~I<rkGF7NvB- zy~xJHSj7YOB2-F~t>`u=VD5m3yC4D-CPj}yta~5=99Uoi9CqMf14U3P%uEhO9u{zv zX&T>RF0L#t0%grxOnC)GM?m`bfrzD`yuwzHn3s}RTm;I=QS8O3iAniIc}15&QtY4= z4>_4hMWFD##g&^_0$Sw=5y(M8Z`4!3WmKWwdI5h%W**nBHN3mA()1>`O6(vr-a zVm&Zf1Zssvah7D}=IVjQgNi`KXc1^!wTKhs9MC8$cp&B$2Y60jFE6hs0%V6UjH{Pc zng?D}0a`wQ#Ab9W0!4ojXhGjC!GcP#`uH5s(i_kklv_*%Ik$w$z-uewOF+wjKvNW{ zMMa?L=Ud!h)eu?m@`WfNBoRGu#amp&4+;<-B+ZZ|65!<zQQXOi1tq1qdZ1QL5vV16 zi{A%kIiw*4&W5+xpd3(4A=ORbcmze^Ee;#VA_LGWU4~){P}d(^KY`~#895kv7<s@$ cG@xM`5oQoy2qX$-fy4xuc$kFh8R{8Q0bx(CR{#J2 delta 5773 zcmbOf(Gb9w$ji&cz`(%ZFP)h5$!a2>4AVD;i8___j5(rFqA3h1j5%Vt;!)y^3}CiI zt~e6|BSWraloV7%I!YQWBNHW?!k)sBBbO^5B@dS4%u&cyj8X)%xpI_pm7|my8Qd9C zxKnsq7*cprWwTV8nWI$Q8B%yt_*xiJ_)^8Q)S8*2)KeIP8JINrC+9H+P1a>nn=G#( z#h5aA0aH9zI|B<t6mKwtrs!lrW?4QF1_lOR1_lOh1_p-WqRlqUsf_h00yPZrvNbFz zf|3j=LfK3cn2MNEglpMSL~7YnL~A)x#A-Q9<QB*;WXMuT5tn4B<w}vL;hN2mA~~12 zmb*l;hCPKpg`=0Lma&F$fzm<-7lv4#TAmcK8lHIN35<nXN>mo8)-YzN)i7kqE@Z0Z zt>Mj5&*IL~NRdjBPLb)YXJTZi;my*_;$EP&kb#k*L}h_?4Py=OLZ%vqc%2E1c{xz0 z$%36$SE8Gxmm-%U-^*CbSE66Tx4>W_Lo;IvLyAHxlO#hee~C(#b`5_YL#;rK0ETmz z3Z1~_3zVoVFia6&$XH@jpQ6~z2y!DR96&BCQCVQTkYOP!$ojkwP|KCTmj5cz%`$;n zE?A;pBM7%#8O?H`8X+u}?*m&dgs?nCrMVtrxhf=hP%JNjTCSGOG=aG&FGZ|Y7-T?+ zN|tGgdWuGhW{OrXGek#;N|sp)OA1Shb}tiHUL!>Zr@T>$3e?N0H4HLHe&z)EIfWsZ zK~vZ77IRL1`sDR;CX?^8sd5W2FfbH>h+pQD`PpSBo3X1gYEF)2Pm_>jU|_h#nwywh zl&{HDBn46j5{jC9n_aQ~B?A)!!!3@a#N_Ox{JhjCmXyrA{2~qp28Nd)!(Os5Ffh0k zu`@6*{9@FvGS;umNzBtP&o9a@E=WvH)d#88&o3=0C@s-<Oia#5&5JkEPtHh9&MwH$ z%q!6=D9K0%WknE9uNGP7{;;aWUX$q-OKxgn-Yw>0rjnG&>Kt-fnoLESAcrY{+{#gu zTAY_!Vrua6@Bjb*p;9jySQr?JbQmU|<*?u~zr|WyT98^)#66jxGl@}tauMf=dOna| zp8V31__Wl-lG38o;vy!H5OZ>Q$}QG{#G;a7u*Yt36(klV=BAdU78T#(ElMpc%`8eS zj!!R2Ou5CBmv@WJCo?ZKvFH|OYHm_$N_<LY?k%?R)XemZl3VO~`9-;jIq|nxk}?yE zZ?RUT7Ujp^;tEbJEKL>5E6GgExy6y2pOTsrpH@=D3G$Ho<Oy6Nlh1QW*0Thq`h*r~ zfn3DqoS#=_lyZyHH?blhKR?GPMN{M!OL1yW+AUVF_96?AUQ3W#*6hTh^x|9Wxv3>Z znaRbsSQ3kpGorW(5=)XZ;)^q@QlohD;&T#<OX8Dq@{_ZRql6R7(hKtQbK=2DO7gQ) z^NMe=WH{vKrbcl?`~i~600&88?k%<)uy>2pLAJ5y=BJeAq!vf97Nr&&T7aDkW}6yJ z?&elf1f_{0P#INZ%D})N4=SW?fZ~>skCB6sgOQ7I@;UBMCnh#VJ|-4M7Dg5(7DhfM z9!4H!9!4%!0Y)K49wrt>CPt=zTr7NyJd7gDER0P5Sy*@&xfns>9E@y?JdA9=IVOAa zNU?$#?31f_yct(b?v;|AyiQ7H@-<<(q#A|=tTh}pEDM=xxl7n;xSAPjxk}g<fJ7K; zd1@H4IBR$$8B*9-86+7}*v%Q38EhEfq8uR68m=0iY$k|I3TF*>CSx#zCYN6kBLf4& zO2#7j$!mC}!RZN{@{9OEu>pz_O|Dz4X*v0cCAZj%3rkZ|t5S<1LBYmZng`=2!PMSj z&de*h#afbIl$>#kEwLy)H?g7!6k?jZMTsCytVy6ayu}Ho;xkhqvbR`s5|dJMZZS^I z<&%od1(}uyBJx25$PKp`Bf<F`LV)7p7H>&WVrE`uUV40SNov7DP^`#-Vugtd6emnv zOkB)7OrZGTVCG@rVBumCW0Yd5l9?PUEHQ~+u%3;9fdQ1`ok1E)7#J8zK*7b>%;>@p z>sQNE!;r;P!z9U&!YIj5%M9m%%5s((<`m{^rXq_JhFaD<o)YF7mS)BpW@(06))JN! zmIbUeOf{@EEH%svS(+Ic849^lSX0<qkj!sptYHv{X9AaK?0!X{EL_Atd4gb^E~pZO zcu^4Kb8x}~6CfWKi83%SoCb-DF)%PxnNE%sR-C+2KtxlMt;hx>Yy~1ffn8Jp4(fs; zP-L*BfwFIjCR<VB<kLdZvW&@48$oKA6O)r67DB9<EGsN7c?c9FAU*>l9~g2lvN7>6 ziZM=(7Pe(vfRrqg>$y`n=dd8sQBe-a4QLSv_8_P<1*fT%EJbMY6`%yj33hZ*AxJ+c zc@-6bSm1;Oik+fT5EtaRTP&F+sktk8!09Qys0<`t4kAF3MIf_^DnZ;T5K#>x(5$Ng ziGjQa@i-`GZ}FEU=47TMmSpDVfs@l&P!ND>4F)DICJsif$?~EOe4yBYCz&IYb46wA zvlvrYQ<z#9Y8YKW)je|xOD|I`a|y!&rW7`W_(H~7mJ)_6rWAIBSPDcfYYjsda|%ZZ z3n<YvGo^5XAy+ReBPcmCr*MPwJ#!7~0=9(=3(@84*=ra;MF2>$h7~H!)5{FfynwTY z1twn0R>GCVUBkA3XCcD^-h~V$d^HSN{LPFoI)%5Fxt6^|V1XdGts|5rT*IEinZlRC z-^;u}WFbQ>M~T1!(Hf40j3AS<#8Q|g8M4F|NPt>t95oDCk_(v`;fA^}#0J%J)-YsA z)o@BOfZA`hTyUO{BttEC4HqZ_O<*jFOJS(x$>S-JuHkNGtl^SosO2e<NfBNkTf<qy zQ^Q@u1<67a7z_1ML_lE&b1*2hYZ&77rD_<m<We}ZnI<q69jIZ*l1~v8XQ*L_moHIR zpa|xRi7<d%4bE)hlNAJvO*I)I2@sS7Rx%X{f${-k1tc+m#2`gID6>5Wr4(>xGoS1) zr93%KT#|hy<1OZt%-qR}0;U?^G77Av6QmN%0u!J#1<tN7K;mMP73EB|S27lXi!-nS zT<TxK)myD(gd}FX>R*A?GcYhjag-(I#OLG}7l&w0wic3&y2YGWnykrGG!dj?5-5i; z-D1tmOG&M`#h#p>Q<|Gse2XKvAT_z9C^2Ux(=FD-yuAF9D3;{h#DXdT=ls01(&Ehg zJcZoElA_ECM3GTcH(60mWby$ym3mNqfCL+^(u#wTfuRT#fRHN1n1O*oQ>h4)pKmcI zCzjk|EiOq+&Mx8sxe-(#7ukVWpfcbVM^1ivd`f0=Ns%W=%pDX<Oa&>oI9(D;65T*e znfhDI$pxjiSo2DA3n~#M;4S8YoRV8?X_@JzMX9%#i}Q+av1g<vmgFWD+~P`4Es2N9 zuqNl^7pE41g60-iYDGb6QD$yxUdb)4#FP}68s6m6qN3EilK9ku{NxO9MFuIR>KPaq z>WdmcF7pNvpd#)TTTx<ON`5Y=QInbit?;4*K=A^qq2eJe#$s^dyTw{il%G_5i!(Q| zA|AqID@iO$Pb~q(PF`9rC~#8J;*)c4v7}|D#}_riYyp>{MNJG03_Osq2d58_$^RuK z&7OfuEfY{B$SA=m!6?Eg!79b5z#_!R@}G@GfKi2!gHeQ$i&26ZQmk`<i*<>~Zc@6F zGZiH3L8%r}3xevqW^mUjoiUhUC8M7va}lUfTm)*J6@dz6O(t;p3lRW20Zf35xFwRD zp9cx#`22#B%-qbX)S_ZRMh1prkQty#rApZs)L`_?D@iR%OH5AH1L>IjUqNtky|5(X zB4`cMh+eyBGC`sMT*#x9>Yx~aq%v^Br3DmTtstTuM1UO!CcrM_EJ-ac0T=Y5AZI}f zdQiK8k&BUw3DU0PV4l2AT9*%$w&A7zoXNSOa`jBeh5kZDP=TL?Qs6@ica|E41<W-} zH7pBQK!tZLE2J<7@xTo^wi?zNwiI?7h8i|dN$tWAs};jk%U;V-!d}BZn<0gxhGRCv zTxL)!qK0Du$3lh$oC_Ih*y?N8Z5RqmN<gJ5H?+Y5DsOqf<t-<uFI>Zr1>tca%G;tF z(9)I{skCKJ;a<RpTG|$_O5p*QwQ%!64VHMG8iob@HOvcG7S=O>UC&!0kR@2d3v(CD zIWQVrr|^{sEf5CRDI!^-HGH5NB}Jf@d4bqMP@h<6fp`r=mc(p^6v4Snpi;U-Xn|x6 zKbR%3K(K}_opB)}xW^1C=Ot2vq`_r58@MnR1{LO-MZNIo#Z{(ifeWF@6J*WU!EN25 z?#YUBW*U72HEKgND$1F;^%GRA16B+!iNT30u{61gS;1C;%#za`W?az}1_p+ZqUj(9 z%mB6fIP{Zq<3UzbaTTQ&m*$k{B^Q@zN`p()qFEqivq1#BBrWm+6`ic0MpJPSsN5== zp8-;^07PJEH$qA^P<dN47o=n!hyXROixz=ci$TN^5U~_wE_-oeS!!BldXY6qXc<U| ztt3Al<RrnO#K{ZgMe5gpWY&U+bs%Cgh)4k`WXVpgEWX87mY7qTT8yoPD_Rdyy8%RO z1ZiM{co@0d0}1niLIYAn$EW6{Xo9k|5wsK(VPuq;tf3&s4JrMYg&4V*C%Y@?O_ot? zX9KkpZ!t_(6wt6)z_5^EB2yttFl0PIlkpZ;a$-SdNn#GD^a7OuXtsi@_gHW&P3Bdq zU`(3arWEc9>f;n006B^)IWadiCq6GRH}w`PxU@qzz?^}BAspmZ8L(TqSU4DEz#YZS zO3GD?jO>%=s>tYrdYqu{Z+dajeo#1o%HE=#Al6}!Eo`L)DTyVin(UC$_~hglDh49p zat=&@3cI4?3=9kzlXX<hm^Aq&SE#GjgUl?d1)0$dB050CCJ?a&L~I2Sdq4!pfneRY zc;n-fa}tY-lT&kY;^UD#k_A!+wgohH2l7ZUsJiE1=Ba03=1}C|=8)vj=3wCv;n3p{ zECQ)XW|*v~QDw0i)D&PY0##PvUiB^3;*7+CR9Llg9b_)paD+cV&Mu#PSz~iONLG^> z>}61!rl7Rs7IS`S36i@&PP)aOmS0q!Sd>x?3S@BA!^Xo{#S@m9TCN9`(qt>T46^D9 zh`0(OK(-X!2C=Sz2(a~F0_+-)xkbLQ3=9lqFhe;Qd02`-BAObvn2Rfmi$EFc7E@k9 z(J7DtM?l0Xa9kB6=A|SSgL73BdvR)FQhrfh(c;NLTFQ(ICzolN2!iwQEe>$+SuZcI zC~)#EEe%H7$sF4H^`I)~7B^U`EECi+0}r)C2_cE-frn9wi$LSRxA=WPgEyf5u`hUZ z1049b*q|Jcdy#?*>_L#<ZgJQ^21D#XqXNZxpiVM)kV6C9lV;>#gp6x|NDy0yk%JjT P^Dy!-3Dq;yGsFM@+q7C) diff --git a/models/__pycache__/resnet50.cpython-39.pyc b/models/__pycache__/resnet50.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfb660f9829304cb853031dd3d3274396afc38be GIT binary patch literal 8628 zcmYe~<>g{vU|?9cx;d#yhk@ZSh=YtlvJMOk48<oH7#LC*QW$d>q98P5E>jc}BSQ*P z3Udx~E=v>(BSShv3d<tKDAp9#6t)!h6pj?m6s{ER6rL2`6uvo}QEVydDf}q{EsRm@ zDH<t)DMBrbQ5^0JDZ(itEet6lsZ3d%&CF3;?hGlSDPk=QDPm9&Zg++h@f3*`h7<{? z2v3S+ic||j6mN=VifoEp3u6>tidKqzib4xx6n~0#ieidV3uBZ(icX4hib@M(lwgW( zifW2l3uBZ}3S%&Xrrt}CAN^i3GB7Y`GTxHN$xlp)FD^+eNsUj*OfHE}E6UG}FD=T^ z<hsQclv?bUT5^k{D782*wZzZ@#xgd6u}lqaaYAJc4B<>uqg(u7(TWmNgLq@33X_yu z0#E@%14EFAManJy^30Ufc(5=?nb9o)Bms~zqbkdcl9GaAD}DWx{PH}IOZ5sWOY)17 zGxYL{()DxmQ&Mw^^&uu0TIiZ48yc7;r=;o?lw^>i)!0PW*w{GDJT=*bRIR24x`vj? z7OBRj=3uQ>Bskc>P}elYILRc%B8_CdrbfC+hDL^|DMl$^y$(d%i-;0k^W>B?6H^O_ z%N<ED0TemLMimw*x)w=^re<krro=k}DUMAI;*E4IP16#S6HOrAAj%O)CV(7bq-$)H znrdihY6fzICT9^V0|P@58;D>B5gZ_b6GVVAbP+d*#RDRE85kInL5e{c0)#<qkQ!%D zrk7%1U?^e8Vyt1PVN7AHVX9%NVNPMHVM$@`<p^fbWT}!dt}s?e&d)2$&nYd*%+FIO z&n(GMC`e36$;?ZC$;`mO;HJrTi#t2DC@(c9zBscg^%h%kNl|7>>MeGts#|R7Mfs%# z#kW|JG82n$ainDCB!Y}u$#jb;FYgwcbADc#QHmz(EzZom_=23oywu`bT=}IXP<D|3 z$Qwc+LWF^VVI@Nm$kD&7^fU5vQ}ruz67%%S^NX^J3lfu4^+Sph^NM{veRN$?Qw!XS z64Odh^N*>4UP0w8c94gRD~u&TK>!Lv1|}Xxj{ik6ps)h@2^3Bs3}S<V7aUd~+mXWx zR513kf<uYlu)+{Il!}<Z0mEFx4{|ox1HvGdD5yAME&|zGBnc7#OM(0Y_EeEH0|NuZ zb%qs&av;}%3}Rp^k_EX8<RWAYb~#9q76SuADnk@w3PTiA3S$&=3P(Fb8e<Ae3Tq2z z6bq=RZefUGZD(L%h++$7(B!<u<&;>Qne3F4pPUVLs^3dc6_5-{KQIeGY!(Iv22c<b zPXH&A62=;aW~K#93mF(0YB*9DN*Rh|N|>`)YB_3HQy8<Ei*;()Qkb%ti%d#bQ<$3> zn;2`@Qy8;Zid;(AQdmIz680L#8m1cNX671(c#ay58m4B(8ishz66P!}P@U2X;&GR- z*D%#EH8a&P#Ph(_@Yb-@FvRoKfOPVwFa$GbvipJEUL?=Jz))q3@R34(UQVS#acMz8 zeo;xW0wnR-8Y(2_r6?pN7N^FSXQq^7*qWJCX}Ca>qJo`*p+a7M2~1aNib7_d0>a3U zm5jI8GxG{^5|dLk1#YnxmlmWJ6@en?7Ds%1W?p7Ve0-4@DDt?S5=)XZ{PK%(jZ$uL zg(c>crn(jt<rm!&_RBBv%q_@C%}vcKNlgI@f)m^=R#0j*yv3Z9XLySxDAgzQ7E4iT zPANp(2qbQFiwjgW6eoi8gY$zX=Pj1v)SR?i9AH<0vJ4v}>wu+;bU}#=lxJ>nA-r4! zO2N0d^74yv<8u-#Q;SwI-r|gpPtHj!E{=~!q`)F&1_lN-P{IOb2nHrTMh-?cMgc|+ z5M<(G<YMGw<YDAuWMixnMud?bEFhA>X%?D}U`Ys^gAy1R7-|@6nQ9moFw`(EWUOVb zVOYRe!wh9H!CB0377L7(#hSuU%2?!8!?1uAte2^VxrWJ{0c-|LCvz~vN=83Twjy<q z|204aC~p*LgIGEs0u(SSS&H;Pam841iz727HLoPIr1BPXerX9L2*BA9k`vSNi^>y= zQY=7W0g7G*MiwS6MlMD!5DCT{j9g4rLP((i)uPD;PMk<d6vAK04oWe3iMgrq@wd3* z<8$*<N^?MLp7{8}(!?C73}<RZL1JEUW_}(xGK#D~?&XEXRdQ-h4lKUFjx`4P4CG>P zk_5$GF{p&+U}RxpU}X5uCj=HxhDHgT!d-rFg3Axq6gF`A!Jfi_QhspV;&RF_DalFA zOHIyx35sEe-@rK$r4RwR0%Qxg5LrZ^5Gi4;VX9$EVU%VNVQ6NoVHIJh<t$;VVXWZ< z7X|D!95tNa;)0`uvjzm2YglX8n^|iZ;<?~*+$C%^oHa}}3|TxP3}Ah{pdtjy2Nf@T zCCpj;;IbRU695-0g5Y9B2wbeN_@zTaD8M14NC8w}3l}MY0@x8mfC3ks(~DF<QNo;= zR{}1aky0Ee(;%fddyqPC9ybQ%W#b}EkU8LF2}+FM{Hn<XPLW0WASnY70j?Op@}OK? zWC#)iW#1wruyNobtOy+M;1mFke308hKuHxO1k1J}OrWeP#K^@c#3;bX2g$VX*drp- z)=`jY7r-(mDAU5SIU77nvxBpAMh(LPc6?d7$P^R{97sV74n0Jsb^(QsD~JGxm>Gx# z3Ij-f28RtO1Eqq(1{6Ce*%&i3BSMEjW-hWuwFw*_VCz6>#T#TDC<kz`L-V0v5l9?& zJ_MCB@OqpNoCn!b*jqTGm{a&covjweD3%ni6mC#%1a-Pv7^2v~wK{t+gQh?cC?~#r z!OXzm2F{P*?C6J)A3-feaDM#32x|M`snj`Y7~)w|K)Hl11ymWc*MK67p~x<sA%%Gn zV=a3LM-6*3V+u<)b5TSIXA0{At{V1*jG&6X29ztgYuG?lcnNbBZwgBdLlzsTC6dAd z@<}g9hOdSLBnr;c{3Xm;0wBMEn*f3(%vnMyEGg_MOes9QpekRuhGBsSq>Y-wE6K1B zBqmzJkOdOst6@rEl4M8$xsO+pA%(veWV;wtR-lF{g&isjQU#F}hsp}pFr~0SWkISS zvJxfC3nWv77J~dPRl>YLIz_OCF-t~-p@w-OC?sTS7#7GaWN=}KjjQD>k*@(YV0fFE zCNLIdlrU#0WGU8gW+{QXy(yx-%!~|qCQvmjCCpjMX-r_14j`2aRKO~k7BVq1!1<~v zydeGr#zKc0mI;hSU*NX6Ky6b4n{WxPLp@6gq9;!dY?B69ejQw16DB`_v2Y!@loRs< zCl7E!(PVaWb1UKlHKRO=Qo)U!ct{&Q9@N^8hc(z0iZk*{b5azNQWgC2^HLS^ixd(S zjCE5XZ5oA=(t?~+9fkD#5{2s8m!M|zOHfO|O_K$w87v6OF`%|Ev}{FgB7@~XJ@%Ik zpbXChE*fvKfvrz{$-v0KkjcQrz@W(lb~Tu|#h#X!7Z0ksK}CWl(=8TIclj1iUVdIq zW?pJyQD#Y{rWU0924xROZ3?Q&A>}uBJfbxODSbhO>@DV;)Vv~9P)W;{mS2>cSW<*2 zg~dTWK(sI5rTHyR-^7Z5{QMlFlw0h%i4_I;`8l_^<8u?UQz5ObTWsJ)mm!okg3`uN z+T@m?V@hH{NoHB9V_7=Ppv1Ctut97-VDA)xYZIol<Xh~JmP7F^uFT@hyyB9?yyVnd zETE?7E&lAp%-qbpbg<tNbK-BYm8WK=XO!IHbO&|BK#kK|oXPoMO(pR~px6KxOhw=# z6x6CM0=1Q(mC`NNBv3yA>dN9<+<B$BkTzp#@hzdM)S`Sy1Gp%)I5VX*F$Ww};9-xF zjQE1oqIhrv{FV&9lm)5P7;~}JXrTJ1I0959Gea7pY)ov7Y>XVt9E?1SY>Z6*A$$Qw zrvDsFJdAvdLX0Ae0*q|`Ihc@19!5}2$j2xKmK9-TW8`7vVH9EHVXWeZi~u3(LQs<h zWFZKHdeR^aZopmw)rAaM47F@247Kbv3|WkzzE6=_4Fd~97Gn*IB?A*fJYx!DElUjx zxJP7H!;r;P!&JkV#Vo>*&0LgH!m@z1gbi7=hNYRgh7G2snQ<XgEqgjcEoTiUs8)ok zWT^q!1}ZK&YdC7ynmHCSGBUvQ)G)*|m2jmn)-X3SxiG}4)pDWeWypf)1G$hbg{hYf zq_=QE2}>4h4QC&yC&dhDrfRZ);|{H{p~+no07~+pi~-3apfm_>1Kr{ZPAx12HJ}r7 zz^$NLY>5R0sd*^~b8fK~CFZ54YH~vwIG`%H2-Lv2#Rg7Q#o#<%1ga{*X#m`Dxg}Uo zlv<XVUs?<=dWzCOnUN9NNdedG;CeQcfq`KosGbESSO!KmMiwyUV`O2JU}RwwV&wV9 z!pg!Z!pQNT4V*YYy%?r{Tr5>Wut-N_P?U}fsOSQf7~ps>0mnPE)4~Yuw1B!VAQlt2 z<pN?c!&xkF7Nn~Jl4XO7LE1VXF%F1W3FiVXP}R1OsTL&44H2yYjbk!FV}=L4@dIk- z6@`J~B^*S6@<&l5h!qVAC|2<B;Vt&GoWzon)V!it(7*{3Bnr$ys=y%#CP2Y>i#HzH z(uvQ^Eyy_u3Q*7(1p}yy1MBWUX&z9AM+6pl$l5T157f#42Op>r1nct!GeDYij76m& zhk!yB(oBQ92He@X4RQl04j34#_@K5yEz@Kxat95sfGYrSE#~J3trOfdIf`sR#@d1i zI}ibC6c>S#3B*27JFo~e*nwmzqS*$D@EVZKpvsT~)OliL_{V3=CBY#CZe5`aq=6bz zAPu0zTWkXE+|)4EFr_hNGSsjxVeDh9WvgK>GALmzVX9$}WB`>2y-c<2HEau*LCqva zh8oad43nQG<1LPYqWtut)Z$`I7O*`<prE_O1**$(K*Q9<xA;Nh!tsa^Y)y_^EQv+Q z8Aae61<oB@1x2YPMTwbtsVPNepj5z?U7lE!UVMucY|bq%gl;6ygX5naIzW9N=6420 zAx0KP4shwp#aIL?gWW<<A{gXCP&o<qJ*ZJ#!jQsP!kEsG!nBAHG@O{mn#oYZ1PX^> z22JKfMMee&NIdEqS}1^BqyQR)=TcBmP)JY$jn^pzL8}l2pVY*nyv)3Gg|z%41<%~X z^i+kQ)a3m1yiCx*p^}0PY#2AOs3NlrJWO7YlBREHYG|ZqU~Ft*s#lPbW|t7p1vcHW zD81MU%mUek2qguLr2PCGO)CY@G=-3&(o`LVqSTVoqP${-M5qG`ic)o<L8_3S2XTpC zY6+5AkZ8j+DJ8SGASbaBWE3oJ6_OH*6!Oy)N-|OvK*Kqp7*#0AS12w?Ni8Y@n+YD( z)nqC11f^V5VkHwe7Qj9EmCQv|AQhlkE~*Bxz?lb3fZ`Ku%nMLlf?|k)Lxhp(KUj=f zF=%W84WQ8&M2|sGh_7S;x2B7n7#J8{f&?JkdQeP(nzm@!4?K+j6)n0z+Nc#>rUs+Y z1)W+$k1kL*rKk~7J%UO(NJAJ~B!N^iF)}a|gJK4nr$Ay9MUoOysc2w08cR502~^c; zB4;XaCk#|c7d7L`R3JV4Xz>LSqgs5K8jZ#mj`-pP_1_`!^%4~3FF}K)kU|h=e1Y`H zpv4zROp_VX&;vCF!7NQi2*u<UqKS@$i;6%k9}aN)S1&KG2-MZT#addDnNtjDRI$g$ zCxVC1%RvbT)Nq0JTztVJ@!<LuQV4+Bc91$3G-V1dq>Dfqz6j)_B9PON%5iwzdyB&c r(lfUMjbs#q`Z}PFDi1RU8xtcM{LaS9#m2=dq#<M|l+3}&#S4J|ESMn^ literal 0 HcmV?d00001 diff --git a/models/model_interface.py b/models/model_interface.py index 1b0f6e1..60b5cc7 100755 --- a/models/model_interface.py +++ b/models/model_interface.py @@ -12,18 +12,24 @@ from matplotlib import pyplot as plt from MyOptimizer import create_optimizer from MyLoss import create_loss from utils.utils import cross_entropy_torch +from timm.loss import AsymmetricLossSingleLabel +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy #----> import torch import torch.nn as nn import torch.nn.functional as F import torchmetrics +from torch import optim as optim #----> import pytorch_lightning as pl from .vision_transformer import vit_small from torchvision import models from torchvision.models import resnet +from transformers import AutoFeatureExtractor, ViTModel + +from captum.attr import LayerGradCam class ModelInterface(pl.LightningModule): @@ -33,13 +39,17 @@ class ModelInterface(pl.LightningModule): self.save_hyperparameters() self.load_model() self.loss = create_loss(loss) + # self.asl = AsymmetricLossSingleLabel() + # self.loss = LabelSmoothingCrossEntropy(smoothing=0.1) + + # self.loss = self.optimizer = optimizer self.n_classes = model.n_classes self.log_path = kargs['log'] #---->acc self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] - + # print(self.experiment) #---->Metrics if self.n_classes > 2: self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted') @@ -73,35 +83,12 @@ class ModelInterface(pl.LightningModule): #--->random self.shuffle = kargs['data'].data_shuffle self.count = 0 + self.backbone = kargs['backbone'] self.out_features = 512 if kargs['backbone'] == 'dino': - #---> dino feature extractor - arch = 'vit_small' - patch_size = 16 - n_last_blocks = 4 - # num_labels = 1000 - avgpool_patchtokens = False - home = Path.cwd().parts[1] - - weight_path = f'/{home}/ylan/workspace/dino/output/Aachen_2/checkpoint.pth' - model = vit_small(patch_size, num_classes=0) - # model.eval() - # set_parameter_requires_grad(model, feature_extracting) - for param in model.parameters(): - param.requires_grad = False - # print(model.embed_dim) - # embed_dim = model.embed_dim * (n_last_blocks + int(avgpool_patchtokens)) - # model.eval() - # print(embed_dim) - linear = nn.Linear(model.embed_dim, self.out_features) - linear.weight.data.normal_(mean=0.0, std=0.01) - linear.bias.data.zero_() - - self.model_ft = nn.Sequential( - model, - linear, - ) + self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16') + self.model_ft = ViTModel.from_pretrained('facebook/dino-vitb16') elif kargs['backbone'] == 'resnet18': resnet18 = models.resnet18(pretrained=True) modules = list(resnet18.children())[:-1] @@ -109,7 +96,6 @@ class ModelInterface(pl.LightningModule): res18 = nn.Sequential( *modules, - ) for param in res18.parameters(): param.requires_grad = False @@ -118,7 +104,7 @@ class ModelInterface(pl.LightningModule): nn.AdaptiveAvgPool2d(1), View((-1, 512)), nn.Linear(512, self.out_features), - nn.ReLU(), + nn.GELU(), ) elif kargs['backbone'] == 'resnet50': @@ -135,7 +121,17 @@ class ModelInterface(pl.LightningModule): nn.AdaptiveAvgPool2d(1), View((-1, 1024)), nn.Linear(1024, self.out_features), - nn.ReLU() + # nn.GELU() + ) + elif kargs['backbone'] == 'efficientnet': + efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_widese_b0', pretrained=True) + for param in efficientnet.parameters(): + param.requires_grad = False + # efn = list(efficientnet.children())[:-1] + efficientnet.classifier.fc = nn.Linear(1280, self.out_features) + self.model_ft = nn.Sequential( + efficientnet, + nn.GELU(), ) elif kargs['backbone'] == 'simple': #mil-ab attention feature_extracting = False @@ -151,21 +147,19 @@ class ModelInterface(pl.LightningModule): nn.ReLU(), ) - #---->remove v_num - # def get_progress_bar_dict(self): - # # don't show the version number - # items = super().get_progress_bar_dict() - # items.pop("v_num", None) - # return items - def training_step(self, batch, batch_idx): #---->inference + data, label, _ = batch label = label.float() - data = data.squeeze(0) + data = data.squeeze(0).float() + # print(data) # print(data.shape) - features = self.model_ft(data) - + if self.backbone == 'dino': + features = self.model_ft(**data) + features = features.last_hidden_state + else: + features = self.model_ft(data) features = features.unsqueeze(0) # print(features.shape) # features = features.squeeze() @@ -177,21 +171,28 @@ class ModelInterface(pl.LightningModule): #---->loss loss = self.loss(logits, label) + # loss = self.asl(logits, label.squeeze()) #---->acc log # print(label) - Y_hat = int(Y_hat) + # Y_hat = int(Y_hat) # if self.n_classes == 2: # Y = int(label[0][1]) # else: Y = torch.argmax(label) # Y = int(label[0]) self.data[Y]["count"] += 1 - self.data[Y]["correct"] += (Y_hat == Y) + self.data[Y]["correct"] += (int(Y_hat) == Y) - return {'loss': loss} + return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': label} def training_epoch_end(self, training_step_outputs): + # logits = torch.cat([x['logits'] for x in training_step_outputs], dim = 0) + probs = torch.cat([x['Y_prob'] for x in training_step_outputs]) + max_probs = torch.stack([x['Y_hat'] for x in training_step_outputs]) + # target = torch.stack([x['label'] for x in training_step_outputs], dim = 0) + target = torch.cat([x['label'] for x in training_step_outputs]) + target = torch.argmax(target, dim=1) for c in range(self.n_classes): count = self.data[c]["count"] correct = self.data[c]["correct"] @@ -202,12 +203,19 @@ class ModelInterface(pl.LightningModule): print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count)) self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] + # print('max_probs: ', max_probs) + # print('probs: ', probs) + if self.current_epoch % 10 == 0: + self.log_confusion_matrix(probs, target, stage='train') + + self.log('Train/auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True) + def validation_step(self, batch, batch_idx): data, label, _ = batch label = label.float() - data = data.squeeze(0) + data = data.squeeze(0).float() features = self.model_ft(data) features = features.unsqueeze(0) @@ -224,20 +232,23 @@ class ModelInterface(pl.LightningModule): self.data[Y]["count"] += 1 self.data[Y]["correct"] += (Y_hat.item() == Y) - return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y} + return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label} def validation_epoch_end(self, val_step_outputs): logits = torch.cat([x['logits'] for x in val_step_outputs], dim = 0) - probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0) + # probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0) + probs = torch.cat([x['Y_prob'] for x in val_step_outputs]) max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs]) - target = torch.stack([x['label'] for x in val_step_outputs], dim = 0) + # target = torch.stack([x['label'] for x in val_step_outputs], dim = 0) + target = torch.cat([x['label'] for x in val_step_outputs]) + target = torch.argmax(target, dim=1) #----> # logits = logits.long() # target = target.squeeze().long() # logits = logits.squeeze(0) self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True) - self.log('auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True) + self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True) # print(max_probs.squeeze(0).shape) # print(target.shape) @@ -245,12 +256,8 @@ class ModelInterface(pl.LightningModule): on_epoch = True, logger = True) #----> log confusion matrix - confmat = self.confusion_matrix(max_probs.squeeze(), target) - df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes)) - plt.figure() - fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure() - plt.close(fig_) - self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch) + self.log_confusion_matrix(probs, target, stage='val') + #---->acc log for c in range(self.n_classes): @@ -267,18 +274,12 @@ class ModelInterface(pl.LightningModule): if self.shuffle == True: self.count = self.count+1 random.seed(self.count*50) - - - - def configure_optimizers(self): - optimizer = create_optimizer(self.optimizer, self.model) - return [optimizer] def test_step(self, batch, batch_idx): data, label, _ = batch label = label.float() - data = data.squeeze(0) + data = data.squeeze(0).float() features = self.model_ft(data) features = features.unsqueeze(0) @@ -292,12 +293,14 @@ class ModelInterface(pl.LightningModule): self.data[Y]["count"] += 1 self.data[Y]["correct"] += (Y_hat.item() == Y) - return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y} + return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label} def test_epoch_end(self, output_results): - probs = torch.cat([x['Y_prob'] for x in output_results], dim = 0) + probs = torch.cat([x['Y_prob'] for x in output_results]) max_probs = torch.stack([x['Y_hat'] for x in output_results]) - target = torch.stack([x['label'] for x in output_results], dim = 0) + # target = torch.stack([x['label'] for x in output_results], dim = 0) + target = torch.cat([x['label'] for x in output_results]) + target = torch.argmax(target, dim=1) #----> auc = self.AUROC(probs, target.squeeze()) @@ -326,19 +329,16 @@ class ModelInterface(pl.LightningModule): print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count)) self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] - confmat = self.confusion_matrix(max_probs.squeeze(), target) - df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes)) - plt.figure() - fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure() - # plt.close(fig_) - # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch) - plt.savefig(f'{self.log_path}/cm_test') - plt.close(fig_) - + self.log_confusion_matrix(probs, target, stage='test') #----> result = pd.DataFrame([metrics]) result.to_csv(self.log_path / 'result.csv') + def configure_optimizers(self): + # optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=self.optimizer.lr*0.1) + optimizer = create_optimizer(self.optimizer, self.model) + return optimizer + def load_model(self): name = self.hparams.model.name @@ -350,6 +350,7 @@ class ModelInterface(pl.LightningModule): else: camel_name = name try: + Model = getattr(importlib.import_module( f'models.{name}'), camel_name) except: @@ -371,6 +372,20 @@ class ModelInterface(pl.LightningModule): args1.update(other_args) return Model(**args1) + def log_confusion_matrix(self, max_probs, target, stage): + confmat = self.confusion_matrix(max_probs.squeeze(), target) + df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes)) + plt.figure() + fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure() + # plt.close(fig_) + # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}') + self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch) + + if stage == 'test': + plt.savefig(f'{self.log_path}/cm_test') + plt.close(fig_) + # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch) + class View(nn.Module): def __init__(self, shape): super().__init__() @@ -383,4 +398,5 @@ class View(nn.Module): # batch_size = input.size(0) # shape = (batch_size, *self.shape) out = input.view(*self.shape) - return out \ No newline at end of file + return out + diff --git a/models/resnet50.py b/models/resnet50.py new file mode 100644 index 0000000..89e23d7 --- /dev/null +++ b/models/resnet50.py @@ -0,0 +1,293 @@ +import torch +import torch.nn as nn +from .utils import load_state_dict_from_url + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x): + return self._forward_impl(x) + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + + +def resnet101(pretrained=False, progress=True, **kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + + +def resnet152(pretrained=False, progress=True, **kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) \ No newline at end of file diff --git a/train.py b/train.py index 182e3c0..036d5ed 100644 --- a/train.py +++ b/train.py @@ -3,6 +3,8 @@ from pathlib import Path import numpy as np import glob +from sklearn.model_selection import KFold + from datasets.data_interface import DataInterface, MILDataModule from models.model_interface import ModelInterface import models.vision_transformer as vits @@ -11,25 +13,32 @@ from utils.utils import * # pytorch_lightning import pytorch_lightning as pl from pytorch_lightning import Trainer +from pytorch_lightning.loops import KFoldLoop +import torch #--->Setting parameters def make_parse(): parser = argparse.ArgumentParser() parser.add_argument('--stage', default='train', type=str) parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str) - # parser.add_argument('--gpus', default = [2]) + parser.add_argument('--gpus', default = 2, type=int) + parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str) parser.add_argument('--fold', default = 0) + parser.add_argument('--bag_size', default = 1024, type=int) args = parser.parse_args() return args #---->main def main(cfg): + torch.set_num_threads(16) + #---->Initialize seed pl.seed_everything(cfg.General.seed) #---->load loggers cfg.load_loggers = load_loggers(cfg) + # print(cfg.load_loggers) #---->load callbacks cfg.callbacks = load_callbacks(cfg) @@ -49,7 +58,10 @@ def main(cfg): 'batch_size': cfg.Data.train_dataloader.batch_size, 'num_workers': cfg.Data.train_dataloader.num_workers, 'n_classes': cfg.Model.n_classes, + 'backbone': cfg.Model.backbone, + 'bag_size': cfg.Data.bag_size, } + dm = MILDataModule(**DataInterface_dict) @@ -70,9 +82,9 @@ def main(cfg): callbacks=cfg.callbacks, max_epochs= cfg.General.epochs, min_epochs = 200, - gpus=cfg.General.gpus, - # gpus = [4], - # strategy='ddp', + # gpus=cfg.General.gpus, + gpus = [2,3], + strategy='ddp', amp_backend='native', # amp_level=cfg.General.amp_level, precision=cfg.General.precision, @@ -85,6 +97,7 @@ def main(cfg): #---->train or test if cfg.General.server == 'train': + trainer.fit_loop = KFoldLoop(3, trainer.fit_loop, ) trainer.fit(model = model, datamodule = dm) else: model_paths = list(cfg.log_path.glob('*.ckpt')) @@ -94,6 +107,7 @@ def main(cfg): new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg) trainer.test(model=new_model, datamodule=dm) + if __name__ == '__main__': args = make_parse() @@ -101,9 +115,12 @@ if __name__ == '__main__': #---->update cfg.config = args.config - # cfg.General.gpus = args.gpus + cfg.General.gpus = [args.gpus] cfg.General.server = args.stage cfg.Data.fold = args.fold + cfg.Loss.base_loss = args.loss + cfg.Data.bag_size = args.bag_size + #---->main main(cfg) diff --git a/utils/__pycache__/utils.cpython-39.pyc b/utils/__pycache__/utils.cpython-39.pyc index e1eeb958497b23b1f9e9934764083531720a66d0..704aade8ba3e8bcbedf237b3fde5db80b84fcd32 100644 GIT binary patch delta 1426 zcmbOr_Df19k(ZZ?fq{YH?7=BX?LrI;k3k${%)`LI;K0DZP|U<QQAe3ka$-P!Jwu8B z6f&g<&S6VoN?}N0N?}S7NP+M{@<J)XU@?$9j4uL}2dTs^FA6qKYz|wLM2dKdL<>Wd zWC~L-gQnC=kef6aZ?Oa<mQ1{GNW)A&BR@A)zcME=Prp3BD7&~IF*#K~q$n}3*w@oX z*CjQzz`ZCjtz_~B#$YA`#>uQq7WFC2H4IstSzNOj<}!&h)H0PY)i7kSHZx{%r!b_j zv@%ID)Phv-q_Ac)7o96%s$pzqtYMO7s9~07NMWmC$l}f71L>7!s9}g_EaA@*s9`K& zsbOkn6lbVmMplu+n8KdI(aT!PQo^4l2vt>&sf@Fixt2B03(YE492T%*H5OtjYYNz` z5}_2X8ip)JkTaVZ85wF=YFN`j0m$taGWj}_JXc9-amg)~wEUcu$z06F(h>{|3`LR* z3=CC#y1Kf$c0T#(3MrXIRtlQJlRcPSHEyx^<QEs;;!H{`PL0pWFD@<u>A1yOT#}fc zdW*#+u_UoboPmL1@)Bk(fm^HvMVWaenrye2lhe|RWG7!|4l&UKMH~|cqY$GGqX-KZ zqXHuzBOfEnKQ<OVM!x?X%wixOl+W^?jYWXDNN=()OBUm{$vaqb7≦WmVHyz*@qY zB2WVI8bdQ<ieL(33S$aiickt;3V$!hLXbxpgBdi1H~X`)FiM&-FfiO=PsuDQNG!>? z#gUeolbV;9n_6TsIfu<C=N5ZeVorQwX>yfhb!tI=a)y<GQA(|DHB{6}&m^t(7ISW5 zg(i2AEGT&7K!g#9kO2|qAk*0r3kp*6Qi=>fB8Fg*;?$zD)S@DNkcd131A``Ok<R3Y zY{rbXlSSE;*o;AWC~UGVySoSnBL^cB0{-OSC<>U|#{PpbV{$!5t7ZsDGt({3iumIE zw36J!id&48MZ6$m`9MSj0|P@Qh!Htims8v&3T$~vYF=@E(Jj{EjKqS}Tg=6!xtc6R z;UHD93=9lWtQGM&`RTV<z%d9iN|PBP8v|A{xr0--J{{yh4p3|{N-?T1iZK?YfFxZ( zu3@{yXmE?s@D`)dEk@&8j3&1jO<yuGFfhDi1`#YEf)%2UIVm$QimN0!w<x|WvnVkq ziX$%{iIrcHky;eRnU^15T#}fSlX{CeF()S(WK=!KW)KE(K-d}N9v%h;h7yJv21$k* zMoETX22DmkO~zZCIf+TBIq|uP1)5B^SV|I$vv0A2g&@HM@?a4ttZs=X=jWBB7L~+< zWsBp}@{8g@%F97!gDhcSDk=dvoY9X-Q)==7PCIT5kf%Y3_LlTy11@QHNQ4?p*5<Nc z14p#hWCJdl$*o-c+~71308$?{c|Mn>KoE!t4zxlLt7P&yE;SvHc|6P<j6957Y&?t{ YN&<{Lj66&nMxY$Q!6Cp1hrif30eovx`v3p{ delta 745 zcmew*H9<@#k(ZZ?fq{YH)M3Y@o7@Zxk3k${%*?>R;K0DZP&|cUqK-17z{G%jR;FME zO~J`VjOQ42CNnb0PS#)wW;(+#xt7VIzLu$ksfHnowV5%CGle0ArIksNp_aLXsfMwc zv4%;Sp@vzSA%(SuA&V=Edp5&dkO)IOV+l_RTMa`NBS@r~k&&T<H;b=^v4o|DshLrn zp@tb*M+##KdkRM{Yb{F&Zx%nADoMD(DV)8`wX8KPHLU3%=W<Qn$gC)mmY<VS#iy&Q zt83?zpRSOSS!AU!`3|!&w>Sd>Ly-gn1B0f(WGxm~@gk7GEf$x=lEfl01_p*(tOZ4x zc_ovlvuJT@vfW}%PD?M6nS7BYL`en|SWFy@LX2{ZB8&=*0?d4je2gsr*jV@&`Tnyp z7imq7WX)onH2DZ?4rAYBLpC**U<OTr&Gl?7jFN^RquEn3iwY7;GH!9CW#**jCFZ6U znM|I;Zse@VT?8_{NEYNcJrE%cB8)+r*b)m0Qu9)ZbU-4yV3Fd~qO#PYB5ja}9LOuI zMH-W>IE)!BC+Be}v8jO~(Pi>94tJ&^ugM=delP}4e#P0!=r_55OHL_>fq@|eq%#;q zFoLwR-C`|D%u7$b#hjCxR}?aN1D6zAAV@ZR@&zslXD~Ai%m)#CAU6mn7v&ch$EW6% z6y+CG#+T$5C1-$4F6ID*5~CEO3ZocfQ5?v0Mn6r7$pSof+^QhmS|HuhlN)%X*&!a# znOx0d!3Or2*5n2rnaNjq__@Jh;0012GWk7^CYuk488=yjS4{>KcRb7-j6957Y&?t{ Qq5_OOj66&nTA+vl0EcCxs{jB1 diff --git a/utils/extract_features.py b/utils/extract_features.py new file mode 100644 index 0000000..fb040a1 --- /dev/null +++ b/utils/extract_features.py @@ -0,0 +1,27 @@ +## Choose Model and extract features from (augmented) image patches and save as .pt file + +from datasets.custom_dataloader import HDF5MILDataloader + + +def extract_features(input_dir, output_dir, model, batch_size): + + + dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes) + if model == 'resnet50': + model = Resnet50_baseline(pretrained = True) + model = model.to(device) + model.eval() + + + +if __name__ == '__main__': + + # input_dir, output_dir + # initiate data loader + # use data loader to load and augment images + # prediction from model + # choose save as bag or not (needed?) + + # features = torch.from_numpy(features) + # torch.save(features, output_path + '.pt') + diff --git a/utils/utils.py b/utils/utils.py index 96ed223..010eaab 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,4 +1,19 @@ from pathlib import Path +from abc import ABC, abstractclassmethod +import torch +import torchvision.transforms as T +from sklearn.model_selection import KFold +from torch.nn import functional as F +from torch.utils.data import random_split +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import Dataset, Subset +from torchmetrics.classification.accuracy import Accuracy + +from pytorch_lightning import LightningDataModule, seed_everything, Trainer +from pytorch_lightning.core.module import LightningModule +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.fit_loop import FitLoop +from pytorch_lightning.trainer.states import TrainerFn #---->read yaml import yaml @@ -14,19 +29,32 @@ def load_loggers(cfg): log_path = cfg.General.log_path Path(log_path).mkdir(exist_ok=True, parents=True) - log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}' version_name = Path(cfg.config).name[:-5] - cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}' - print(f'---->Log dir: {cfg.log_path}') + #---->TensorBoard - tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name), - name = version_name, version = f'fold{cfg.Data.fold}', - log_graph = True, default_hp_metric = False) - #---->CSV - csv_logger = pl_loggers.CSVLogger(log_path+str(log_name), - name = version_name, version = f'fold{cfg.Data.fold}', ) + if cfg.stage != 'test': + cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}' + tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name), + name = version_name, version = f'fold{cfg.Data.fold}', + log_graph = True, default_hp_metric = False) + #---->CSV + csv_logger = pl_loggers.CSVLogger(log_path+str(log_name), + name = version_name, version = f'fold{cfg.Data.fold}', ) + else: + cfg.log_path = Path(log_path) / log_name / version_name / f'test' + tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name), + name = version_name, version = f'test', + log_graph = True, default_hp_metric = False) + #---->CSV + csv_logger = pl_loggers.CSVLogger(log_path+str(log_name), + name = version_name, version = f'test', ) + + print(f'---->Log dir: {cfg.log_path}') + + # return tb_logger return [tb_logger, csv_logger] @@ -74,6 +102,14 @@ def load_callbacks(cfg): save_top_k = 1, mode = 'min', save_weights_only = True)) + Mycallbacks.append(ModelCheckpoint(monitor = 'val_auc', + dirpath = str(cfg.log_path), + filename = '{epoch:02d}-{val_auc:.4f}', + verbose = True, + save_last = True, + save_top_k = 1, + mode = 'max', + save_weights_only = True)) return Mycallbacks #---->val loss @@ -84,3 +120,103 @@ def cross_entropy_torch(x, y): x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(y.shape[0])]) loss = - torch.sum(x_log) / y.shape[0] return loss + +#-----> convert labels for task +label_map = { + 'bin': {'0': 0, '1': 1, '2': 1, '3': 1, '4': 1, '5': None}, + 'tcmr_viral': {'0': None, '1': 0, '2': None, '3': None, '4': 1, '5': None}, + 'no_viral': {'0': 0, '1': 1, '2': 2, '3': 3, '4': None, '5': None}, + 'no_other': {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': None}, + 'no_stable': {'0': None, '1': 1, '2': 2, '3': 3, '4': None, '5': None}, + 'all': {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5}, + +} +def convert_labels_for_task(task, label): + + return label_map[task][label] + + +#-----> KFOLD LOOP + +class KFoldLoop(Loop): + def __init__(self, num_folds: int, export_path: str) -> None: + super().__init__() + self.num_folds = num_folds + self.current_fold: int = 0 + self.export_path = export_path + + @property + def done(self) -> bool: + return self.current_fold >= self.num_folds + + def connect(self, fit_loop: FitLoop) -> None: + self.fit_loop = fit_loop + + def reset(self) -> None: + """Nothing to reset in this loop.""" + + def on_run_start(self, *args: Any, **kwargs: Any) -> None: + """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the + model.""" + assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) + self.trainer.datamodule.setup_folds(self.num_folds) + self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict()) + + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance.""" + print(f"STARTING FOLD {self.current_fold}") + assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) + self.trainer.datamodule.setup_fold_index(self.current_fold) + + def advance(self, *args: Any, **kwargs: Any) -> None: + """Used to the run a fitting and testing on the current hold.""" + self._reset_fitting() # requires to reset the tracking stage. + self.fit_loop.run() + + self._reset_testing() # requires to reset the tracking stage. + self.trainer.test_loop.run() + self.current_fold += 1 # increment fold tracking number. + + def on_advance_end(self) -> None: + """Used to save the weights of the current fold and reset the LightningModule and its optimizers.""" + self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt")) + # restore the original weights + optimizers and schedulers. + self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict) + self.trainer.strategy.setup_optimizers(self.trainer) + self.replace(fit_loop=FitLoop) + + def on_run_end(self) -> None: + """Used to compute the performance of the ensemble model on the test set.""" + checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)] + voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths) + voting_model.trainer = self.trainer + # This requires to connect the new model and move it the right device. + self.trainer.strategy.connect(voting_model) + self.trainer.strategy.model_to_device() + self.trainer.test_loop.run() + + def on_save_checkpoint(self) -> Dict[str, int]: + return {"current_fold": self.current_fold} + + def on_load_checkpoint(self, state_dict: Dict) -> None: + self.current_fold = state_dict["current_fold"] + + def _reset_fitting(self) -> None: + self.trainer.reset_train_dataloader() + self.trainer.reset_val_dataloader() + self.trainer.state.fn = TrainerFn.FITTING + self.trainer.training = True + + def _reset_testing(self) -> None: + self.trainer.reset_test_dataloader() + self.trainer.state.fn = TrainerFn.TESTING + self.trainer.testing = True + + def __getattr__(self, key) -> Any: + # requires to be overridden as attributes of the wrapped loop are being accessed. + if key not in self.__dict__: + return getattr(self.fit_loop, key) + return self.__dict__[key] + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) \ No newline at end of file -- GitLab