From 307668c0f666b849cd638dba0e81b4d11408f562 Mon Sep 17 00:00:00 2001
From: Ycblue <yuchialan@gmail.com>
Date: Mon, 30 May 2022 15:27:32 +0200
Subject: [PATCH] working

---
 .../{TransMIL_simple.yaml => Resnet50.yaml}   |   5 +-
 DeepGraft/TransMIL_debug.yaml                 |  50 ++
 DeepGraft/TransMIL_dino.yaml                  |   6 +-
 DeepGraft/TransMIL_efficientnet_no_other.yaml |  48 ++
 DeepGraft/TransMIL_efficientnet_no_viral.yaml |  48 ++
 .../TransMIL_efficientnet_tcmr_viral.yaml     |  48 ++
 DeepGraft/TransMIL_resnet18_all.yaml          |   2 +-
 DeepGraft/TransMIL_resnet18_no_other.yaml     |   2 +-
 DeepGraft/TransMIL_resnet18_no_viral.yaml     |   2 +-
 DeepGraft/TransMIL_resnet18_tcmr_viral.yaml   |   2 +-
 DeepGraft/TransMIL_resnet50_all.yaml          |   2 +-
 DeepGraft/TransMIL_resnet50_no_other.yaml     |  48 ++
 DeepGraft/TransMIL_resnet50_no_viral.yaml     |  48 ++
 DeepGraft/TransMIL_resnet50_tcmr_viral.yaml   |  48 ++
 MyBackbone/__init__.py                        |   2 +
 MyBackbone/backbone_factory.py                |  80 +++
 .../__pycache__/loss_factory.cpython-39.pyc   | Bin 2411 -> 2545 bytes
 MyLoss/__pycache__/poly_loss.cpython-39.pyc   | Bin 0 -> 2726 bytes
 MyLoss/loss_factory.py                        |   5 +-
 MyLoss/poly_loss.py                           |  84 +++
 README.md                                     |   4 +
 .../custom_dataloader.cpython-39.pyc          | Bin 7729 -> 15301 bytes
 .../__pycache__/data_interface.cpython-39.pyc | Bin 5690 -> 7011 bytes
 datasets/custom_dataloader.py                 | 506 +++++++++++++-----
 datasets/data_interface.py                    |  82 ++-
 fine_tune.py                                  |  42 ++
 models/TransMIL.py                            |   2 +-
 models/__pycache__/TransMIL.cpython-39.pyc    | Bin 3333 -> 3328 bytes
 .../model_interface.cpython-39.pyc            | Bin 10240 -> 11282 bytes
 models/__pycache__/resnet50.cpython-39.pyc    | Bin 0 -> 8628 bytes
 models/model_interface.py                     | 164 +++---
 models/resnet50.py                            | 293 ++++++++++
 train.py                                      |  27 +-
 utils/__pycache__/utils.cpython-39.pyc        | Bin 2832 -> 3450 bytes
 utils/extract_features.py                     |  27 +
 utils/utils.py                                | 154 +++++-
 36 files changed, 1583 insertions(+), 248 deletions(-)
 rename DeepGraft/{TransMIL_simple.yaml => Resnet50.yaml} (93%)
 create mode 100644 DeepGraft/TransMIL_debug.yaml
 create mode 100644 DeepGraft/TransMIL_efficientnet_no_other.yaml
 create mode 100644 DeepGraft/TransMIL_efficientnet_no_viral.yaml
 create mode 100644 DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
 create mode 100644 DeepGraft/TransMIL_resnet50_no_other.yaml
 create mode 100644 DeepGraft/TransMIL_resnet50_no_viral.yaml
 create mode 100644 DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
 create mode 100644 MyBackbone/__init__.py
 create mode 100644 MyBackbone/backbone_factory.py
 create mode 100644 MyLoss/__pycache__/poly_loss.cpython-39.pyc
 create mode 100644 MyLoss/poly_loss.py
 create mode 100644 fine_tune.py
 create mode 100644 models/__pycache__/resnet50.cpython-39.pyc
 create mode 100644 models/resnet50.py
 create mode 100644 utils/extract_features.py

diff --git a/DeepGraft/TransMIL_simple.yaml b/DeepGraft/Resnet50.yaml
similarity index 93%
rename from DeepGraft/TransMIL_simple.yaml
rename to DeepGraft/Resnet50.yaml
index 4501f2c..e6b780b 100644
--- a/DeepGraft/TransMIL_simple.yaml
+++ b/DeepGraft/Resnet50.yaml
@@ -5,7 +5,7 @@ General:
     amp_level: O2
     precision: 16 
     multi_gpu_mode: dp
-    gpus: [1]
+    gpus: [0]
     epochs: &epoch 200 
     grad_acc: 2
     frozen_bn: False
@@ -32,9 +32,8 @@ Data:
         
 
 Model:
-    name: TransMIL
+    name: resnet50
     n_classes: 2
-    backbone: simple
 
 
 Optimizer:
diff --git a/DeepGraft/TransMIL_debug.yaml b/DeepGraft/TransMIL_debug.yaml
new file mode 100644
index 0000000..d83ce0d
--- /dev/null
+++ b/DeepGraft/TransMIL_debug.yaml
@@ -0,0 +1,50 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [1]
+    epochs: &epoch 200 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_debug.json'
+    fold: 0
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+        
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: resnet18
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_dino.yaml b/DeepGraft/TransMIL_dino.yaml
index ffe987c..b7161ba 100644
--- a/DeepGraft/TransMIL_dino.yaml
+++ b/DeepGraft/TransMIL_dino.yaml
@@ -6,10 +6,10 @@ General:
     precision: 16 
     multi_gpu_mode: dp
     gpus: [0]
-    epochs: &epoch 1000 
+    epochs: &epoch 200 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 50
     server: test #train #test
     log_path: logs/
 
@@ -17,7 +17,7 @@ Data:
     dataset_name: custom
     data_shuffle: False
     data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
-    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_bin.json'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
     fold: 0
     nfold: 4
 
diff --git a/DeepGraft/TransMIL_efficientnet_no_other.yaml b/DeepGraft/TransMIL_efficientnet_no_other.yaml
new file mode 100644
index 0000000..79d8ea8
--- /dev/null
+++ b/DeepGraft/TransMIL_efficientnet_no_other.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 1000 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 5
+    backbone: efficientnet
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0001
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_efficientnet_no_viral.yaml b/DeepGraft/TransMIL_efficientnet_no_viral.yaml
new file mode 100644
index 0000000..8780060
--- /dev/null
+++ b/DeepGraft/TransMIL_efficientnet_no_viral.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 4
+    backbone: efficientnet
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
new file mode 100644
index 0000000..f69b5bf
--- /dev/null
+++ b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 200
+    server: train #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: efficientnet
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_resnet18_all.yaml b/DeepGraft/TransMIL_resnet18_all.yaml
index 8fa5818..f331e4e 100644
--- a/DeepGraft/TransMIL_resnet18_all.yaml
+++ b/DeepGraft/TransMIL_resnet18_all.yaml
@@ -9,7 +9,7 @@ General:
     epochs: &epoch 500 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 50
     server: test #train #test
     log_path: logs/
 
diff --git a/DeepGraft/TransMIL_resnet18_no_other.yaml b/DeepGraft/TransMIL_resnet18_no_other.yaml
index 95a9bd6..c7a27f2 100644
--- a/DeepGraft/TransMIL_resnet18_no_other.yaml
+++ b/DeepGraft/TransMIL_resnet18_no_other.yaml
@@ -9,7 +9,7 @@ General:
     epochs: &epoch 500 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 50
     server: test #train #test
     log_path: logs/
 
diff --git a/DeepGraft/TransMIL_resnet18_no_viral.yaml b/DeepGraft/TransMIL_resnet18_no_viral.yaml
index 155b676..93054bf 100644
--- a/DeepGraft/TransMIL_resnet18_no_viral.yaml
+++ b/DeepGraft/TransMIL_resnet18_no_viral.yaml
@@ -9,7 +9,7 @@ General:
     epochs: &epoch 500 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 50
     server: test #train #test
     log_path: logs/
 
diff --git a/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml
index e7d9bf0..c26e1e9 100644
--- a/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml
+++ b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml
@@ -9,7 +9,7 @@ General:
     epochs: &epoch 500 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 50
     server: test #train #test
     log_path: logs/
 
diff --git a/DeepGraft/TransMIL_resnet50_all.yaml b/DeepGraft/TransMIL_resnet50_all.yaml
index eba3a4f..e6959ea 100644
--- a/DeepGraft/TransMIL_resnet50_all.yaml
+++ b/DeepGraft/TransMIL_resnet50_all.yaml
@@ -9,7 +9,7 @@ General:
     epochs: &epoch 1000 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 50
     server: train #train #test
     log_path: logs/
 
diff --git a/DeepGraft/TransMIL_resnet50_no_other.yaml b/DeepGraft/TransMIL_resnet50_no_other.yaml
new file mode 100644
index 0000000..d3cd2aa
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet50_no_other.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 5
+    backbone: resnet50
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_resnet50_no_viral.yaml b/DeepGraft/TransMIL_resnet50_no_viral.yaml
new file mode 100644
index 0000000..2e3394a
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet50_no_viral.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 4
+    backbone: resnet50
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
new file mode 100644
index 0000000..a756616
--- /dev/null
+++ b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
@@ -0,0 +1,48 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [3]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json'
+    fold: 1
+    nfold: 4
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransMIL
+    n_classes: 2
+    backbone: resnet50
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.00001
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/MyBackbone/__init__.py b/MyBackbone/__init__.py
new file mode 100644
index 0000000..1fca298
--- /dev/null
+++ b/MyBackbone/__init__.py
@@ -0,0 +1,2 @@
+
+from .backbone_factory import init_backbone
\ No newline at end of file
diff --git a/MyBackbone/backbone_factory.py b/MyBackbone/backbone_factory.py
new file mode 100644
index 0000000..ff770e5
--- /dev/null
+++ b/MyBackbone/backbone_factory.py
@@ -0,0 +1,80 @@
+import torch
+import torch.nn as nn
+
+from transformers import AutoFeatureExtractor, ViTModel
+from torchvision import models
+
+def init_backbone(**kargs):
+    
+    backbone = kargs['backbone']
+    n_classes = kargs['n_classes']
+    out_features = kargs['out_features']
+
+    if backbone == 'dino' or backbone == 'vit':
+
+        if backbone == 'dino':        
+            feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
+            model_ft = ViTModel.from_pretrained('facebook/dino-vitb16', num_labels=n_classes)
+
+        def model_ft(input):
+            input = feature_extractor(input, return_tensors='pt')
+            features = model_ft(**input)
+
+
+    elif kargs['backbone'] == 'resnet18':
+        resnet18 = models.resnet18(pretrained=True)
+        modules = list(resnet18.children())[:-1]
+        # model_ft.fc = nn.Linear(512, out_features)
+
+        res18 = nn.Sequential(
+            *modules,
+        )
+        for param in res18.parameters():
+            param.requires_grad = False
+        self.model_ft = nn.Sequential(
+            res18,
+            nn.AdaptiveAvgPool2d(1),
+            View((-1, 512)),
+            nn.Linear(512, self.out_features),
+            nn.GELU(),
+        )
+    elif kargs['backbone'] == 'resnet50':
+
+        resnet50 = models.resnet50(pretrained=True)    
+        # model_ft.fc = nn.Linear(1024, out_features)
+        modules = list(resnet50.children())[:-3]
+        res50 = nn.Sequential(
+            *modules,     
+        )
+        for param in res50.parameters():
+            param.requires_grad = False
+        self.model_ft = nn.Sequential(
+            res50,
+            nn.AdaptiveAvgPool2d(1),
+            View((-1, 1024)),
+            nn.Linear(1024, self.out_features),
+            nn.GELU()
+        )
+    elif kargs['backbone'] == 'efficientnet':
+        efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_widese_b0', pretrained=True)
+        for param in efficientnet.parameters():
+            param.requires_grad = False
+        self.model_ft = nn.Sequential(
+            efficientnet,
+            nn.Linear(1000, 512),
+            nn.GELU(),
+        )
+    elif kargs['backbone'] == 'simple': #mil-ab attention
+        feature_extracting = False
+        self.model_ft = nn.Sequential(
+            nn.Conv2d(3, 20, kernel_size=5),
+            nn.ReLU(),
+            nn.MaxPool2d(2, stride=2),
+            nn.Conv2d(20, 50, kernel_size=5),
+            nn.ReLU(),
+            nn.MaxPool2d(2, stride=2),
+            View((-1, 1024)),
+            nn.Linear(1024, self.out_features),
+            nn.ReLU(),
+        )
+
diff --git a/MyLoss/__pycache__/loss_factory.cpython-39.pyc b/MyLoss/__pycache__/loss_factory.cpython-39.pyc
index ed7437016bbcadc48f9bb0a97401529dc574c9b2..14452dde7b34fd11b5255f4ed07bdeb48a27e49f 100644
GIT binary patch
delta 939
zcmaDY^ih~Ek(ZZ?fq{YH{qy`JS+0qEGK^*uwRhArr*P${<f=xgGBTtvrf}z|<*G+%
zFfycwrSPTjw=hO&rZA>3<!D7|r3j=5wlG9#r|P5#Wiw4+EHX;fZf1zmP1Q`*ZDx$p
zOVv-+Yi49*NM&4LkRqHSvXCjtFqJXOC`CAxG0Qkbw3lgtNs7oq#wgQNy;Q?g(`Ke9
zvlJ$lU<OU`%~6co896xu@^dPE@{5Zn^D&#%ConQF)G%Z*)H0_q)UuQ?mN3;YH!~G!
zlrTeBAZPV5*0M}s%rmKBPGQbwDzYhIS-`rG0W79a!<@pB%~TXn!m@x3D#la8oWh#T
zRFni0%llQsoWhpPR8&&JvVa|=D(^)Na|(MlQ;~363CjWwu+X&{<`j->rlMt&FEESO
zZ>wQW;ml?#+EK!?fEz5oq=q?#E1Rk4MhVLT&JwO>28i)Bj5SOPS!!AGT56b6xU-px
zejv&7EM$Por_?Z~@PI>93P}YAk_x99<`iDA3ZsY;mMm6K)PVe+r&Yt8A^?`~DB(>J
zOc4T$ilhimU@Qtr5uU(U9O1>##8|@D#8?9gX%W971qKENO~zaNIf+TBIq}80`S~Ro
znR)5ACQo2#SBv5($j_<F$uBN;(`3BGUYws+lABmj#p9V*mY9>7q5zW7<e#j;x{<X=
zmVtp`@=aC)HbVvmhF^A*`PpQ-bX`(Y3*3tm(@G}mu&GFiFfcF_+kp}tBO5adBNw9x
z69*#?BL^eT<SaH7ZbmLf9wq@s7RDmu$y3?(8R;@GFcg7IgAhgx3=Frpauc&t;|mgt
zic_sXrg4Bx<6snE<X|kao!rW<tR)Z92KJWuEw1?Z#L|+C{G#~yTdXDdMadbrIKY%%
zUS5&L<V);V-8eyBj0gDy?1d<Sf=aN8_>%nmoTSv8k|H^ffm$HK7(|$Z2wRY2IO5~;
z5_41I<BN1B-{r6kHUP;W91LcGUCm*Wo1apelWNBZa<c{l0|N`l=S&Jr7R)@10*pM2
bGE6*7JWO1w9E=>ST#Rgt9H3BP6k`Mc@EPyd

delta 775
zcmew;{91@Fk(ZZ?fq{X6kzFZiA?HLs8Ah#%+B@nQQ@C<ea#f?$7#UJTQ+QMOS{S3$
zQy5d2ax|heQutE@S{R}<Q?*hAvzaC^7OACbHZw$Nr>duFH#0`*r0S;XG&3?Xq%tnh
zOA$&DUdR-spURkJkRp`Im}QtE(#y2KC`EW7W0Y~KPO5&YaWhktNeWXigQi#%TX9BW
zepT}3JjV5mlQo%*>s=Tb7-|@@7;2eQ7;0Hc7)zLHn46i3R7#j3tY(H<mI;h`GBwO8
zjM+>@79}hTSQkRXIBJ+vn6jCQyh>OWuz|$#zSS_NFlRFr#lS_L)G(*8WHS}zm9Q*e
z2dlYM!<@pJ%~aG-!m@y4AumHMOWuJR<`lMUrlJLt-!Y3zFQ{QoVb5kN+El`_fE%o(
zuZB5=Bb%w{(qsh|an8IN<`m9srlL=i9azM}{c4y~xWIuVQo^!;vxKXe0c<cs4Py<{
zLY7*VJe?Zm6mGBzof4KT))bx;-d@I9mOPOZz6p#)7AgD_7>jMZ7@8PMc$*k&Sb`Ze
z1^l8W-(hJNbgSa=%qvUG$xKnm$uBO}<eS{dx^c2Sn}M<}0|Ub^6a9?*+*JL_oWwl+
z^8BLg;)2BFRQ-^m#Jpl(PoK$i*i@v185kIft-$fm#>~RV#VEkU!N|kN!N@iF7Mlt?
zlK>+NW0C%3PWF99nhXpKMIdt_gdPI}!!54d#O&1gg2bZYRCAC?9AJ|;7zG$P7>g_?
z|72I@mIG;%XJBB^G@Y!@A;+jP*^T3>nLt5hNq$jsMtn(reoj(qPDzn0NSzvp&<7F5
zAi@%42S<E-USe))e0-7S<U~%}Kpl__!X021*fAV7x%nxjIjMGxAg6$gVgY%ENrB0N
fnTJt;k%v)+nTLr>nS+som5Y&$k%JNBRWU{Y(4f2$

diff --git a/MyLoss/__pycache__/poly_loss.cpython-39.pyc b/MyLoss/__pycache__/poly_loss.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08edf030976f28689763861889e7ace3d3747a59
GIT binary patch
literal 2726
zcmYe~<>g{vU|{H-Rg~1s%fRp$#6iaF3=9ko3=9m#ISdR8DGVu$ISf%Cnkk1dmnn*g
z5yWQ9VUA)3%dzCLMzMm$m~+^2*`wGQ8B&?EI8s>Bm{OUuI8(V&nX|aTBxefi9M%-J
z6!sR@C?0o)6pj?m7KRkgRQ6`(DBcvtU<OUDmmqieX)@m8@GmII%+E{A(PX^E8t;=|
zT>O%efq_Aj^%h%BVp3{O@h$GW(%ksuoW$bd)Z$yLDJ7K!skfL@GIMXS6{VJx7Ud<g
zfXsqoke65)7#N&E-s)jsV5nipVyIzGVH9V`W+<{LVXR@uVrphgVN7R)h?X!zL?jqQ
z7{nQBSxQ(Iu+}ituq<S%Wq_%!WvyYYVU}b_Va#SOiYj4CVQOY<Vys~ViP|ufuw=2<
zu+}iuFl2G0F!wUovemF<an>-`Ff8C&$dJOA!qUrJ%TU7*%%I8YR|Ilwl^Miq3dI@u
zr8y}I8Hr`73W*BI8Hss$sW}SenI#ztIjMQ+B^e5-g{6r(3MKgp`FW{&n#{LYiZiQH
zi$K1-#g>v<mYJNY$y~(2z`$^eH8U?Iwc-|QSz=CUswVd>mOM}xvE*bHm)v44&PW7>
zT25--E!L9!qU4NQ>_w>%K|aUg;?yEg9Jm%0<rm#ztx7G*FD~K+naNR{oLEwlS`>ea
zB_}^GU6Z{C6t+cd3=9llH-K#}5&$V>D$cmYn7@*tNQ!}h;g_|3Mt*Lpeq~N#o_=|L
zQFd`bVsff}NKs;5v9G6(u1jiafqPM6T8X}IB`C`E3-WU+<8$(hi}eaBZ*i66$LHsz
z#%JW0fLu{50*YHEK1L-*9wsg(9!3r(4kjT+5vC#;1_p*?kR&n&g$;;b?8U&qkjfCn
zn8Fanl)|0LoXV2Qn#Cr`kjkFQk-`K@qU{W6jNoL-8^sBdgW(j8Im{`XDO@cqQC#f|
zEDTZH!3>%_w>SdwbHL7Z^Si~8o0^#S9Apmz!v}j!)?4hw`DrD&i50gvlT+i%Q!~>u
zN^Wr$rKXf7gM#)Jdul;(W=?)y5huvapiqS9WnciYL5YGHlwu7S7#K<zni&=_E@WV2
zs9~&Oh-U(&q<H2UmKuh5mJ-%1wi>1s#uTPrrW%HL_7qU6V)nbml9!*Cs>yVV4PtYV
z5-1CC++r;*El4f8#StH$nU`4-A77*ZiaC&Ekuo^a^73v8I2Yv?7rW+_6y+CyeRqpD
z86*;)3K7y|y~R?Tnv(|6tB7L2O2%89@sPY39}kZCB59BUbp{3od60j=<qsPdBO4f2
z3Bp2959T6GW>CT{C@s0gR+3nho>~G4AW&ey0tl44L6Hj1(oBpD3^j~d3@MC~3^hy;
znu(#7xt67dxrQa1xyYl2VF6<eV+m6hGc2GMuq<R)2nsIN61Ezq8s-${UZz^6JPEi;
zc9=?tDo{o)Y6BGy*-S-MC7cVmQdkx;GBV`xlrSyes9{WDT*y=l%Kl6u3|ZXZ$Y88x
ztzpOlVV)Gm8dhjw!dAnO#hb#K%~W)tglPd^4ckJ-TJ}6UuqazLQ_-~&rUl$kQ4O#t
zdp1+in;NzS{3(oJwL&EVDI6e~*$gS1bD6=#4m+s)OuWR%z`&&d1&&4O#a3`0NH8<6
zptQtFp(G<!0i4ZYMNLwwf|H*|bht~bjzW1xYEi0!pF(CaSY2LeZc=KILVlV8w6N7f
zF%e{HWPYhaa$=rBL1J;SLQZ~qW=XL^evv{!QGQZlQf5wONoH!XLSnH3*mfO-%rqpG
zItsa^#U%>GsU-@~tYRBdRGNxn2P6sLb+&>=eqO3>Mt+GxT7FS(Vu>cm0ZxWwd&#o|
z7BtEEc_oRNc?wCHd5J}p3gB|I81B5F#LVJUv_K966}8|}(@G&cBQ;MUApw+A^x)Y(
zApsP)dHE%v;u{pQ(Mq6lLrF(Lskk&3M1k_HQY<{A!PSE%Be-1EWCRzqRYt{`dFeT+
z@amzUC^aP$T&n0OBp?T6f<k6`UVc$(ie8Z%D0ed#m*y6!FfcGwiHGJDmlhP{7nP)@
zD8OB3rBG#{qmWpZn3<E9l#{BE53Y%e6%vb56>vFJQxII-7iodAxe&+}_Ts|Q)YPif
zB5^P;ITc)vgUjYyY>7qbxrr6vqT?1vd16ssW?p*nEfx@~2vpXAiy?4n1#vQGX&%(L
zTU?0H)#NDB2HB?vBJ>#;7@|1y;z7X>pPUiJofi*DHu1?Bx0nh_ZgIjZlPFdYGY3*I
zf$Al2>2!-dEx)Kdu_y&(X0aNmnBfv*;$V_t;$xOzVq@fB<YD4r5?~Zz;$q}s1Y;H+
zCLzWu0eI;JH9(Wa%@6GEB2Zxmsv+Xz^AdAY<Ku5}#mDF7r<CS^*gWy^pjsR%!;+Mr
zpHsvH@(MWo*g{hCit~$ZF&CE<-C|A4$xkdP0+pXdR-o|Yg%+>LsW~~YA{Oiu9gqRw
z!d8}nfdLfX#S#n*3>-`hj9~bOiIItsgAqhB{TARU0;$sED>4BY%2rZYkeQc`<OmLM
zqd+e&uSf$FyrM9!9>f#HdZ4<w2vjN;If5*<2N}bhnO6d727oGJunkB-4R$5Sfwwqp
ZAg0)X^cVAhT*=78#KFkJ$iu863;>xw!Q=n{

literal 0
HcmV?d00001

diff --git a/MyLoss/loss_factory.py b/MyLoss/loss_factory.py
index 1dffa61..f3bdceb 100755
--- a/MyLoss/loss_factory.py
+++ b/MyLoss/loss_factory.py
@@ -13,6 +13,7 @@ from .hausdorff import HausdorffDTLoss, HausdorffERLoss
 from .lovasz_loss import LovaszSoftmax
 from .ND_Crossentropy import CrossentropyND, TopKLoss, WeightedCrossEntropyLoss,\
      WeightedCrossEntropyLossV2, DisPenalizedCE
+from .poly_loss import PolyLoss
 
 from pytorch_toolbelt import losses as L
 
@@ -22,7 +23,7 @@ def create_loss(args, w1=1.0, w2=0.5):
     # mode = args.base_loss #BINARY_MODE \MULTICLASS_MODE \MULTILABEL_MODE 
     loss = None
     if hasattr(nn, conf_loss): 
-        loss = getattr(nn, conf_loss)() 
+        loss = getattr(nn, conf_loss)(label_smoothing=0.5) 
     #binary loss
     elif conf_loss == "focal":
         loss = L.BinaryFocalLoss()
@@ -46,6 +47,8 @@ def create_loss(args, w1=1.0, w2=0.5):
         loss = L.JointLoss(BCEWithLogitsLoss(), L.BinaryDiceLogLoss(), w1, w2)
     elif conf_loss == "reduced_focal":
         loss = L.BinaryFocalLoss(reduced=True)
+    elif conf_loss == "polyloss":
+        loss = PolyLoss(softmax=False)
     else:
         assert False and "Invalid loss"
         raise ValueError
diff --git a/MyLoss/poly_loss.py b/MyLoss/poly_loss.py
new file mode 100644
index 0000000..e370545
--- /dev/null
+++ b/MyLoss/poly_loss.py
@@ -0,0 +1,84 @@
+# From https://github.com/yiyixuxu/polyloss-pytorch
+
+import warnings
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.loss import _Loss
+
+
+def to_one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor:
+    # if `dim` is bigger, add singleton dim at the end
+    if labels.ndim < dim + 1:
+        shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape))
+        labels = torch.reshape(labels, shape)
+
+    sh = list(labels.shape)
+
+    if sh[dim] != 1:
+        raise AssertionError("labels should have a channel with length equal to one.")
+
+    sh[dim] = num_classes
+
+    o = torch.zeros(size=sh, dtype=dtype, device=labels.device)
+    labels = o.scatter_(dim=dim, index=labels.long(), value=1)
+
+    return labels
+
+
+class PolyLoss(_Loss):
+    def __init__(self,
+                 softmax: bool = False,
+                 ce_weight: Optional[torch.Tensor] = None,
+                 reduction: str = 'mean',
+                 epsilon: float = 1.0,
+                 ) -> None:
+        super().__init__()
+        self.softmax = softmax
+        self.reduction = reduction
+        self.epsilon = epsilon
+        self.cross_entropy = nn.CrossEntropyLoss(weight=ce_weight, reduction='none')
+
+    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            input: the shape should be BNH[WD], where N is the number of classes.
+                You can pass logits or probabilities as input, if pass logit, must set softmax=True
+            target: the shape should be BNH[WD] (one-hot format) or B1H[WD], where N is the number of classes.
+                It should contain binary values
+        Raises:
+            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
+       """
+        n_pred_ch, n_target_ch = input.shape[1], target.shape[1]
+        # target not in one-hot encode format, has shape B1H[WD]
+        if n_pred_ch != n_target_ch:
+            # squeeze out the channel dimension of size 1 to calculate ce loss
+            self.ce_loss = self.cross_entropy(input, torch.squeeze(target, dim=1).long())
+            # convert into one-hot format to calculate ce loss
+            target = to_one_hot(target, num_classes=n_pred_ch)
+        else:
+            # # target is in the one-hot format, convert to BH[WD] format to calculate ce loss
+            self.ce_loss = self.cross_entropy(input, torch.argmax(target, dim=1))
+
+        if self.softmax:
+            if n_pred_ch == 1:
+                warnings.warn("single channel prediction, `softmax=True` ignored.")
+            else:
+                input = torch.softmax(input, 1)
+
+        pt = (input * target).sum(dim=1)  # BH[WD]
+        poly_loss = self.ce_loss + self.epsilon * (1 - pt)
+
+        if self.reduction == 'mean':
+            polyl = torch.mean(poly_loss)  # the batch and channel average
+        elif self.reduction == 'sum':
+            polyl = torch.sum(poly_loss)  # sum over the batch and channel dims
+        elif self.reduction == 'none':
+            # If we are not computing voxelwise loss components at least
+            # make sure a none reduction maintains a broadcastable shape
+            # BH[WD] -> BH1[WD]
+            polyl = poly_loss.unsqueeze(1)
+        else:
+            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
+        return (polyl)
\ No newline at end of file
diff --git a/README.md b/README.md
index 04c7dba..effc0fe 100644
--- a/README.md
+++ b/README.md
@@ -9,3 +9,7 @@ python train.py --stage='train' --config='Camelyon/TransMIL.yaml'  --gpus=0 --fo
 ```python
 python train.py --stage='test' --config='Camelyon/TransMIL.yaml'  --gpus=0 --fold=0
 ```
+
+
+### Changes Made: 
+
diff --git a/datasets/__pycache__/custom_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_dataloader.cpython-39.pyc
index 147b3fc3c4628d6eed5bebd6ff248d31aaeebb34..4a200bbb7d34328019d9bb9e604b0524127ae863 100644
GIT binary patch
literal 15301
zcmYe~<>g{vU|?83yE%y^hJoQRh=Ytd7#J8F7#J9e3z!%fQW#Pga~Pr^G-DJ~3PTE0
z4s$L`6bmCnj5Ufig&~DGhb@;qiXAM*lEab98N~@^v*vK+az$~0*^D{dxja!kP&RKA
zFW79pTz)17MuuF0C_$)<P?S&#Lke4taIQ#{2vkfoN))VGEJ_?qOGHV4X~`(56owS`
z9O+z{C>gL?jvU!sxhOd>n=?l~S0PFf%x27yk5WouNa4y+&Q*z0VPtS;Na0T5X<<m=
zNmXrTj#6`HNa0Q4Yhg&?OJ&beZ)T3taA!#2PZ4NgND+XFXu30`2&M?NFr)~9^=PGN
zrLtyer|P7L%wb9qO%ZEhjnZ{zND)U?ld6{@k<B!LsVD~{?DV7bQ&|=mEM&+sOp%mi
zh%!o%N;R6zkRm;oIm#HymYK^OWddc(&Sj1=g|g-5GDn%E>Svj!nx)96DD*NhGNdR*
zS)^*DnxvYhD5WT;TBN9?sP?i&S*GfwTBTYxGp4ACFf=npS;P3L))FATO)6`aZ5I0i
zyHv|m>r|U`riDx~Oi}hx4yl}}_9=|1DyjC(j4A3V8oex0ju5dF^%Tus<|xM$hA5{L
zhIGbM3lWAW=Tt`*hFG;ImsFP+<|x-Fw<z~i=TxQz9;t2%8Ce-58B!$88JHPt7$z_k
z8l-xrIyW;iGNgK@T7k?5u~Qg>88o$Dg3_X&CgUxZfW(pv5S@}(lBmgei_0akB+(~7
zF(tJKBwA9KlB>yhi`6qXF+CL|9g-iCnpd1(6lSE!cuT~wv?SjxHL;|$DAl#1q$n}D
zBp;-WE3q^^H#M&$wWwH=@s@;VZc<`SVqS7;3dr!{)RN%D+=86cqGXUGkTDZ1Q;0Ay
zFr+d>F{Us?F{Lo3Ftsp5F{dyGGib8h5(-MpOHcL7FUn2K$*f8&$;{772I+&D!_2_I
z0OC7?G8f348pc|t8m1IRX@(Rg35FCVX@+11O=iDzu!aZu4$r|1h6na588n%1u@tA~
zq}^gG2D|VUb7pRO5ibJ+1DN>bs-Kaco2p-#lbEMpo?nz*T#%TYsvlC6m{;uU>7(nC
znp)sql$chc4+_lU)RJQT<kI4j{M>jDKL;GBdIgoYIO5}z6LWIn<7Gf@WC7XF$ii49
zjp;Kzm`XoQ7Lc~Q#N5>Q_*-1@@wxdar8yurPkek~X<`mUrbr0pR<MI1gdj+an}LBr
z8e|<PY&aN;Kzxkw2Rj9%C<!I}qnJ}9+8NRqQ`l12TR5XwQaDmLTNt8PQ@B#NTNt9)
zQg~8$TNt9)Q}|N&TNt7^QUp>2TNt7^Q-o54TNt9aQbbZjTNt9aQ^ZolTNt8v+8J0F
zqIiQDG$n5ddbqfm`g;0+5(XrI9{yrvU<gVrC`v6(%_}KZNX%16OezMaEd`Kb1yEQP
zr<Uj`xaAiq6ldn8=cFolm?|Xa=anR8=A{-XDsU+%C_o6$yn@mag@VMQ#N1*lurNp}
zEi)%Iz96wA!%6{^`V~s@6-qKv71Hu^Qc{ax24&`@gGHg%DQM*9r7GkXDdd*slw=lw
z>~wK6RR9@Ptf>b#IyXNh)k;CVq$n{nPaULOy(G1`L>(><@*_AkSSfg>DTEZ2rs^nw
z#EKOXb8_JJDP-p6rlw>jmZauXDrDxB<SPU@`syi!7N;tdWMmdAWTwFl1o;Q#c$hOY
zixtu`ONt@7a#M5jiz@XL{7W)Yi^?;LL6)T^mO)KPEKx{H%qdP)NX$!t1Q*CK5QZ73
zP@b8S19qQ6PGVJNP9?&lAp7Ez6O%Ji<BKz^QmqvHN^_G^ixl$HAh84v7=@CI#1e(%
z#5{$hR0Xi&6ot$@us6Xx1&x%{w8YY!5?f<UJua}H{M=rGGVMzS76yi7PEgi>Vo?SL
z22e@{<?rGMCI*HQhGvEZj0+hU8PXYQ7~+{SK_oMnWC4q?*09$w#Iw~f*D%Dh*D%&F
z#B<az)iA_!*09tt#B<ef)G)+z*Ra+w#Pg&u)^e8c)-YsorZ6=#^=H*`mGCX#ui;wA
zSi>yIP|IDz?ZOc27Q<A_Q_EW-P{Y&Am?c=tQzB5qkj2r=2ogzQ7H4Q?lxC>qtKnN9
zw2+~wpax_fOAU_=Lk&X~>jL2#zJ-j8421>Z3@HpO3@i-I%!~|q0)`C5XABq^!7!46
zk)cEcq`H~Wg&|g=mcNE!foO?X4gUi1g$&J%HEau+7#Rv5lt?U)ERo2PN?}Q1?PW}7
zTnO@;G&pxM)C#0<q_DRz)C#6>f@z@?E-)>e!VRWHN@TKROJuU-K)E)Bua~JtutX+H
z9#jdZ2=+2HGnU9?DS&cribyY0iA<JaicpFuh+QL4B9o<*B9_9KBAz1A%M4bjTq9H>
zlcka(nIe@UogxDgtr0Gf$x=;`O_58H0#}+W&5Si5Th&q&QxsE_<}jrwr>LYzrKqN;
z^|CfIrZcCggUnB7OVI$)H4O3UB^nDPOEjAqN;DQ|Eo4Y%0{c)KEUJ;ElOmR)*~?g>
zk);b`EznER0{LzsV>&Zftv*<-K?)BTYWFgum}FSPkj0at18y*clo%~APSFMHGD!i|
z4SIedw-}46SoJbe(oA16FfuT#WW2?Zn^*x(>B)@H1PP)*Y*q#ahG0;kDZl_qlv#`^
zj5Q1?OexI0%)tyRS^SE47#J9$*n<-DQu1>RS2EsW0hQG&nQk#>CMHI42IuFdmSpCp
z7O!Ny#hh4*RNWMTlKU+Vo1Dblq?AOvKn4be&mc>ybU@`GN};J2AD@|*SrQ+wXOokk
zoS0K=rw7wUvU}5K%>;py_Aoaa70H8Y5JeE73?kqaBG`pRstgPaK_GE81_lO_t$Tj<
z3<$h|TW3s!b)fJrBFnlWO{CZ`A;PK<vVDePl_?Qcg~F^_$&~r||Ns9lK>>Y>)wj|k
zIBX@;E!MKcoYK@H1qKF&l}zczVDQ>LJq-*#+23OKt#nV!%}oR;O$RZ+O2NtR7Kd-8
zTTW(yPY~D{Fm9;JOHc|1=?uy*u{2o8e2XzViZeSkwE&#f!3L~kO5Y*Z`Z{n&pM83@
z$U66jRa5MXVA=vQQ<GCyGT&k?PEO28y~UDTmY4(8^)iKlf#Kx|ke?(N7#MDG`c}H;
zBo>!sCa+|<#hRE?kdb(cwKy|9H?hbbqzLLSuy4}d*{4I&#RvP9Ot%;-ZZTG_WXik@
z3KU(i|6l$Au|W+MkY^mz(lYZ>SF#q_fox(c$}dSQNxj8doROMXgfJcCPyf7B|Fm1I
zzLk!p>6(nU7z;2GEU1D3rLSU82d;!+0b>orLdIZ*m5hFxOt+Zw3U09`78NB{f;1;W
z8d#ucLS%(zkY_+?Q<M1?e@1C)d~rr*T1k9PW^QK5E&k%f5>$b*#2gd>mXyrA{3?;O
z#N^bZ{QPWv5LdS>vn0vTOw;rhYjJ5oY7wX@2+8laIKdThW?owUEiNz<T(REbE6&MG
zN%cugO3iV}OfI>_3a*cD@gtOgBls2zsKIiJ6C?|65#8d-0qKFTIrHL^a}tY-Q;TnL
zBqgRpbZ{gkCTAz*=cV3a$uCIFyTy`KoS%1#1=NTs0<|q~u@sl2=H6n@ERF}6dW$Ww
zpdd9br3h3^-r|o>Oi2N^f8s$-F22QFTvBw4y(qOfKc_797Ng@WcIW)ug8bstTl~Sf
zi8(o`#U;Lp6~UQRska1+Qb8?^_!3ZaJ1xH`x41|N6x&?EsfDGfc_o>NIk!X|Q&K|m
zJxWs@^HPFA-4byB;1;htC_6ak=am#C7MI*&17(GrA_!fYa*Gw5_;0a7(gr^`O@mrX
zkW_U`6q;<o?R=2M5JTCZ(gNVd9+I<eaUnSo)Xu)e0qZs0;sZrDLi&~nq+0OEFUqU}
zHDGf<P6PLGQu1@dQj1D5lM^9gY(C&>Ac`+DHyznB38)FBsR$2;feMN!zKr5BB$Y*=
zA%|Q1X+`<D@dZVxU{|N6+!9QK^xxuB;T<^9B0f+W2KA7NGC-^}5RnNYvOyuv35slR
zEEc7L1hc?`MXAZ9Ma7w*XiU2$T9TSu5D!j-@sJdil9^mm<Ox#l1uB90E8>&$LA}5d
zNK^b4PlhQdaTceR#HSV9;w*+$i?_JoOd}+=F%sJZiEX-)@fK%1B;CfxL-Hr69xX~>
zU|`SzRns+~I+T%%iH#8n!gx$9EPRZ7j6957jAD!&j9iR5%zTVu%q)yt|2ddB7@?4f
zg^`C*fRTrhhf#=;g;9u6gi+=%6ARNHmS0RfOkX$@S(sTE+5T{_voLZoaxwAz;a~&l
zVEf0x!oyf4g}eGj>B)lPi;aPS0o;?7VPs&aVOYRW!?=(!hPjrhmbsQCouQVsh9Qe_
z0n<W;TDBU7Eanu(Y^I{B8YUNpSg%_48ukS&3mF!$mat{9FJ!D`2enpHn6jCQB0w$O
z8uk?C6qa75TJ{>Y1za^y6$`kbEY<})AQmG-p-x@^K@+#2naHq!7o;0xCl(XyO89D6
zn;BDBvzdxsmGCbBwW9?=?bj6cg)FtKd0Zv@3xsM|7c$nemhdkSu3=utSj$|(pCytd
znk5Em&u59JaP%_OGS{%yFxRl8fkp;6{cf>>({xcBC^<3}lxQ;DV#@@LTNK~oD@rXX
zEy{~80gqi2gR8HXpn|>#)F;+tyv3bYnhQxZsl}QSw>Z*MOCSjxRGQxsD9MjcNi9h&
z%FWEnEH24RE&}C<TdaAdxdoNC*b<9NDhpDJB0%=CmS*OaSQLTsNl_0-fEPLz5)XDt
zQ5z@+6oQBfkSudfYThl@lKi6Nj9aWliFqk`Mc~}TT3nKtoPCQku^8&+TTFR*x46<u
z^O8X$9*H@(*z@yJ<1_M0ia_~NlLMS(Z?R_PrKDEeVogdc0T~F+dEn|eiZj2oBpxEn
zS)5uJ4=P1cia<?PNFIULVYj&B<I_`1GD}i(<Kz248O8vVVVDFMSs1yPxR^N@1sJ&)
z`IxvsB#Z}=VdP=rVG?2FV&P&GVXTtHmdN4eVpO`Ingg89A(bvDRW~yRGZck`EBm6Q
zpa4q*u^GW0MRzlMe0)x7UVQvkkb6O9F)&t1VsimlRWc~%L7oC(P!3~<<uV&^6TXJA
znK6s8mKmJVn93N6ENU3Cn7}<jmJ~*Co?~9XQo~roQo~%sn!+@P1=O<x_Z^u1UjG08
z|NlxBXvGW~Whhz?ay_VdbBhyPJI80HRA{n-Ys;d1kQ`_rrU>M{B2dcEWP|t-T>kcg
z3QLai;>>tZuK{e-E#}OW3bYU?Dq>(@*a7lAC;%Bi#SSAEXi!NJODLe$!6;=8Cj$cm
z2Ll5GD1?fI7(pGH1&kdGDU1u5`j^!*r8Cqr*D!&~8yALHlUkM<rWzJWP-lm!l}Qp*
z8nG^5E@8=n<XpA|?2tTM!@Ph4#A0M9lqz9czzNOIT;TBKZf1%Bm1b-;3=4Q_m^&C2
zGS#w|@Pfyim{V9<I7;|x7_#_5J>q7@X2x0$P_I6VzlNiRT^uyr$!fz8&XB@{(xcvi
z(xV2KF3pTJ3|T_iEE5=u_)3Ip7_tPL89Ny}7#E0iFo0qYB(gxXgE56o63lOBOk+%8
zPvL0c0NL&W>U}fB^QSNcGiY-Dssas=f(N2tV@Ke4RY=b)ON9(?fSNww@eI&#3MdhP
z2U$uBQW8r*qc70W5Cw2{g${|pm4m85aPvr^xTL5w88o<x7<Va>VqjqKL&-m&xB_8g
zQ1pYEL^TYtQZY=mka152Nrn!FY^EZn8pbdNMutKr&;amCrlRd2bHJ%+CG#yNJ%d|}
znYS3rRx%cWlO3Ws;Iv6k%}cE)D6%^bN~)lSZIwQjL<MgXAay}uI*ULl6I|$`I2@Ew
zL0L~2<Zvqn2Jk>mtWhmv4Py$}*(r>YAg3Ea_%)0vOp**Wj48~L44|QpIV^chHH;IP
z3R!{~G}*u&0^0%3B;Zo72vi;vZ3DR*<e`;JMSDSoE>lY8O2%6}MX9;@pypV1YGpAv
zAi!<-TU<6dnZ+fb#);i6kWR3V4T$tHT&t$oEtU+^g34PgZkeEt5!jMNAoD=gMA2dp
zYZ=HfT;K*7xF!MT{-PBiQNfb@__WN5)D&=|J3c8f9a1mwf*a5!`H-QvTP%606(zSg
zQ}ar5Q;QNyQXz#1xU>Wf;ewlIw^%@qy2S?eZt5-Og8TwaPDBZFi#@{>+)|I?1eG$7
zDuSgXGba^XJKSQ7FWL!m5eFzqgGw%NP_1Nwq?4l{e}LLox47crDPTV+K0yhQft8OD
z)bIf{mAIHdjVcgJhS7jo2wZ;%FjvXo3s{tb8kEFA7+g(&QVVD>FN>j!u}GqZF&&h>
z8ETkp81g{bouP(F64Y~J1ZC|KrW%GUut+mw3Zn!=FvCiws)>#%DaBZ_Fla0k(pV@~
zfQ+N2q*mxDgk+=^r7C0=D}Y9U!I>SNT|uF%kd|MhkXn{nR0$pt26vxykVle1qr0Gd
z4Vn~y<af~MYi3R=QqDF9)o;v2;D`g&1K?V3B{L+dKt6{gaZpMI#{oFG7M)>WV7LPc
za8R+$zyhk81sJR3u>>oq*oUPYltLe5HYhKGLl!iH01hukNd{2<15KWQLd*}GqgFD4
zlNiiZ`#{kTs@Wh80y_<y8o&)GP>uBl<TOw>oq>rDYtt9h(gG<&aS_NhAXkBKF~~*W
z>Hstm#8JZtDvZHXI2eTq$R2QE0xBfa8EP0}g=!f)K#>pf41**Cw6I`=6cM0c(}a`>
zMPQ$726+eWiK1iB5jXHiQx%pX0OWRX8><+p&;z9haOnXK6HO*?{M};D%uC5kPK6dU
zV2i*3#f32ABPggqK4)OUmMstpQGyE8egOp&IPyT99Q4QowWZ1!i<D{@7BJN?fieVK
z1*pzuhVwup-7Ih(3%EXK$YKO#B~aU^hNXrT)Y4wbT6M=6l%7jKQxUKhsX}62C1@B^
zp)$Wzp*%6K1XMF7CMTyB7b_$xBxjW7WrOAhKzTqRttdYiGzkFDLZC_#Bm<t-2>}@k
zo_NSj%uZD(E-eDpn29AI9gqy0Sqz?&NCi(OBqtW9DrA-@WEP_smzNKo_bCBc3ZKpZ
zXIw;%16iA?U<;{*;kqHY3|wQu8*VF^!O0$yH?M>8DYT6SN;bEc(h49o99%s(>LCOu
zmW$>yFfe=tMK7ph&A`qFYM+5SntV*yx`Uvo28D+vi=SH&Xz;8^l7WFildlMrxQjsZ
zBSoO;kX!8W@hSPq@$ul!ZP7+h>fHh&K#jp7kTb!hL(x5u*aHv&a>p%RX!kfd6*LD4
zaV0p>PX{RhyH|{XfdQ1Qi$QjAFtTv-C2^>8FmteT$Z`m9uyU{zfmDE+t)M0}9E0*7
zh*KPeK7S|z&cG}wtl;@W&^#h&{tz;c$OfH9WCzVifaeb(^N5_#c|<PIYy@cjupKmS
z$P>(<DSAuLJvA@2II|crZ&)PGz`)@45;V+|3~E$>+yTNM|9~(kj}!|LH*W-*6T&iY
z1RB_n=Sl(1A#s=REZ|+pki`dTjMTE1@PqolDJ;!Q{Ry=kB?1csYdD~DN1QdBE)20Y
zwOl1UHC)Y%5FIt_*~~=?YPc2%p^7eKs^tRB|11!Oh=X!#VF|{J(J_?vE%J;JH+05`
z8#-ge0-iCtRU(lkxj+g`N-typ(=s&-S=<X5p>r#;C2|X-O60QSQ<#wElw?7D@)m|#
z9_XABFLX|c4?3sBp8}av5-3r~QY=x(QUb*wcutA8L?KHVX--KYO9d2*p!uB=g)CM0
zoDxroLY7(zGiYWQI;W(NrC!5VqL8J5GN;5}qL8JDHm3x#Rg363CG8SW2<S93lt?bn
zh0iJJfqQM5elJ12aqwV35vUplx3Y^s)AJ~G3~1CFl)%8l!=S!%2?OeMPmvrb3*nt9
zv1eppfK(_t7-^Vl(>O)y1l{2Pa|ek&)F5cGBbg>^5j5F}Op~<<n(PcSxd=2ySfmHi
zpbsJpL4*;A0ClgxlPN`}ATD@L%#4A7A*9Fx#J2<y)*u3uN{T=u2}QObE@+h(q!0&H
z(M2vGF?SH*0V2FXgfEB)01==`z9R6H6KFC>Q@<!4Bnqn5kt!fixmyI9T>-c6!DCU7
z);M^u7Ce5$l?ht;0PZ}1#~E%hCzq$(Vl7B4Dk(0?0hyT#BJw~jd-kH#f}F(UR7m#&
zl)k~E2}Q*q6(t}Q+~5JKvdrSl{Jf$TkZ>7Dm@6bdz|#l3H0Tyv5O{RH2%L{fK}yO&
z1gIBKR0(2Lfrx4l0UCfUss*v?Ktw%=XaEt7AfgFGfJWWGOH_(lL0r(FnWiAPA%)~o
z&=_11XnX`547XTQQj<#4Z?R<L=cX3*gS1Tm5uPAoB8Zp-A|``~DIj7hh?oW<!1W%u
z1_V_U;L22lk%6HY6r2p0BgIT?Ol*+Kn~{S_h*6BufSr#~jERkrgOTSS2MdW~#k`QQ
z;wmZJr2$HN4-`M3x(t+<igOr{MunL`qr#xp9%y`~7Sxi1j0bzvFrkbCGu5yx-~hD*
zYnf^o7H~pXc~Tgv;BqxgC0r>?b3kJs%zi~PK}`{+B9I%v&1g+-aPy`J9A^ta;R%kI
zg&-DMhj$4`3>29VjjWIuT?-Ngm8RexM$rZk7aSX40vsivrcE>`N<e)91|~ixE+#G}
z&^ROyh-9vk#TIzvjW!l-0@)94po8rP6JSSxnpSxrM}XVtj8&4@EG00&2r4j27+?d8
zP!<!opT&~GgfX<pj5M^!;`j33|Ns9HV;U$!h@gHJq=*1@9kC1{f(Cp+{VZ_MgS+%d
zEkCgTK&{3?kpI9vNzm{v6ANRNA{JkQdXk{_DR`8s7$tdvnqDCNpyW{uYIA~SP@w~e
zWwlHtOf^gkm_dtKKw~+e!9-BkB854H89ac<Qp2)<rGzz$Z6PBl0WV;$VF3>avX*eb
z8;~3&oHYztT<`%wHqa8DEUp^18dmUtARBl<kQrq_umh#R#0^^h0Ui+K&1NaO10EmX
zfer`q!G{BR7Vvj~N1Q;7FqGjyE>I5?w1fz}B$2ad7bwMmx}4zQCqI-56%-C245~ef
zL8T$#!MUQ{AoIZKZV!lc07QVf;AlF?z`)Q7tFEf_u|yAiAO@*diZT?3;zCer26ZAq
zeT!lf*eIM1#weT)@+cf=>Szu#()b%Iq}v3p`#?Pha5B|ohKPa%z&$FkHx7Y3bOc0z
zxnN&``q~rVzA_-vSMa{Jra;kgkb&Uk2H=6plOQfy`2!j&DLM<1I0quYBkrh!WknZ2
z5*I;596O|R0uP5?0*PJ*5#VwPG$2-V6~qO5;2MYp_64ZVRRPLEpiUbDD<6{pXqu3j
zhY>X9B>*4u!sk0$jd>N_0J-uehyV|3f?W<Kz%B;`%_NY^L75jZjD*Fdpv+J1FjCQN
zko}->m!i8M7TA6;0d@qaqdEuV2yjOdJYa<B08r-`RFI(*JD>y%%4XnN5!76Qk7F=O
zGJuA#K=mRuM?s1ng53WIM1a=zqnD+ij`l*3Pq20PvHJ;BplUL^6}<rUV>Nlv`uasr
zK~{l$MxdV$vbpFPNHI8E!34O00}f%(VrLGJyFe8wxMT0mAq?)=OL1^<uoi*jLCs50
z&l!$Eu>oqu7K3^>@MYPYpfUIs#wcb;ClRzPIG90`tEdY!Ji?Ue16rR0sxntHWio=;
z|3Pj7t%XYlHDO?egV>-F2pn;s4i|W#CIe_}H=YqRK@`sfUarXu>Zn1c16hiqK}N-Z
zh*%I22C@;_n7|^19yg#s+ztu@a8ifP{QpwKwT=SbDMLwcpoub&+rjPz6&f{+wM?My
zAZSS{Xu2?)q3BErXqu3zggJ|KAyX|=3G)IrP-82G88mrc%L-NvnUrUg1T7Nf0M&^p
zOyZy^Ft!p<^9xiDf(9bk%o%E#!RkT71KgnY0;u~AQV*UkV_v{h!&1Y(kg1jhu8J3`
ziWjPiqXeXiwT5FMQ!Q%=^8!B5)H!HKZz0I#FqiU!Tw22v%&?N#4_uvr`@<^*AhX+S
zpegG7+#*nj-(pTKGm7F!E-P`)&&e-};&t}-@ehg*a(6QF2oAf&T3nEmS#paVq$4x0
z1QKV9LEQ%S(!5O2GMXY#4Blc+Oi3vMMdK~j+|;7<RG2v)!C^*0?oKPAQx*K6hz1qD
zw-_^SF%}nr<`|+l;mfO{IN=McqBud5I^acD=rIpU^!q?D4@zhZOe{<SjC_m&jB<=z
zj2w(oj67fr62qGsAZz#iG+BzifE@J=PJj~kN>=oY3rcRd*X4t9%T|z+K$(n#iHjYa
ze_n#NFlaK}Vunrpprk!ex&uw~g3}(TlLBgEfzli(#uqZCFw`)DTT@Im;8_n)%2~i%
z!(74up0sAp0=G9M7-~Q)CP@g*45e8l85XjnF@f7(EG4Wpj3sO}49$$qOlhFzKPyt2
zcnQk3MROrlC3q@}0~FT<sYSWBSc@|f3sP?}CnuKNVy#Lo$}hgfoST^kp3%@`E#d-I
z^q_uf6nI@(W=chI6gPN<8F<@8aTGU5PkcdQN=j-9dL*7^U|_fa3NcV{GcZdqaxk+o
zYW-(pDgwz+q3}YAR*)4G7FR_-K<VQrhya(xU;<nggA&F)kVlZqp`u@)MYN!xg_V%Q
zzN9Jo4GNq;Ac6&yFu)-MCcwc2%Ct{G!9*z27O^ujFyP3RMW9L#DJ3C{2W7xJAmd>f
z5VW{5H!(9WJ|0~3{bJOw(t&LGf$#5$FDS~-N=+`&PtH$C)d!C|fEL<i=IJFDmsRPa
zsfO(a(g&?4*EcdXi#IYgE6t4u7m)f@mgvgCtC%wL(&I}KlX6mv^}*`m0~~|nOOkVo
z;>$9N5_9xGt1hd|(6xfv5XGq_`pFro$=MiIL02LbhiEe1;z+H?EG~)9&o1Hyg*m8G
z51DGvWJOx@2OeDkEieMlK0{ZAzqALpQ8bxuu|u{=6+xCyfp$cJcYGCzGBPk|G8S=y
zC@)CE96THVn)<%Q4q6V8lbHk_mns4k98r+*CB4#;%$#EIz``wQ@Q4#cL=WUEFahe=
zq!tx{VhX&TG)fGt3^FjKhgen#8KZg!S~?>PR-u=drw5y<O3W!*1xjnHL3I(M+b!ng
z;<8)J0iHfZphjfTKaen6L1JD?V(~4ef|MfAKv@)DVop+NZfaf$c<pC#5om-fN(@y{
zub>iaSrMoYe2WL}oYbOX@Y>5Lw#?jg&<cN0j)tsVWy(yv#RJv^4ap)<xqgc+KM9n)
zz$0L{*g<m~#i=FWDHm`94O~UsVuNhNyTuHuT|q@3xZ=6RlwVx*8stClh`=XM7;u79
zR6NL^kmbTf`S~T_SiZ%bUs_U7S^}Q12JLakPEE-yDh979yTuL)caX*4E)r-;=N4nq
zEmm+W-Qoe2_4&|59K`{#A*DEq4OIE26yIV>$+^Xt9K{Wu9*s}VFU>2t#R6JX1}XgE
zGiA5f5{uGv6Dw}978HRz0?wES0#pLr;;_jD?{>2T^)8D+qe~o29E?1S9891Um^@58
zj2w(YpwTzj`b_Z3OpquKvjD45Jwv8YJ%gTrv<Q!YEEiZGXoV+89}goR6AP0FxSJ-y
UD8k6c23oPn!OX=d#mK=30K)+=ivR!s

delta 4320
zcmX?FzR^Y}k(ZZ?fq{YH?xXmmH%1H$k3k${%+A2T;K0DZP+Y(`Q75K8myd~oks+5q
zN`R3eg&~D8M=(k-g&~D4M<`b~N*F385+wqb6O9rB)8bL$?hGmHDI6^fDIBTnSrX06
zQIhTqDV!->Eet7KsZ!0%QPQb0Dcspi6BvtfQW#R0a+q^jqFA8j$wtYhvMi8W$dFYp
zpTZ-_5T%g9o2oFIA%$-)bCe>K%|DkpN(ssqn9Cfc3}p+>WsXuwl}uGiRZbB~5l#_F
z5$$D-QcdMdRZZbZ5$k1&Voza+QcGcoQcq=fVTd(|(nwWL5zl6wz*ywR$S{GiFeg<b
zhB-<zN-IiRDpe;{vzd{RAyqwv$DAQbeF9^lRSIJ;gQmpf1x%)spDNS{dAPWl`g;1f
zB$g!R<R_-27A0=BX8p%l@AguJfq@|z6rivG1_eC_0|NsyD1h@A85l|!ni&=_E@WV2
zNN4C^jAzOOk<4I{rG~A8F`l)Cxq~sDt%k9KF`m7Kse>_|qlTq}F`hF8B*K-#RLfq%
z-NBf}k;2@})UQy>QNpu;w}xXOV@*A?B*-w9aE25H76ulEW@bi)JOM+7;xYpUMlg(I
zU}Q*PNnvea=wQrZabai%8Oc(@SHj=Hn8n`AxIiF<Z6RYbV<$8&I~YJ}IvLs-+8NWB
zQaDpMS~yDhKzbJlE@WtC1j%+V#tW4Qmxy#Q)@KPdGlGJyM6`nu#CBnbm8j+HU|b+p
zB3{F}Kmw$thIJtmBSYbv63GQpC6WuI7c!JcE|5v#TF98rv=HPRS%^%QTncvzPcLJM
zWR^UPwLl?-H-)c-VIgBWGgz%6SgleDKN!{v^fH%7E>K>`05wV_g&~+hQ_$~a1p@;^
z$SuaAWJXxLK%<0(fq@|y6g4so3=EkJH4L#LF-*0LwM-=pHH^)SS&TJIlbH%xf*Dpa
zXfhQsF)%RPVks_3&DCVS#iVC&i?QMsV`ULH0|NtyU}pGbJ^4Db?BsvU@?u;zDVfP7
z$@#ejcIFHW44*+3SLsc@C@DPIK!|;^y$W|dTtO9^mR?3mnkkC&K$d}04cK`?3=9nE
zAm?$_GS)C=F-!zI3S=pRCSwr~0|UcKrXpUDJNQ5Z%tZ`E0w5=I+2mvvL)_vDk_MTE
z;ufCCe`Go8F|220U|<jig^n7`dO48w9Smg*MJy$tG{o4<Si_jYAkH8Sb&@7y6)V^;
zu;AiM%gjlQFGwuOC;}O^k|~N8BvKronFkiV#ZsJ_lLqpM9K^LnpgiOUbLLA1Mh1qL
zplF-Cl-F4q&0`!kIf=PRDT#K8F!fbBNI@*bsmYZ2`Tzg_FL@_3@>wuxs!k5)v#jS&
zNi0c>PfkqENR2PftV+GblAE8BS|kb%s+`26)EtP*IP>C@a}tY-Q;VZGk`mLQYV!+H
z^KP+Z73b&OVgcp;TP&c!g`^GEqV$~nq+2Z6sg=dISi%0h#ad95nOAa)IVUx*2o#{V
zCX4gSt8$eTCFT{U<rn1^-(q*p&n?I=PQAqyk{{sd<C&Y7o_dQdD7AR9Kfl;ybpa#B
zTa%*%H0y71gye^$<`w4`-QvqAE{jJPp~+Pw4hjK~7mBn&EG-bB1M(3k$dU0NFBfTo
z`9-P8rA5V=WvRCq)1r7XOhHn`sU`7g1uGeEamGXZ93NjK1~LMaeu`8X7#QRkycrmZ
zoj}Q$k&B6q5e^}2CKh%+Mm;7WUJ*t<P6Z}DW<Dk{W)?;^Mvi|RtSpRNj9g4Se>m7!
z7<m}k{#8jqBT1wlR93PvFfcfSlGzys28J4j1q?Ne3mIdWYnf`9Ygy76YFTR-vKUhs
zvzdxKYM5LYVwGyyY8VzU)vzsO1hbf-EY<}qAQmG-p;Dd#Sfx}gTM26x+X8lwcr9BE
zLWNHWM-6K;V|@x!Hd9es3FiW?6lPG#n8Lb{rIt01tAuj_cMa=8##+`I))LMIJT<IY
zyeUkQ3=0`08EQcyd^OAq8EctKIJ5Y(*s=s_m}?la1XI|0nQED9SZkPTSkjn-88q4b
zZn2gWC1&Oofs*jc6b1%{m!M2tB*DPIpvicvo;$BJ7ZT#B#hOC5IMP!~AaMnbKu1s%
zxPW4fEwP{=H7}*e62#$3E6qzT$;{77%(=x`k{@3JP83C^AUW2O{G#NHTdYNic`13f
zSc^*%le2GefOrL|MY*?Fi!%}nQg3l47K3%w=NH{#%FDaOo}ZT*pOIfu1WGcR?BLXN
zi#0PZCAH!fYf@qf$V70?xW$r}n422KnO|BG4-se1%uNSnUvLo!$-=j|<Kxp)OEODR
zbK~Q)K{2ZYN;gabj4VuC%p9Q9z{tnU!N|kN#ly(MB*e(YD8e-PkBE<fG$>e@z&<Gg
z6*`)XV5gv)&mJG2lbRPF-!{2I)W{W-hCmpchCnSDSm9mE42~A2GKM0X8ip*U8fHlb
zNQz1Ur<^S21uQj;H7qsEHEfe#iz?N}fWkW#6pWmZ;vqAoLX#Dox50&Mkp@Ts$YDkP
zAeTY{3S4p&flAgWj`HHncu)uxfsDGv3F5+nDWw9<8%1Ue3=Ey1-~koj42)dd;E-YA
zW#Ot)oZKiQKKYTBxGV<)0|N+yl1=e922cxP0b>V43gbei{+`KK#q{cH7#6TW(<<`<
zb`Xn^p-`%ZVF5=Ca|gphrdqZV&US`0Mo<cG;V9v%VaVc!)ZNXDwd^%)Abt&d4VyTF
zBtr_T4MR9X3KL5EVIFe(p@gS~A&a+}v4$awFPo+4Q;7pVM4*$ggK>dC2g5=}Mh1|`
z0>KW(6gE(nVFWcF7{Sd$4v=M_W<w1_Ja-CHFoPy%)#Q8PHH=)7qa|D%HJOT%85kJ+
zz#*l{46f2{F=pOkEYoC!1T(k>0F_F&IBn8X^HM7citILnVg+2O=}*2Wsm=uxFG`s#
zAekHlG7+5rz>!h}ir*q|g^>i(2Qp9-QhDBDO3Bn@yv0+Lnwwvim=g~w#=$NDSEJyH
z=>kZL(B$*7;*)E1L?{22<<!#@y2X-VT2Oh5#Vs=@wa5sReAp6;ODYRei$KkaTdbv-
zc_kJ_wjizklS8Cb>cJ@plwFI|L4u%ixyS&l4{UYnE#`v!0!<D`+6NWtQS2F}@u2#o
zC>5lPr37SC97t~(0|P@82S_SD6Pm&_nIJ(0%6VXy-{OjgR|3lz7#L)V<v{t7m5&ir
z9t(i-AtMhX52FmDfvCV_Md?~b&B^nnr&>#b(l|53jiB@n&X=0Z5NCnXBRH}kAp{CW
zuw_LB3=9mHCg;l}T4^%*fkRu95t5mqI@3Y!2FC(e6PN(UD7f?k<(@~AxnxZz7w8u<
zYEHf*_aFisgP_6_5rdk{U^5|(D+4){4dO14>9?5D3W`wGqq`E+!gxMeM8TMo$*l;K
zzchI#2eYWu-(ruCPsvY?j|bO!Me!iFCV&V~W-S6)1omK2IY<m-$Sq!Im6@EHlM^40
zWRC|(3CLqbpwbD{`Y8q_QVvEIZayClX%1!%Sq>cz0S;EMJjjm>43iHTSTZ_I{$n7`
z`5Y8S3=9wKC(9UGM}td7mLf%vnV{0NNEyVE0(qPj;&CieNR9!!15}a!1r<0T4>Le&
zZpNy~2aIJVpELX-r^)OG$tU257l5Q*Hc;J>pIam~`GJv<9N3b4P#R<}&C3L}Uy4Ak
zxy78Al2W8R*}zy*Mw1ilo?C3iiMa(isYRe<dW$jR7GrTy#^fYpRY`D!g9(tMZgIrN
zCnth}j%9LyxeC}hYN>S~-!;IACWt>PK>}4Eq8dbiyi?Q*Vu4LV5Fm4lKy8)ZAkTst
z`5a7K>|ha)b-9U|dGYZ@l9P{{cZ)iL+JSlT0gl1(i8(oXpoU$Q-sHdL5|isK#KMcb
z7#JAbUV@q^nk-07D^O_<sVo_9ae+&CaP#mbsK2MFTa*I|A=bRo+=5C-D6)fEz&V*o
zkh&yjvX*6NJ*X~^5(d}TdU<(zh)TUE0_3LvkQt0_x0sWQ%Wg3Tc={B9Tv5~l5@st%
z%u7iuzQt6KQltkGU<Y+Zi&IO$N$D0_eo|IyatYYOMW8^t#gt!si#fSG<rZr}Vo^!)
zEtZV@dQd$MZk)t}Y~Td71LKSG^Gm?qyv3bgT2fG20%;1fgCs%5gSr>D7?W?YfGYPW
z=A^`QM3yWv0XZ8H3t$!~DsFMu<mRW8=A_zzT0zC21`P)jha)2oBL@?x+~Hy3VH9En
twFfvDIT$&Zgz6bG1!TFv@_e9@38aceh>?exg^`63)P7)N<YMHI0svo~<<0;A

diff --git a/datasets/__pycache__/data_interface.cpython-39.pyc b/datasets/__pycache__/data_interface.cpython-39.pyc
index 4af0141496c11d6f97255add8746fa0067e55636..9550db1509e8d477a1c15534a76c6f87976fbec4 100644
GIT binary patch
delta 2831
zcmdm`^Vp0pk(ZZ?fq{YHhHOVtyuw628OB2swTnYi7*e=$gmXorL>L)T7*e=%M03TW
z#26Xe8B%yscv~1!cvD%k#G9F;B-|NN_)_>=7*hD5B9iV5DFP{iEet7wP!XvV#$X0b
zp@}D|O*9#Ai8z*)<h!LNmXsEyx>l4FB_@~T7iluylJLw;O3X>jOHNI3Ni0b$PAv&e
z%q_@CE!ynN$jB@u%)r1<BmyEtL4+8H5C;(wljB$<8J#9qu*hobx}>HSxECd+mFTBH
z%qRvk;xqF~Qj5|OlT-ByDvN{|7#IX5Z)1tlRAgXaSjku<3o-}9E&>xu3=9mncnXSA
z3lfV`<3R?gO}1op5UF5bV5nhCVa#SK7N}uNVJc-P3Ypx?YG}+<!(79X&eY5l!&Ji%
z&sxI}&sM{j!j{ccoB~x>Qo>Nfkj0q7p2E@#Qo}x(k?jtn%;bA)(xMzi93cB;K(;bx
z=9S!H&PmOiEXr==33Ash*5Z=H^weAIAa@j}mfT`Yyv3LVaVWyIti`D%r3E?+3=GAf
zkYV5vU}R%t`p?EBz$m~h#y)upyFcqnW<NL0$uHQoHMK$ZAp%qx#08rLCcxeoC@D(J
z%$t0VLws^DN3cVYF#`jGCQA_uNGZsgBK+F<$`W(HA(xY%n37r)F!>cntOK$cx*&rP
zF^SCy{3WTyB}j$@Pj2VD%&0gynoC+5>`*WPb}g4pPG)gQa(-@sUBu*xT&;}alf}6`
zRYBnw#g&|xo0=1!mzbN17HUP33=9n6n_IcpFe-x5#4VP@qV(b-2~b3^WT#dZ-(o8*
zNJ%V7)nqG@ogBn#AP2S)On|NB&CDw<Nz6;mOw7rwN==x&fY*YJ59Bfa$@h3IvOs1Q
zset%KAi^9(Sbzvi5CO8h$PZ)}Z+v`mPGWI!a%xUad_0mP<w5E|HWl$QFff2jE*4;5
zVBlcl6XRgxQsiLfVCG=q;4I>u+`wlh$u@_zogodB#yDDdqu5eJQrT0PB^gp!C!gnw
z^x|h=V0g*Iz`zjVx01C8oY=T>5|dJM;tLW>GKx$<R+)la!<w9!oROLgN|T_10fa#%
z1V}5$jN**Rh5Vu7HS9GE@q8&9H4O3mDGW6X@d7Cf!3--WU*uPlT`5>32r>@re~?W@
zVIVHVipe4Z`PsKPk`j}%lk)RYi^3Tg7@~O6QsYaC67!1F@{4j;azQdAx=xT8aUcSu
zr6>W!@&XZ{_$mSi6+C9ZS-b$08x$sA6p&|OV`O8T{6Rp1WhGOQ-eg`unR<2x1_ox3
zf*A}93?&Q;7#A{RF{Lm{GSo7GvTPYcQCbOe4MP@74MP@d3W&v4!;l5azO1#(B^)K3
zHO$S7DXbz4&5X4yFn$e-1c=XC!UZZ2YFOb#0(T8VJP)WyC}Sw<0-0M6HJH7IAq#GB
z4MRL{312fq7Jm(63P&$vFvChtzao%JHJSX1R6%}KgJeTcc-~^k%}+@!0*44VFw{YE
zOnyZgAcY`8ld(t-RKN)ofg-R-3M4KKA~+dj85qE601`6@Gc-ZcT9b{0bn3xT2uips
z*&y)&GOI`qq(C1;7=Q>v5P=r<AU!RhpasQ010x?32cs0D0HXk-7!wDh3{#b`ucr^F
zg!j!)Da}dM1D9MY8H<t`7#NZnLCFn_C*Kv)wOPqj1WJ=dW*~1sOP_R*m5{UqVqqyc
zia<GS(quQ`Buluul~COoAU$x=B5b-rS#;{;hr-b&c}Uh~64nUH)-xs-h-itif@7@+
zlwXQMz<ExS(QoDC^&(;kpb);rl9ZpH15TtxpnO#X3NlF93UYc8$OT0p(~5#Xwt|y2
zm;i?~xIhSIU|;|_uNV}y985AIQha>Oe9T3`1d9S!8!klXW4N;@6ci2kGrv9d%wH4%
za%?P!NCXjJ$AX;=b}=Y5%mXD0aF(nGWyvaGgiG||<1_OzOXA~;l0Ys11q(PGFfcH%
zF)%QIasf5-Ydtsv6Q5y;%dbT#Adjbkh%69+76hPZ-UJE)AqG(X#LOB3NZ|p_7e(2V
z&kI>uxfSJrjLQWP*pe(L)f@mB3=THNDrvYmlM^_UCrgPa>4OqWQ9dpMKq>Apk^z#F
zKM4Ck^ni*{toa0#K#w8mk(yjDBJGPc=@%7(EC#0qFaZu1aH=R_U|^U7G6XFhgTytZ
zCu@mrRpNoy45>xMMcm-}q9ng4IincTJ}lyyd|On~2vkKDfwDyrNH0<?3bqks*)0y6
cTyU9V2P$TYWhN_%iE%RWF!C@7B?##O0A>I?4*&oF

delta 2069
zcmaECwo8XEk(ZZ?fq{V`vCkrDg49Gl8O9kCwTrdf8B(}XxLX)fxKmlPgqxY8MBEut
zcv5&<7*cqlBBCjb!3>&w6F*gNUdYJE%qhsgz)&Q_z`!thA&bc5{VZZ~7Wx_axvBb<
zIf;4t<@rU~#RZAUsrn&BiFw7oo<5WRvdFLsFfcIiO_pPg(v)XlU|7jmBn>hJ#4Z98
z3JeSkw|EMQQVSA`QsYw+OA=KkFJyHP>0n@Bs9{WD%w{SUs9{WDDrG22nf#U2(3!cJ
zk&&T<rG^Dev(_-xFxRl8Gc_~CFx4=`v(+%fv)3@Duw^q9mq7KklrYpVWHF|&r?B*b
z)No9WWV>S_1+tj6q$n{n@A<5mAn?-u7E4KLafzEIXAuX;m0%^znRz9*m~&F|Zm|{=
zW#*MkPGdLn1$p8YYjH_pdg?89kfV!JOKveH-eOFG_zK}Q*5cHX(gH0828LpgaSS{H
zjBJcd|Jj%X7zG$P7&#cl*d{+@_h;2)_6u1#*_1<D9;CJi9@a&QAeVs60}~39>p4Uj
zy(UlPh?{J{!QxP4#K6Fy$x_4uQl<?eaBBtmn6E4`2OP{f`H3m1MZS}bIb$8nk<8Em
znIQ}!L_h?(^&m6&L7{|XNWkO^oR`f(&eLSP#g&{`kXe$LlUbFj$#jcJ&)^ngCVHT7
z+2mvvmn7%s7TASOzQxtbC^|WT+fx-3s!?3YiMgpc@p*~4sc71Y#2FYELN=e{Uc;ye
zN>;a65{uG{i^M=qV#!XeEWX87T9A@hlB&s8Bt3aHuYnxMNf1N8$&5EMuec;JFF6zJ
zy4cAtc`ev@K%V5Ctj=d44>Gez3B)%95oRF50z_DX2=U2vd=kb86J<eyAY+TT85kHq
z#uS4L;9%kt<6z@b;9%xp=3wFAEaINLPSi}2bq;GgLmFcWTMBy%ZxmaKNGf|Ovm`?b
z$K=1Fu9IW=MLc;K7#LnMfqbaRS_Dcex43c=lTvfy3ld8*ii|<FnSgxCnw*%Nk($g7
z%F|HH$-uw>(#a1>WigZM#6r1KIBFQ;`BE5y88j!W3#iFz@)v<*i$G2Un_d(OG7)0z
z<Sc>w`l2vUBH)DNPmuk{DnM$ZK?F!qkr#*s3icv!Ab`UJ91<YYb3u-f0lAEki;)dG
zWU3PL_4IK`EJ^guPbtkw)r*g>&&<m#iI4ZwWGd1Hd6TUuwWPEtFB#-7kYOOq#=yV;
z!ptBmCNMBClrSt{T*#2cl)@;<P|F0$?qv)`NhQoR3|TBS3|XuxAQoE+OD{_;a|wG1
zM-6i`V+yMXLo;J73yfdGA_3ypvzBmzN|G8@cv-?#!w}C6Doe^3ikgrNWv^k#0vlSx
z5YJP>+su%~SHqaX(F-aqIsJ-2F{8=kSELLIB^6MhGJ}Ki7E5k^N@@`}7*#>SOnyab
zAT=OD(})r5Z%zIpP-$Hx0g{je5u6|b5g%af>L4+V$-9Jf>cO!JO0S?q4-y82UJ)q#
zi}XNTeGp*)BGBRrq^BMfkf0Ksft!zwgHa4zD2Xv~FiJ3038TgkxERx9EQ)1dU`S>J
z#ReEpKE<Kw04lpQnTkNkv&a<WGiY%X53&}L#6T=8MM)7Tlg*iYPB_UDt`1h1C4lt6
zMKQ`jkSr*7&Yzqq5^a)%WNjiy3cp5Bj$SnRhlo}^3pl=tK)J6d1f2UA{WSTC{6S^}
zfQUd)YGf`hDJlYm`Yo2E{QMk9HUcHhB2X|w3SW>{ia;(Z0-0VE1kw!78ejq(`ru3w
z%)r2q4GMTrL<@2-$p}fw@-g!<7l9-+#U?9BZH*T8$xP2E$;-@3M`RLkIRjP%Dbm2k
kf(ekZw>WHa!FkCJ6gkD9+zF~(co;bt1sHi?kVz;W0My@>+5i9m

diff --git a/datasets/custom_dataloader.py b/datasets/custom_dataloader.py
index ddc2ed3..02850f5 100644
--- a/datasets/custom_dataloader.py
+++ b/datasets/custom_dataloader.py
@@ -9,12 +9,25 @@ from torch.utils.data.dataloader import DataLoader
 from tqdm import tqdm
 # from histoTransforms import RandomHueSaturationValue
 import torchvision.transforms as transforms
+import torchvision
 import torch.nn.functional as F
 import csv
 from PIL import Image
 import cv2
 import pandas as pd
 import json
+import albumentations as A
+from albumentations.pytorch import ToTensorV2
+from transformers import AutoFeatureExtractor
+from imgaug import augmenters as iaa
+import imgaug as ia
+from torchsampler import ImbalancedDatasetSampler
+
+
+class RangeNormalization(object):
+    def __call__(self, sample):
+        img = sample
+        return (img / 255.0 - 0.5) / 0.5
 
 class HDF5MILDataloader(data.Dataset):
     """Represents an abstract HDF5 dataset. For single H5 container! 
@@ -28,68 +41,89 @@ class HDF5MILDataloader(data.Dataset):
         data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
 
     """
-    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=20):
+    def __init__(self, file_path, label_path, mode, n_classes, backbone=None, load_data=False, data_cache_size=20, bag_size=1024):
         super().__init__()
 
         self.data_info = []
         self.data_cache = {}
         self.slideLabelDict = {}
+        self.files = []
         self.data_cache_size = data_cache_size
         self.mode = mode
         self.file_path = file_path
         # self.csv_path = csv_path
         self.label_path = label_path
         self.n_classes = n_classes
-        self.bag_size = 120
+        self.bag_size = bag_size
+        self.backbone = backbone
         # self.label_file = label_path
         recursive = True
+        
 
         # read labels and slide_path from csv
-
-        # df = pd.read_csv(self.csv_path)
-        # labels = df.LABEL
-        # slides = df.FILENAME
         with open(self.label_path, 'r') as f:
-            self.slideLabelDict = json.load(f)[mode]
-
-        self.slideLabelDict = {Path(x).stem : y for (x,y) in self.slideLabelDict}
-
-            
-        # if Path(slides[0]).suffix:
-        #     slides = list(map(lambda x: Path(x).stem, slides))
-
-        # print(labels)
-        # print(slides)
-        # self.slideLabelDict = dict(zip(slides, labels))
-        # print(self.slideLabelDict)
-
-        #check if files in slideLabelDict, only take files that are available.
+            temp_slide_label_dict = json.load(f)[mode]
+            for (x, y) in temp_slide_label_dict:
+                x = Path(x).stem 
+                x_complete_path = Path(self.file_path)/Path(x + '.hdf5')
+                if x_complete_path.is_file():
+                    self.slideLabelDict[x] = y
+                    self.files.append(x_complete_path)
 
-        files_in_path = list(Path(self.file_path).rglob('*.hdf5'))
-        files_in_path = [x.stem for x in files_in_path]
-        # print(len(files_in_path))
-        # print(files_in_path)
-        # print(list(self.slideLabelDict.keys()))
-        # for x in list(self.slideLabelDict.keys()):
-        #     if x in files_in_path:
-        #         path = Path(self.file_path) / (x + '.hdf5')
-        #         print(path)
-
-        self.files = [Path(self.file_path)/ (x + '.hdf5') for x in list(self.slideLabelDict.keys()) if x in files_in_path]
-
-        print(len(self.files))
-        # self.files = list(map(lambda x: Path(self.file_path) / (Path(x).stem + '.hdf5'), list(self.slideLabelDict.keys())))
 
         for h5dataset_fp in tqdm(self.files):
-            # print(h5dataset_fp)
             self._add_data_infos(str(h5dataset_fp.resolve()), load_data)
 
-        # print(self.data_info)
-        self.resize_transforms = transforms.Compose([
-            transforms.ToPILImage(),
-            transforms.Resize(256),
+
+        self.resize_transforms = A.Compose([
+            A.SmallestMaxSize(max_size=256)
         ])
+        sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+        sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+        sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+        sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+        sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+        self.train_transforms = iaa.Sequential([
+            iaa.AddToHueAndSaturation(value=(-13, 13), name="MyHSV"),
+            sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+            iaa.Fliplr(0.5, name="MyFlipLR"),
+            iaa.Flipud(0.5, name="MyFlipUD"),
+            sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+            iaa.OneOf([
+                sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+                sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+                sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+            ], name="MyOneOf")
+
+        ], name="MyAug")
+
+        # self.train_transforms = A.Compose([
+        #     A.HueSaturationValue(hue_shift_limit=13, sat_shift_limit=2, val_shift_limit=0, always_apply=True, p=1.0),
+        #     # A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=0, val_shift_limit=0, always_apply=False, p=0.5),
+        #     # A.RandomGamma(),
+        #     # A.HorizontalFlip(),
+        #     # A.VerticalFlip(),
+        #     # A.RandomRotate90(),
+        #     # A.OneOf([
+        #     #     A.ElasticTransform(alpha=150, sigma=20, alpha_affine=50),
+        #     #     A.Affine(
+        #     #         scale={'x': (0.95, 1.05), 'y': (0.95, 1.05)},
+        #     #         rotate=(-45, 45),
+        #     #         shear=(-4, 4),
+        #     #         cval=8,
+        #     #         )
+        #     # ]),
+        #     A.Normalize(),
+        #     ToTensorV2(),
+        # ])
+        self.val_transforms = transforms.Compose([
+            # A.Normalize(),
+            # ToTensorV2(),
+            RangeNormalization(),
+            transforms.ToTensor(),
 
+        ])
         self.img_transforms = transforms.Compose([    
             transforms.RandomHorizontalFlip(p=1),
             transforms.RandomVerticalFlip(p=1),
@@ -100,32 +134,45 @@ class HDF5MILDataloader(data.Dataset):
             RandomHueSaturationValue(hue_shift_limit=(-13,13), sat_shift_limit=(-13,13), val_shift_limit=(-13,13)),
             transforms.ToTensor()
         ])
+        if self.backbone == 'dino':
+            self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
 
         # self._add_data_infos(load_data)
 
-
     def __getitem__(self, index):
         # get data
         batch, label, name = self.get_data(index)
         out_batch = []
+        seq_img_d = self.train_transforms.to_deterministic()
         
         if self.mode == 'train':
             # print(img)
             # print(img.shape)
-            for img in batch: 
-                img = self.img_transforms(img)
-                img = self.hsv_transforms(img)
+            for img in batch: # expects numpy 
+                img = img.numpy().astype(np.uint8)
+                if self.backbone == 'dino':
+                    img = self.feature_extractor(images=img, return_tensors='pt')
+                    # img = self.resize_transforms(img)
+                img = seq_img_d.augment_image(img)
+                img = self.val_transforms(img)
                 out_batch.append(img)
 
         else:
             for img in batch:
-                img = transforms.functional.to_tensor(img)
+                img = img.numpy().astype(np.uint8)
+                if self.backbone == 'dino':
+                    img = self.feature_extractor(images=img, return_tensors='pt')
+                    img = self.resize_transforms(img)
+                
+                img = self.val_transforms(img)
                 out_batch.append(img)
+
         if len(out_batch) == 0:
             # print(name)
-            out_batch = torch.randn(100,3,256,256)
+            out_batch = torch.randn(self.bag_size,3,256,256)
         else: out_batch = torch.stack(out_batch)
-        out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
+        # print(out_batch.shape)
+        # out_batch = out_batch[torch.randperm(out_batch.shape[0])] #shuffle tiles within batch
 
         label = torch.as_tensor(label)
         label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
@@ -138,29 +185,7 @@ class HDF5MILDataloader(data.Dataset):
         wsi_name = Path(file_path).stem
         if wsi_name in self.slideLabelDict:
             label = self.slideLabelDict[wsi_name]
-            wsi_batch = []
-            # with h5py.File(file_path, 'r') as h5_file:
-            #     numKeys = len(h5_file.keys())
-            #     sample = list(h5_file.keys())[0]
-            #     shape = (numKeys,) + h5_file[sample][:].shape
-                # for tile in h5_file.keys():
-                #     img = h5_file[tile][:]
-                    
-                    # print(img)
-                    # if type == 'images':
-                    #     t = 'data'
-                    # else: 
-                    #     t = 'label'
             idx = -1
-            # if load_data: 
-            #     for tile in h5_file.keys():
-            #         img = h5_file[tile][:]
-            #         img = img.astype(np.uint8)
-            #         img = self.resize_transforms(img)
-            #         wsi_batch.append(img)
-            #     idx = self._add_to_cache(wsi_batch, file_path)
-                #     wsi_batch.append(img)
-            # self.data_info.append({'data_path': file_path, 'label': label, 'shape': shape, 'name': wsi_name, 'cache_idx': idx})
             self.data_info.append({'data_path': file_path, 'label': label, 'name': wsi_name, 'cache_idx': idx})
 
     def _load_data(self, file_path):
@@ -173,30 +198,15 @@ class HDF5MILDataloader(data.Dataset):
             for tile in h5_file.keys():
                 img = h5_file[tile][:]
                 img = img.astype(np.uint8)
-                img = self.resize_transforms(img)
+                img = torch.from_numpy(img)
+                # img = self.resize_transforms(img)
                 wsi_batch.append(img)
+            wsi_batch = torch.stack(wsi_batch)
+            wsi_batch, _ = to_fixed_size_bag(wsi_batch, self.bag_size)
             idx = self._add_to_cache(wsi_batch, file_path)
             file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
             self.data_info[file_idx + idx]['cache_idx'] = idx
 
-            # for type in ['images', 'labels']:
-            #     for key in tqdm(h5_file[f'{self.mode}/{type}'].keys()):
-            #         img = h5_file[data_path][:]
-            #         idx = self._add_to_cache(img, data_path)
-            #         file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == data_path)
-            #         self.data_info[file_idx + idx]['cache_idx'] = idx
-            # for gname, group in h5_file.items():
-            #     for dname, ds in group.items():
-            #         # add data to the data cache and retrieve
-            #         # the cache index
-            #         idx = self._add_to_cache(ds.value, file_path)
-
-            #         # find the beginning index of the hdf5 file we are looking for
-            #         file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path)
-
-            #         # the data info should have the same index since we loaded it in the same way
-            #         self.data_info[file_idx + idx]['cache_idx'] = idx
-
         # remove an element from data cache if size was exceeded
         if len(self.data_cache) > self.data_cache_size:
             # remove one item from the cache at random
@@ -223,6 +233,182 @@ class HDF5MILDataloader(data.Dataset):
     #     data_info_type = [di for di in self.data_info if di['type'] == type]
     #     return data_info_type
 
+    def get_name(self, i):
+        # name = self.get_data_infos(type)[i]['name']
+        name = self.data_info[i]['name']
+        return name
+
+    def get_labels(self, indices):
+
+        return [self.data_info[i]['label'] for i in indices]
+
+    def get_data(self, i):
+        """Call this function anytime you want to access a chunk of data from the
+            dataset. This will make sure that the data is loaded in case it is
+            not part of the data cache.
+            i = index
+        """
+        # fp = self.get_data_infos(type)[i]['data_path']
+        fp = self.data_info[i]['data_path']
+        if fp not in self.data_cache:
+            self._load_data(fp)
+        
+        # get new cache_idx assigned by _load_data_info
+        # cache_idx = self.get_data_infos(type)[i]['cache_idx']
+        cache_idx = self.data_info[i]['cache_idx']
+        label = self.data_info[i]['label']
+        name = self.data_info[i]['name']
+        # print(self.data_cache[fp][cache_idx])
+        return self.data_cache[fp][cache_idx], label, name
+
+class GenesisDataloader(data.Dataset):
+    """Represents an abstract HDF5 dataset. For single H5 container! 
+    
+    Input params:
+        file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
+        mode: 'train' or 'test'
+        load_data: If True, loads all the data immediately into RAM. Use this if
+            the dataset is fits into memory. Otherwise, leave this at false and 
+            the data will load lazily.
+        data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
+
+    """
+    def __init__(self, file_path, label_path, mode, load_data=False, data_cache_size=5, debug=False):
+        super().__init__()
+
+        self.data_info = []
+        self.data_cache = {}
+        self.slideLabelDict = {}
+        self.files = []
+        self.data_cache_size = data_cache_size
+        self.mode = mode
+        self.file_path = file_path
+        # self.csv_path = csv_path
+        self.label_path = label_path
+        # self.n_classes = n_classes
+        self.bag_size = 120
+        # self.transforms = transforms
+        self.input_size = 256
+
+        # for 
+        # self.files = list(Path(self.file_path).rglob('*.hdf5'))
+        home = Path.cwd().parts[1]
+        with open(self.label_path, 'r') as f:
+            temp_slide_label_dict = json.load(f)[mode]
+            for x in temp_slide_label_dict:
+                
+                if Path(x).parts[1] != home: 
+                    x = x.replace(Path(x).parts[1], home)
+                self.files.append(Path(x))
+                # x = Path(x).stem 
+                # x_complete_path = Path(self.file_path)/Path(x + '.hdf5')
+                # if x_complete_path.is_file():
+                #     self.slideLabelDict[x] = y
+                #     self.files.append(x_complete_path)
+
+        for h5dataset_fp in tqdm(self.files):
+            self._add_data_infos(str(h5dataset_fp.resolve()), load_data)
+
+        self.resize_transforms = torchvision.transforms.Compose([
+            torchvision.transforms.ToPILImage(),
+            torchvision.transforms.Resize(self.input_size),
+        ])
+        self.resize_transforms = A.Compose([
+            A.SmallestMaxSize(max_size=256)
+        ])
+        sometimes = lambda aug: iaa.Sometimes(0.5, aug, name="Random1")
+        sometimes2 = lambda aug: iaa.Sometimes(0.2, aug, name="Random2")
+        sometimes3 = lambda aug: iaa.Sometimes(0.9, aug, name="Random3")
+        sometimes4 = lambda aug: iaa.Sometimes(0.9, aug, name="Random4")
+        sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
+
+        self.train_transforms = iaa.Sequential([
+            iaa.AddToHueAndSaturation(value=(-13, 13), name="MyHSV"),
+            sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
+            iaa.Fliplr(0.5, name="MyFlipLR"),
+            iaa.Flipud(0.5, name="MyFlipUD"),
+            sometimes(iaa.Rot90(k=1, keep_size=True, name="MyRot90")),
+            iaa.OneOf([
+                sometimes3(iaa.PiecewiseAffine(scale=(0.015, 0.02), cval=0, name="MyPiece")),
+                sometimes4(iaa.ElasticTransformation(alpha=(100, 200), sigma=20, cval=0, name="MyElastic")),
+                sometimes5(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, rotate=(-45, 45), shear=(-4, 4), cval=0, name="MyAffine"))
+            ], name="MyOneOf")
+
+        ], name="MyAug")
+        self.val_transforms = transforms.Compose([
+            # A.Normalize(),
+            # ToTensorV2(),
+            RangeNormalization(),
+            transforms.ToTensor(),
+
+        ])
+
+
+    def __getitem__(self, index):
+        # get data
+        img, name = self.get_data(index)
+        # out_batch = []
+        seq_img_d = self.train_transforms.to_deterministic()
+        
+        if self.mode == 'train':
+
+            img = img.numpy().astype(np.uint8)
+            img = seq_img_d.augment_image(img)
+            img = self.val_transforms(img)
+        else:
+            img = img.numpy().astype(np.uint8)
+            img = self.val_transforms(img)
+            # out_batch.append(img)
+
+        return {'data': img, 'label': label}
+        # return out_batch, label
+
+    def __len__(self):
+        return len(self.data_info)
+    
+    def _add_data_infos(self, file_path, load_data):
+        img_name = Path(file_path).stem
+        label = Path(file_path).parts[-2]
+
+        idx = -1
+        self.data_info.append({'data_path': file_path, 'label': label, 'name': img_name, 'cache_idx': idx})
+
+    def _load_data(self, file_path):
+        """Load data to the cache given the file
+        path and update the cache index in the
+        data_info structure.
+        """
+        with h5py.File(file_path, 'r') as h5_file:
+
+            tile = list(h5_file.keys())[0]
+            img = h5_file[tile][:]
+            img = img.astype(np.uint8)
+                # img = self.resize_transforms(img)
+                # wsi_batch.append(img)
+            idx = self._add_to_cache(img, file_path)
+            file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
+            self.data_info[file_idx + idx]['cache_idx'] = idx
+
+        # remove an element from data cache if size was exceeded
+        if len(self.data_cache) > self.data_cache_size:
+            # remove one item from the cache at random
+            removal_keys = list(self.data_cache)
+            removal_keys.remove(file_path)
+            self.data_cache.pop(removal_keys[0])
+            # remove invalid cache_idx
+            # self.data_info = [{'data_path': di['data_path'], 'label': di['label'], 'shape': di['shape'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+            self.data_info = [{'data_path': di['data_path'], 'name': di['name'], 'cache_idx': -1} if di['data_path'] == removal_keys[0] else di for di in self.data_info]
+
+    def _add_to_cache(self, data, data_path):
+        """Adds data to the cache and returns its index. There is one cache
+        list for every file_path, containing all datasets in that file.
+        """
+        if data_path not in self.data_cache:
+            self.data_cache[data_path] = [data]
+        else:
+            self.data_cache[data_path].append(data)
+        return len(self.data_cache[data_path]) - 1
+
     def get_name(self, i):
         # name = self.get_data_infos(type)[i]['name']
         name = self.data_info[i]['name']
@@ -248,6 +434,45 @@ class HDF5MILDataloader(data.Dataset):
         return self.data_cache[fp][cache_idx], label, name
 
 
+class RandomHueSaturationValue(object):
+
+    def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
+        
+        self.hue_shift_limit = hue_shift_limit
+        self.sat_shift_limit = sat_shift_limit
+        self.val_shift_limit = val_shift_limit
+        self.p = p
+
+    def __call__(self, sample):
+    
+        img = sample #,lbl
+    
+        if np.random.random() < self.p:
+            img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) #takes np.float32
+            h, s, v = cv2.split(img)
+            hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
+            hue_shift = np.uint8(hue_shift)
+            h += hue_shift
+            sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
+            s = cv2.add(s, sat_shift)
+            val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
+            v = cv2.add(v, val_shift)
+            img = cv2.merge((h, s, v))
+            img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+        return img #, lbl
+
+def to_fixed_size_bag(bag: torch.Tensor, bag_size: int = 512):
+
+    # get up to bag_size elements
+    bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
+    bag_samples = bag[bag_idxs]
+    
+    # zero-pad if we don't have enough samples
+    zero_padded = torch.cat((bag_samples,
+                            torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
+    return zero_padded, min(bag_size, len(bag))
+
+
 class RandomHueSaturationValue(object):
 
     def __init__(self, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), p=0.5):
@@ -285,48 +510,81 @@ if __name__ == '__main__':
     train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
     data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
     # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
-    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_Aachen_PAS_all.json'
-    output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/'
-    # os.makedirs(output_path, exist_ok=True)
-
-
-    dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=6)
-    data = DataLoader(dataset, batch_size=1)
+    label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments'
+    os.makedirs(output_path, exist_ok=True)
+
+    n_classes = 2
+
+    dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20)
+    # print(dataset.dataset)
+    a = int(len(dataset)* 0.8)
+    b = int(len(dataset) - a)
+    train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
+    dl = DataLoader(train_ds,  None, sampler=ImbalancedDatasetSampler(train_ds), num_workers=5)
+    dl = DataLoader(train_ds,  None, num_workers=5)
+    
+    # data = DataLoader(dataset, batch_size=1)
 
     # print(len(dataset))
-    x = 0
+    # # x = 0
     c = 0
-    for item in data: 
-        if c >=10:
-            break
+    label_count = [0] *n_classes
+    for item in dl: 
+        # if c >=10:
+        #     break
         bag, label, name = item
-        print(bag)
-        # # print(bag.shape)
-        # if bag.shape[1] == 1:
-        #     print(name)
-        #     print(bag.shape)
+        label_count[np.argmax(label)] += 1
+    print(label_count)
+    print(len(train_ds))
+    #     # # print(bag.shape)
+    #     # if bag.shape[1] == 1:
+    #     #     print(name)
+    #     #     print(bag.shape)
         # print(bag.shape)
-        # print(name)
-        # out_dir = Path(output_path) / name
-        # os.makedirs(out_dir, exist_ok=True)
-
-        # # print(item[2])
-        # # print(len(item))
-        # # print(item[1])
-        # # print(data.shape)
-        # # data = data.squeeze()
-        # bag = item[0]
-        # bag = bag.squeeze()
-        # for i in range(bag.shape[0]):
-        #     img = bag[i, :, :, :]
-        #     img = img.squeeze()
-        #     img = img*255
-        #     img = img.numpy().astype(np.uint8).transpose(1,2,0)
+        
+    #     # out_dir = Path(output_path) / name
+    #     # os.makedirs(out_dir, exist_ok=True)
+
+    #     # # print(item[2])
+    #     # # print(len(item))
+    #     # # print(item[1])
+    #     # # print(data.shape)
+    #     # # data = data.squeeze()
+    #     # bag = item[0]
+    #     bag = bag.squeeze()
+    #     original = original.squeeze()
+    #     for i in range(bag.shape[0]):
+    #         img = bag[i, :, :, :]
+    #         img = img.squeeze()
+            
+    #         img = ((img-img.min())/(img.max() - img.min())) * 255
+    #         print(img)
+    #         # print(img)
+    #         img = img.numpy().astype(np.uint8).transpose(1,2,0)
+
+            
+    #         img = Image.fromarray(img)
+    #         img = img.convert('RGB')
+    #         img.save(f'{output_path}/{i}.png')
+
+
             
-        #     img = Image.fromarray(img)
-        #     img = img.convert('RGB')
-        #     img.save(f'{out_dir}/{i}.png')
-        c += 1
+    #         o_img = original[i,:,:,:]
+    #         o_img = o_img.squeeze()
+    #         print(o_img.shape)
+    #         o_img = ((o_img-o_img.min())/(o_img.max()-o_img.min()))*255
+    #         o_img = o_img.numpy().astype(np.uint8).transpose(1,2,0)
+    #         o_img = Image.fromarray(o_img)
+    #         o_img = o_img.convert('RGB')
+    #         o_img.save(f'{output_path}/{i}_original.png')
+        # c += 1
+    #     break
         # else: break
         # print(data.shape)
-        # print(label)
\ No newline at end of file
+        # print(label)
+    # a = [torch.Tensor((3,256,256))]*3
+    # b = torch.stack(a)
+    # print(b)
+    # c = to_fixed_size_bag(b, 512)
+    # print(c)
\ No newline at end of file
diff --git a/datasets/data_interface.py b/datasets/data_interface.py
index 12a0f8c..056e6ff 100644
--- a/datasets/data_interface.py
+++ b/datasets/data_interface.py
@@ -8,6 +8,8 @@ from torchvision import transforms
 from .camel_dataloader import FeatureBagLoader
 from .custom_dataloader import HDF5MILDataloader
 from pathlib import Path
+from transformers import AutoFeatureExtractor
+from torchsampler import ImbalancedDatasetSampler
 
 class DataInterface(pl.LightningDataModule):
 
@@ -56,9 +58,10 @@ class DataInterface(pl.LightningDataModule):
                                                 train=True)
             a = int(len(dataset)* 0.8)
             b = int(len(dataset) - a)
-            print(a)
-            print(b)
-            self.train_dataset, self.val_dataset = random_split(dataset, [a, b])
+            # print(a)
+            # print(b)
+            self.train_dataset, self.val_dataset = random_split(dataset, [a, b]) # returns data.Subset
+
             # self.train_dataset = self.instancialize(state='train')
             # self.val_dataset = self.instancialize(state='val')
  
@@ -72,7 +75,7 @@ class DataInterface(pl.LightningDataModule):
 
 
     def train_dataloader(self):
-        return DataLoader(self.train_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=True)
+        return DataLoader(self.train_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=False)
 
     def val_dataloader(self):
         return DataLoader(self.val_dataset, batch_size=self.train_batch_size, num_workers=self.train_num_workers, shuffle=False)
@@ -106,7 +109,7 @@ class DataInterface(pl.LightningDataModule):
 
 class MILDataModule(pl.LightningDataModule):
 
-    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, *args, **kwargs):
+    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs):
         super().__init__()
         self.data_root = data_root
         self.label_path = label_path
@@ -121,41 +124,74 @@ class MILDataModule(pl.LightningDataModule):
         self.num_bags_test = 50
         self.seed = 1
 
+        self.backbone = backbone
         self.cache = True
+        self.fe_transform = None
 
 
     def setup(self, stage: Optional[str] = None) -> None:
-        # if self.n_classes == 2:
-        #     if stage in (None, 'fit'):
-        #         dataset = HDF5Dataset(self.data_root, mode='train', n_classes=self.n_classes)
-        #         a = int(len(dataset)* 0.8)
-        #         b = int(len(dataset) - a)
-        #         self.train_data, self.valid_data = random_split(dataset, [a, b])
-
-        #     if stage in (None, 'test'):
-        #         self.test_data = HDF5Dataset(self.data_root, mode='test', n_classes=self.n_classes)
-        # else:
         home = Path.cwd().parts[1]
-        # self.label_path = f'{home}/ylan/DeepGraft_project/code/split_debug.json'
-        # train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train_small.csv'
-        # test_csv = f'/{home}/ylan/DeepGraft_project/code/debug_test_small.csv'
-        
 
         if stage in (None, 'fit'):
-            dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes)
-            # print(len(dataset))
+            dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, backbone=self.backbone)
             a = int(len(dataset)* 0.8)
             b = int(len(dataset) - a)
             self.train_data, self.valid_data = random_split(dataset, [a, b])
 
         if stage in (None, 'test'):
-            self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes)
+            self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone)
 
         return super().setup(stage=stage)
 
     def train_dataloader(self) -> DataLoader:
-        return DataLoader(self.train_data, self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        return DataLoader(self.train_data,  self.batch_size, num_workers=self.num_workers, shuffle=True) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        #sampler=ImbalancedDatasetSampler(self.train_data)
+    def val_dataloader(self) -> DataLoader:
+        return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers)
+    
+    def test_dataloader(self) -> DataLoader:
+        return DataLoader(self.test_data, batch_size = self.batch_size, num_workers=self.num_workers)
     
+
+class DataModule(pl.LightningDataModule):
+
+    def __init__(self, data_root: str, label_path: str, batch_size: int=1, num_workers: int=8, n_classes=2, cache: bool=True, backbone=None, *args, **kwargs):
+        super().__init__()
+        self.data_root = data_root
+        self.label_path = label_path
+        self.batch_size = batch_size
+        self.num_workers = num_workers
+        self.image_size = 384
+        self.n_classes = n_classes
+        self.target_number = 9
+        self.mean_bag_length = 10
+        self.var_bag_length = 2
+        self.num_bags_train = 200
+        self.num_bags_test = 50
+        self.seed = 1
+
+        self.backbone = backbone
+        self.cache = True
+        self.fe_transform = None
+
+
+    def setup(self, stage: Optional[str] = None) -> None:
+        home = Path.cwd().parts[1]
+        
+        if stage in (None, 'fit'):
+            dataset = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='train', n_classes=self.n_classes, backbone=self.backbone)
+            a = int(len(dataset)* 0.8)
+            b = int(len(dataset) - a)
+            self.train_data, self.valid_data = random_split(dataset, [a, b])
+
+        if stage in (None, 'test'):
+            self.test_data = HDF5MILDataloader(self.data_root, label_path=self.label_path, mode='test', n_classes=self.n_classes, backbone=self.backbone)
+
+        return super().setup(stage=stage)
+
+    def train_dataloader(self) -> DataLoader:
+        return DataLoader(self.train_data,  self.batch_size,  num_workers=self.num_workers, shuffle=False,) #batch_transforms=self.transform, pseudo_batch_dim=True, 
+        #sampler=ImbalancedDatasetSampler(self.train_data),
     def val_dataloader(self) -> DataLoader:
         return DataLoader(self.valid_data, batch_size = self.batch_size, num_workers=self.num_workers)
     
diff --git a/fine_tune.py b/fine_tune.py
new file mode 100644
index 0000000..5898e75
--- /dev/null
+++ b/fine_tune.py
@@ -0,0 +1,42 @@
+from transformers import AutoFeatureExtractor, ViTModel
+from transformers import Trainer, TrainingArguments
+from torchvision import models
+import torch
+from datasets.custom_dataloader import DinoDataloader
+
+
+
+def fine_tune_transformer(args):
+
+    data_path = args.data_path
+    model = args.model
+    n_classes = args.n_classes
+
+    if model == 'dino':        
+        feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
+        model_ft = ViTModel.from_pretrained('facebook/dino-vitb16', num_labels=n_classes)
+
+    training_args = TrainingArguments(
+        output_dir = f'logs/fine_tune/{model}',
+        per_device_train_batch_size=16,
+        evaluation_strategy="steps",
+        num_train_epochs=4,
+        fp16=True,
+        save_steps=100,
+        eval_steps=100,
+        logging_steps=10,
+        learning_rate=2e-4,
+        save_total_limit=2,
+        remove_unused_columns=False,
+        push_to_hub=False,
+        report_to='tensorboard',
+        load_best_model_at_end=True,
+    )
+
+    dataset = DinoDataloader(args.data_path, mode='train') #, transforms=transform
+
+    trainer = Trainer(
+        model = model_ft,
+        args=training_args, 
+
+    )
\ No newline at end of file
diff --git a/models/TransMIL.py b/models/TransMIL.py
index ce40a26..69089de 100755
--- a/models/TransMIL.py
+++ b/models/TransMIL.py
@@ -86,7 +86,7 @@ class TransMIL(nn.Module):
         h = self.norm(h)[:,0]
 
         #---->predict
-        logits = self._fc2(torch.sigmoid(h)) #[B, n_classes]
+        logits = self._fc2(h) #[B, n_classes]
         Y_hat = torch.argmax(logits, dim=1)
         Y_prob = F.softmax(logits, dim = 1)
         results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat}
diff --git a/models/__pycache__/TransMIL.cpython-39.pyc b/models/__pycache__/TransMIL.cpython-39.pyc
index 4e1ddff6d6f3cbadfd7f1c0c4686a7806f5896e8..cb0eb6fd5085feda9604ec0f24420d85eb1d939d 100644
GIT binary patch
delta 953
zcmZpbYLMbh<mKgJU|?Xdf100ExRLiEBjc^f?-?gBT1{?Xl9ktWNlh(qFG@@+(a+6K
zNzEzN4=GB_EB5vD(JQDd;+wpQsaRfsfq|ijpMilvlc`9Mfq~%`V?~iDNDNGfO?F{U
z*5m{6B|wtA@$tzyiN(dqsX00E@kk1VK}y6nUts286yOE1<v|2kqs-(tEJ{)!vx_)D
zf?%Z((?MLm$r`Lr7!@b;vT1Rvfdte+gvMkiHZ?Y|Ch^I6Y$gI=4PZhHBrh>}6PuB%
zB#0>mBBVhC*gB9qi{wCDuxc;?wi2ul<SbbR28Lo51_lNWMn1M8`Nf*-j5?Yi4T_*(
z5G?|`-vT56HV~xc7H2_zaePi<Wol88%;ZoGeO^6KaIh97=B4BnnN6O^p`rsa795hd
zSmM)?jiBLKWDhb3>;SM+z^>q%{E)+&QE0L*=VOT)riF~PoF!~o0yUf|OexI0OtoAk
zf>}b7^|(~nKrSlMoE*WW9wB^-H7zGUv7|@~WHWPSUdb({yn<US$*Gw+w^)h`i%N<>
zZYt6T*=Gh)!<Jf6keHWpizT@<C9%i>ByI^J96<yq%8Q&B7#Jo$;1W*<1@tYp?DE8-
z^x|8L8AXN+3=C0B@g7l3@!?Tii76@ZIjMQ+B^gD=AcMJ*bBg0j^0QO(ii=!9hPi_X
zH;@!hQEG8%PDycmN@j8iQu+XUTViq?w^}_HBOfClBM)Ph06aCO<rkGF7Nr!KgUkU3
z6PN&p8dxVNHbB0EMjH=r5lG|~M|^y4VrE`^ye4xID5x?E#2FZhKv7%728v5YaOi6?
z7CD0q20PUaoKipz)|3PX*e#Bd{G#Lxy}Z04naOfIhKyd519&vUK(P+CqzIhI>_7(i
sfe4UMw^+gX=@x525hxy5Qp*x^ib6rM5IaCDkQGP)0QQjB<nugc0Nnw}e*gdg

delta 915
zcmZpWYL((m<mKgJU|?Wa!k&=Su#xv6BV*L$_ly%54JJ1*$%^Xd=BK3Q6zhi+CFT|T
zdiv-UR2K10-osQZ&(FZXP{hZ;z@W)gB*4JHaEq~`NCYGXCPXLuF(+&Cg81SfN#6MQ
z<ebFf;^fqvocMSog+d@DqML6q^DqkVfY@>%0<2Mb@;4SGDUjJk93VlkQi$mwF7IR`
z)+dY#lf~JzxK%*{Y9K;=vJaaY8(5Rr<O((u0k8%zAqtWgpS*|7$W;Qwlmrn{AOdV1
z$el&9ATC%nm;hS|)(3JHNV1rPfq{X8k&mrNZt_QFk;(4tZ*(+3niN1mAzB3XgE>e5
zY$!<EEzW}c;`p4z%G9DF>B;dN`n<ZJAYm;^%uC5DGMzk|Lq$ghqyrS1w^-uSl8vBY
zT4V<j20H@m9I#7xC%@+KW|Ww0&iPobgl&O94bwu#g^abFC2U!OHJm9-Da^f0wOl1a
zS;CWzxK!9c?kdul9K)r~C31^3Ehj&*q)2meFPCIJ$Q4C;ARA0U?qN%<C`impxy6!P
znvz&#4-&Tk5e^^%<g_A3kTUk-%=Fy+%#>SfiACwTi50gP-EOfL=ckqACRS+jLXrYo
zc6nk^dhspBj3NUD28Jl6c#kNi`0yyM#FUiyoYcJZl8hoFTaam7$vMUGCHdK@dBsI8
zAdB2Ugeypjrzo|!G^eCEJ|#1`1S!RU13`RpB)1x)z~lyQsd_Vz7I65032<<M6@j7w
z<UeSf@$eRbL~e1!$LA(y=EcWrG8chDC$m7DfuRT#uSIO2NMr;DyC!3i6Ubn&t3VDa
z0y#}nvIrDxw>V1ji;^?+^74wLCrk1eGI~z-=FyY_#WvWEB5(o&`LM_rWQfz`J|0<l
QRHKl*1a`OR<WoFm0Jyll!vFvP

diff --git a/models/__pycache__/model_interface.cpython-39.pyc b/models/__pycache__/model_interface.cpython-39.pyc
index 3e81c0ccdd33e6fe68a5ffe1b602990134944c80..0bf337675c76de4097e7f7723b99de0120f9594c 100644
GIT binary patch
literal 11282
zcmYe~<>g{vU|<kCG$qN#fq~&Mh=Yt-7#J8F7#J9e?HCytQW#Pga~Pr^G-DJKn9m%=
z45nG4Sim%E6f2l!i(&`U98sJp3@J=GT)Es)+>9W-%sD){yit5$HcJkFu0WIkn9Z6a
zm@5<|1ZK142<M7KiGbPcIik5@QDR^=M~-N&c$9dqM3h9XWRxUWj4?+lN(yY3bgoR4
z3{*@uSB8m!ks((uN**er5T%g9kiwZGm8%$~2o+O`QcmGY;m%RXRgF>wi}B>B<*G-i
zgW0?}8o8QLnqW3xj#jRAly<I8ln$8BpQD?r7p2F@;LeaDkRsT^kRq6>oTcB)9A)6n
zkRp^K+`^C|oGO!L*vuSdl)@Oypegba6taGrjJH?<5=$~P8E>%_R2JmqmuNEH5=bse
zO)N=`&o3y+%+0JyEz)GX#SIn9$uBO}WV|JuT$Ep29G{w3Qj}j%8DElLlngRM(y_QQ
zH#fDUC^Ok7zqmL!GcP?S)h972HAj>AmK2x~oSUCtl98F0?hG=)6=J|GvEcl)l90rr
z^wbg*aelAj6l4KS##^jlXT4-(U|`T>yv13TSrT8Io0yXW@<MKYN@@<seQZUk#d)bE
znoPGu97{{`-BJ@vN{domD@uwIlS}f8ZgGTVhWLUMYck&A@ky*qEpjhPOmR-kO$G%E
zGG>NyigOqk7*ZLc7*iOcnA#cA7*iNim|8fZm{XWjSXvmOSW;M1*jgB(SX0<jI9eE@
z*itxCxLO#Z*i*Pucv={uI8u01_*xjEI8*pj1X>uPxKadDgjyJ)xKo5vL|Pc4cv3`D
z#9A1lc-t9R7^3)s88pRj@qwM_nOBlpl$MyBnheSSQ0Fl4F)%RjGB7Z3gEB-4GXp~j
zLkVLILo?F?riBcQ3^fc3n4vTah%RBRVQOZqVTfm|VXk3_XRl$*;;3PW=cr*xVU%R3
zVTk8UVM<|cVJP9MVaVcWW^`cytL6srQdqKCiY}D!WbuM|tSM~0Of?Mgyd^vf_)2)P
z_`#y=Akh+@ECDc&qnELSCrc2_=7h=!rEr4zT)j*sJXyjhJP`yhooyjwEo)u~$i5VC
zsG{1I!jb~FD}^NmYEueJ3X(l3EGckXAnbHjkf|j+S)yP!aP)%0Lac^0OB@tGV4eh+
z$D6_f=1JDDr0_{HWJ%R9#7n2}*D%D()Uc!ofWt~QMX-h;UJk-f5z1zoz*sb=M1Fxn
zits{^xr!x9aJKRU#=Ij?l_J?p6PSvmN>s8`L9yM-Sj$$TR>QVHeIY|LV+uoxcq@}6
zLoIuWMwVs`dmlqBM-9gUt%VFO46!`5oGB7DoblQd7z^)|Xe`jI;V98qppzoGkg-HJ
zMXH&xL}P&-$W>W#3z<qZ7U(ZzSjbw#5D&854QjhI*!Bq}Dp>|l+qp{AYPjIG%b?lL
zUBiuGJ5ym4*miCt+hw7)%hfP|Y+uL%vN!Jn)K+=0t#3*Uvy8HgQxsAZ!FDs(^3?EP
zF?|QvA4(9DYIw3t;!VN!@j(3qN_Wb6B~Y`KvzaC^7bT=f)bfH<m1ty{rKqH+rl_T;
z_cE7gEHGclP@<7#k;0O~lA_Ve1eRA#(ZngQn}X&Sh#3nxK|V@h2xidK^1H>Hlb@ap
zDt%x%k%@tUft7)QAsCbs#TXbE(iv(PV)<%8G0Kp_IEQH>Qz1(*!%C(ikS<N8Tdc|X
zrFkW{*pu^%ic*tHRx+$)y2YetaEmdsNP>ZZ0ZjaI)X&JzP1Uc=NzBtP&o9a@E=WvH
z)ek93%q#Zw^wD)mO)YRQN=z%!hm?5`Dn1idR_GN}-r};!$t*5O&d)8dGiG35_zZGz
zl{RXbq8A^ZnU`4-AFpSVlb@WJQ*5V)(Df4JwObtJshQ~+C8;TzOt-l6N^|3ra}tY-
zQ;To0CzhoaC8np|V$ID=F3Q(rEK+1(V7SGan+W1D6)Ax{0~TT_OU#J}70?AmscD%N
zw^&P3i%a5fv7{uHB!WZb7DrNIa&}UFUg}E*CI*IEEGe0J`BfsIQZFe#KU*Kf)h){`
zNisCM#Q`bX4J|@68E<hF6s49FC1&QOro8<B|NnotA|6o1z*M9MvRn(~9;kX#gO`8*
z|NpNEk$TC%!oW~u2$JVXO-sv6&P>fK$xAJ%Qt=D(bn$f52L-TCYGP4dW?s5$MPhD2
zPHM3gxSq=>O^TAuE6YsDOpHg?8eg87l3JV^pJb2;N)NZ#iZemFUV>_nmm&-d3@<^c
z1XP!1rxxX<=EN6gR;7XhrMRRhGbQz<5y;_|w^)ly3sQ@Ugh5W@fW^)&k>bR%)cB0b
zg4Cjd#G=I9)RNSq;#*uf`H3m<;1o~<GW-?`sD8f139ceRHTNyfJa}|*<m9Kv7bKQs
z++r<C%u7!#(g0b(1NIZ7&Mv;i>KGd2?|e(Z7hH!s=jY_4CYNO9=iTCPOinH>N=&Z2
z#pRrzk(%e7SWu97i`~sII61#4^%h%DYI0&u&MnS>qSWNfVvr8*;DXfTOi;*{RNmru
z&d*CL1&RA6mK0@H+!9EJ3B>0@L|Bt^^7B$}@s=g#WTwQ!9K-`oDNtS%4=BOJ7iW~F
zrRAjFVuvz{K$)&c6YQJ((vtWzNNrzSBmffPPb<pLjYlMtTY^ye_*7VPphys;kOLf!
z@o6PRpt`k42PDP<%0ag{k~1=MQi@XZZZYNM-Qo&PEi6sVE6GgExy6MP8N5ZQg{7HA
zsm1Z>MTsf51RYZn3raG}QXR|E1M>58j8bl~gk`3d-(vI0%u7uyy2aw|>JwU|53+_8
z90|9WGfI;{QJ9!=iwm5|GSf0si*7NcCEsFm&d)0Y84;A~6MBo&H?aa@rY7Glmg3Z$
zv|FrT*B9A=lD$3H%<ROX^x|9W(6EE}hdnnxr8FnCIEu9>wb;<&7Ax5OU`K$3Obv=a
z$uLR?EsQer(pNIx;*5tRrucYp8C4_$vdo%+fk72i^qm8hFN}PQ9E=={T&!#$l39$I
zi&2P?hmq+&3lk3`2bgAKWckm+ghYPl;1yux0?TnRvVmzN@+TXo03#nGA2Sal8zUDJ
z8zUbh4<kr~hf#!?1*Dpphmnhsg^`JogOQDqhmq|!7dsyl50eN;>^}>t>>mzxKCp|p
zn1n#`jBNinSoj!O7+Dzkz`7wM*qrZG;y4P+WKf<4c^8C1tuRpLa|Q*%B?bnD8pasr
zTIO1oTGkqt1q=%rYFTQS7cizUE@TArm}?lam{ORsnTnih7_ykt7%~}Zn3pj2G1jux
zuw}8-vgL7sT5mPX3&A1_SV1~#*=iWF*lO5Pn0py(*=yKSSR_I11NIbFNrqa^8uk=6
zNrqaGKK2@p8kU7jwcI5f3pi_77BbfI)G%am)$m9%q_DFxNHU~wm@_al*f79FIYFW&
z+%;Uyj5R#jOi)?w8qON78kQ8UIV`~pn%sWiV!BA1fq`KqBc$YD%gImAEGfRl78zeq
zl%I5qH8MUUu>_JCL4|OUIH*8lfh5TyP*N!dr^_60e}Lx}Yg$fzV#zJ`;=<C@R8TJi
zq$CuiL@+0@xFkL!GbJT8FTS`Wu_W~tXK5ZxT_#8!IGrbe6tg83rROGAK-|QfnOCC8
zTa*EkU`<LaNzS;%38vyRQy@x;N<eaq@wYgjNxb+LPf==dX--LTd`f0=Nl_ulc(`-H
zuD`_?2}$ta>KT;kZ}EaFoy@%S_~Me(f@z?r2UT|r%v_8djC{;oi~@`tOk9jy%sfn>
z{J_D?!^y=Y#wf>BC4-ixPz*{2)wrNU1;U&R3=A9$3=H6e^@jn}#0Mp#6s8u28gK)B
zA!9953Bv-W6lN4L<`RYlj43QAVk{*LS<E%i=5#Gf2}=z_7AvTkAM01k3Qk(Apadex
z08MBh9tWr~TEmvYnaxzxl)_NUk;hZQUc=tZSi>gGP|H!mk;1irvxc>XqlUePZ6QlD
zBO^m$ND4Q|UWEDJZUiXta;5NyF@VjwQNxhMUBgtv0_qa=GBbjbA`hsEy?_@ax`1yX
z!$PJM0WcKoWd-$cgBdi1{E|WKCs4qEFesov{ULCx3sk#;TU|W0j5UlY43Z4cHWy<N
zI3_{ywvwqx3LNznXmMF2%fP^}9~A3y3=9larnurKwIDw^BR(}R1zS^zEVqI31H>`S
zAnQS{gm6Jg4D74}AaRhh$TA<4)FB4qGXD_Fd~maACF3pTl+0X676#>BaHjty53Y+8
zs%xzj5|fh^s%v!=piNhW>RSEkS`4>=N<(nGfGTEiykNA3i$p;-gMBOsXMwqJzn=k#
zgZ%yyG%R$B6`WC4GTve>E=f#Jt>O#;HT(1vOOr!ZGT-7TD9TTdPf9Gh#gU&E500u^
zpaLpAwMdh{2wYGwCnuJGH792mRf5cR0ui7AC2-kMR0>KRtOZ4xc_p{l((;RP6H9LK
zCYKf!rRJ4DO%wq&q0n0{MWEDI<N`9R3{>KAf{ULhQLuYqB}9CFX-Pq8N%1XKP+3`g
zi!(Q|0$gks-(o8Pk9*u=OhyF!E#}1JWJuJ(1Lc+w+(fYdKuK{gsNeyG5d)J5qX?4-
zsG<a84n{U69!4=nF%~{X0j4TNf{7ZXN&^+&phyIlJWGhG(iSkmD=cPsg~d|Cmcj(B
zu$ZA07O3I^Ral^ki>rin0UNlI0u3~<*KkWRq_BaiqZD>%HN=nw5#<1h)^IK0SjbSr
z4X>m)pp_ISxRT<6R#F9^s6Z>mK`9ZO#=*7HN*45D9+Xv9a)Q%95hxamKowC@1;`|D
zeFI8CkW>S1HM3-vq~@;V0mpPvHAqzrhyclfE105M5VsCQ)Po2#*EE2{z!3{3KuP`<
ze;KIZn+R%cgX@&5pr{3vSPV>DOdO1Sj9iRd;5HNoBL^$AE|EivRx~3}Y7<Z<g4ZT#
zjG)>i3o?iWtxZ4!S-ni)8ig^16;z7DYZ*|D!j!^>BE|x)W!O=~Si!XnYYImSOD_|+
zUSUb$1Pwjaur@QLaDgE=s6?+}$YM$1sbN?E8jfOJzz#0)QRO*m7(fjekX#MxLWpi&
zXl=n&!(79f!k5Ai;+KHZ4iC643jo(%yfti+44|Q!TJ{=-EC^2!n);Z*gEdSk47Hqj
zJSBWJ9L<b1?9vRioF)7zLJI_H*lIXI^%=A_yObgf@*lzmaBT(}1`|x-0-Kan!;mGE
z!XXYCLK7+xULXSIi-<5RWQ5jRqGZ%tnv6x@6a!8jxXSMLphN`9`&H(+lL5S>z*cF2
zi}#{-1_lO<BAj5=TGRxx9<KvGf>eMUNQUXHAia1^{{&J&rs-`Uy?9Ok0#X4o9b9Nc
zaexLeKz-|wqCN%&hA4IrKe04flc{I|NaaLO-N9Uzm{ZgP;(|*Xa0S|lEpS0aE9Lbg
zKO+M}5!i{AplU}HT+0<r0@*(qM1bo!aNWcP;=(JaTi}5Ja9h_OBrzRCV5tj>rhp`-
zf`~v60rJ!>wxYzml>A&!r!zH0lM7q`MhSpI1Jn+N)pC&3I|HO^CWx2?B4&dKkRjkA
zxM&WD3yuRY0WJeYkP0w(%6$ha2tmDZ21W^3EypUwsK6q`$nu|!MSz75T;p-UYCL8h
zMjpm0Wg_wf?mBN7MRi^cOA1pj6S$j3w^~n=3DSxIRb-Ih2e*8|O-yj0gX(lh$j<|X
z{Cp6x5JZ53A54JbfD=680*XyBMh1prXblAF)qratNdKOLnTrwJn~^{ZX1HdQx`z+B
zC*uaLd#KfusbN^a0`4`ig6a@Zodc><7O*a40C$7gYFKO7QrK-6YS=(Ei3>xlRt!@u
zJE-?k!#<lKg`<XJHp5(IP~WA7V*%$vh6P*;8EV*S*lidJs}SXWtY0l?@`AU9Q<5Qt
z6Vlbrg7CN`8EUy}xKg;YnTqbEFx2vZ>lN;1#u_eZhFYEyjuf5+{5704JT=@kT+n)D
zLkcgBdPM-_7kJGkxIn0eEuC>8V=ZqD?*ic(hAfd30ck|tB1lHvLPWu>4J%oy46#Ko
zJZEC7Q)p7w>%bgDhUtXMdOeuw;IbZEA4YM26LMl{auu_JtpcK94yiI}SgL!$9CM2W
z<OFbiQpE)sq0mb%F4M$b&cjRhq5x1;#0qMk7K2OhTg=G?rMFn~N^=V;i?9{*;1V8G
z3EW~TNV&!7l30@H1{z_$#a5CZpIls4<Or%>_#j0+AEX}>4{=a&5o#HKi#f40xhNat
z0+#I5%HmsWWr;bZsl`QD%KlrRIoi^k67*gUAIRH~{uZc*qXjC2K;!rfjL;qqvjC$2
z3kSOZqY$GI6Qt~i_jZ^Bn0OegWbvk2j5Guq%mepygur7gH4Isdpdq?uCeUC)FvChl
zKTYN$P=lxlG_YC(>JMl#frA4g04_Vh1h@bd0d){F(@TqxX3&cRLEZqF!NA0aSvg_T
zf#OV%PEailb|$D`uVKhy$YRW5s%1=JtYIu;DAG$|S^%Aku3=oj0_uW+=5RqnM0qSV
zjI~TvQYEY<Yzx>^SZf$(Gt6aL$XLVF%v8%9&X6a<$N+*R94Tzgj7^M@40%ijpc!ON
zP?xTTA)XOD_r>l9?$z8Py0Hiv?SzcOfa4f6Ta%nvkXe$L1Dcw^)$NakWq>MOv}i<3
zc3>-bsz77gsX4`ZRr;QJ;C6(9FL<z9!7Vc<RlzSYH&r3ONC77244T9Oiz;ez-D1xG
z4}=$kBli|dUSe+QEtahO%)DEy#RWN;CAZkqQ%e#{N{ViAX66>;7nS5>Cf(wNFhTP#
zr8%j$xWYhb-L<GFzvvckW?pegVqS74*sq!_kkqmZlrgxH6LV8@;=xw1f_(@M9S8wR
zKt;9;3=H9*^aA2BFmkbQFv>8oFfuW*F>?H8Vr2Ty0UZUek_79AFA&f}gb+&g%g(?6
zn(kuo2DPM;z=?bTV+~^#(?Z59X3#nY@XRl0T$&||bpab_#5$9qmbHf2g&|h1maT@Z
zhNX<5NUH?Yonpyitzm0s0#&Ov424=XtP415m_VJK8b<I)HY;c(JD5R}x$2aHCvqSt
zBr1TXC=^PIGxO3FN-|O*qe#UC`FSZI5pWEE=e4*LkN{}<MgcU%p$Cm41yIv9Ge0k}
zs8UZMGQU(IIWbQmF{e0RAv3R_v_v5>uTlXt^?_naNxnjUS!z*vQD#Xh7Q4a4ltN}6
zibZ;G6TlUNpC;!m_RPHEg4E=aTU?;@P?Qeh6oF<yZn1!@EZPZbP=N}FqKzOHXn5up
zTWLW`Vo9ndJ0uYt1}B32l8n@%c#t0GR1bvBmYD}GH(5b4hPRj#i_#$p175ro9b{l&
z$ONSh(Ciiis~95>BMW%WMTC)y5o2^2B^{zHAOR%`P)7t@q<{vnYZw+V)G#h!T*y$v
zw2-NmrGyDIV!VK5Ap@v2RRS8rWp8GL(J73*%(bkbNe1xp60R)n8dlH}k`$I+<^?<p
zA>+urpy`P$z6Ja>%(EF%*yb|TfJ_ls2paQcPGQexD(WcVSO8ikA_QuSfeIx?h7yhi
z!ZmE5Bn2*ZIQ>9V=9zgZsTH6p^PJM$yyB23j$lwU6eZ?>Bop)U@=I>9B<Cg;fa~Wf
zL4EjqqXJ}lv1lb|Y?nhnIX50uzi5gS`GO({Jk+xm#99X;R)dH&AOc=d-eNAu0gVTO
z3-nve#d*cI*fUZSOL7wnqCkltJ`Ezno|B)Ro(h_iO|2+MEy~PI%`3UZm6(zORZs*9
z;#=&XY238Tba0(|3pC|f3{Mz{o>DI;iLiqzh1|rFDAts;_~hIumbA=t=)4;uF@UP=
za(H536kt?g<X{wGmSf^z6vC`#u#H}7^55c!2bV_i@wd3*p@ms|{4JjN_`=e}98ftG
zA72DY<3)|2keLf27J-Q6AYuiG*a9N9fe2991x}(xAgA5ph0Yizr{?6qW(~ov&jTq2
zm2cpQZqU?8F{qcq!OT<0q0S+}A<3b^A;O`~!NwuP!3S1@du5gift6X{6;-Uk44SM(
zv7qiaXzBnog#f}JHVA{-0L7rratT8-11JwLGSq;U@-TtYV9{1k3y8VM3gl=|NZevA
z&PXgsg|!0igT%o;0{ac@Es(M{P|Sd;ZU#`+<OR+C!q;1%6wn~sL9PJjMNq!3Vax)}
ziZJ#u)H2mD1v6+eR>i_gqo7o<?Zx173|i(TC+8P|OVN^iurOGK0w}emfQmIxX<L?=
zTCM<^KvyU(DJo3{Er>&uaGFfFSU?JjKz`9=1{YPgSiy$hV$LtcXr-~I<rkGF7NvB-
zy~xJHSj7YOB2-F~t>`u=VD5m3yC4D-CPj}yta~5=99Uoi9CqMf14U3P%uEhO9u{zv
zX&T>RF0L#t0%grxOnC)GM?m`bfrzD`yuwzHn3s}RTm;I=QS8O3iAniIc}15&QtY4=
z4>_4hMWFD##g&^_0$Sw=5&#8y(M8Z`4!3WmKWwdI5h%W**nBHN3mA()1>`O6(vr-a
zVm&Zf1Zssvah7D}=IVjQgNi`KXc1^!wTKhs9MC8$cp&B$2Y60jFE6hs0%V6UjH{Pc
zng?D}0a`wQ#Ab9W0!4ojXhGjC!GcP#`uH5s(i_kklv_*%Ik$w$z-uewOF+wjKvNW{
zMMa?L=Ud!h)eu?m@`WfNBoRGu#amp&4+;<-B+ZZ|65!<zQQXOi1tq1qdZ1QL5vV16
zi{A%kIiw*4&W5+xpd3(4A=ORbcmze^Ee;#VA_LGWU4~){P}d(^KY`~#895kv7<s@$
cG@xM`5oQoy2qX$-fy4xuc$kFh8R{8Q0bx(CR{#J2

delta 5773
zcmbOf(Gb9w$ji&cz`(%ZFP)h5$!a2>4AVD;i8___j5(rFqA3h1j5%Vt;!)y^3}CiI
zt~e6|BSWraloV7%I!YQWBNHW?!k)sBBbO^5B@dS4%u&cyj8X)%xpI_pm7|my8Qd9C
zxKnsq7*cprWwTV8nWI$Q8B%yt_*xiJ_)^8Q)S8*2)KeIP8JINrC+9H+P1a>nn=G#(
z#h5aA0aH9zI|B<t6mKwtrs!lrW?4QF1_lOR1_lOh1_p-WqRlqUsf_h00yPZrvNbFz
zf|3j=LfK3cn2MNEglpMSL~7YnL~A)x#A-Q9<QB*;WXMuT5tn4B<w}vL;hN2mA~~12
zmb*l;hCPKpg`=0Lma&F$fzm<-7lv4#TAmcK8lHIN35<nXN>mo8)-YzN)i7kqE@Z0Z
zt>Mj5&*IL~NRdjBPLb)YXJTZi;my*_;$EP&kb#k*L}h_?4Py=OLZ%vqc%2E1c{xz0
z$%36$SE8Gxmm-%U-^*CbSE66Tx4>W_Lo;IvLyAHxlO#hee~C(#b`5_YL#;rK0ETmz
z3Z1~_3zVoVFia6&$XH@jpQ6~z2y!DR96&BCQCVQTkYOP!$ojkwP|KCTmj5cz%`$;n
zE?A;pBM7%#8O?H`8X+u}?*m&dgs?nCrMVtrxhf=hP%JNjTCSGOG=aG&FGZ|Y7-T?+
zN|tGgdWuGhW{OrXGek#;N|sp)OA1Shb}tiHUL!>Zr@T>$3e?N0H4HLHe&z)EIfWsZ
zK~vZ77IRL1`sDR;CX?^8sd5W2FfbH>h+pQD`PpSBo3X1gYEF)2Pm_>jU|_h#nwywh
zl&{HDBn46j5{jC9n_aQ~B?A)!!!3@a#N_Ox{JhjCmXyrA{2~qp28Nd)!(Os5Ffh0k
zu`@6*{9@FvGS;umNzBtP&o9a@E=WvH)d#88&o3=0C@s-<Oia#5&5JkEPtHh9&MwH$
z%q!6=D9K0%WknE9uNGP7{;;aWUX$q-OKxgn-Yw>0rjnG&>Kt-fnoLESAcrY{+{#gu
zTAY_!Vrua6@Bjb*p;9jySQr?JbQmU|<*?u~zr|WyT98^)#66jxGl@}tauMf=dOna|
zp8V31__Wl-lG38o;vy!H5OZ>Q$}QG{#G;a7u*Yt36(klV=BAdU78T#(ElMpc%`8eS
zj!!R2Ou5CBmv@WJCo?ZKvFH|OYHm_$N_<LY?k%?R)XemZl3VO~`9-;jIq|nxk}?yE
zZ?RUT7Ujp^;tEbJEKL>5E6GgExy6y2pOTsrpH@=D3G$Ho<Oy6Nlh1QW*0Thq`h*r~
zfn3DqoS#=_lyZyHH?blhKR?GPMN{M!OL1yW+AUVF_96?AUQ3W#*6hTh^x|9Wxv3>Z
znaRbsSQ3kpGorW(5=)XZ;)^q@QlohD;&T#<OX8Dq@{_ZRql6R7(hKtQbK=2DO7gQ)
z^NMe=WH{vKrbcl?`~i~600&88?k%<)uy>2pLAJ5y=BJeAq!vf97Nr&&T7aDkW}6yJ
z?&elf1f_{0P#INZ%D})N4=SW?fZ~>skCB6sgOQ7I@;UBMCnh#VJ|-4M7Dg5(7DhfM
z9!4H!9!4%!0Y)K49wrt>CPt=zTr7NyJd7gDER0P5Sy*@&xfns>9E@y?JdA9=IVOAa
zNU?$#?31f_yct(b?v;|AyiQ7H@-<<(q#A|=tTh}pEDM=xxl7n;xSAPjxk}g<fJ7K;
zd1@H4IBR$$8B*9-86+7}*v%Q38EhEfq8uR68m=0iY$k|I3TF*>CSx#zCYN6kBLf4&
zO2#7j$!mC}!RZN{@{9OEu>pz_O|Dz4X*v0cCAZj%3rkZ|t5S<1LBYmZng`=2!PMSj
z&de*h#afbIl$>#kEwLy)H?g7!6k?jZMTsCytVy6ayu}Ho;xkhqvbR`s5|dJMZZS^I
z<&%od1(}uyBJx25$PKp`Bf<F`LV)7p7H>&WVrE`uUV40SNov7DP^`#-Vugtd6emnv
zOkB)7OrZGTVCG@rVBumCW0Yd5l9?PUEHQ~+u%3;9fdQ1`ok1E)7#J8zK*7b>%;>@p
z>sQNE!;r;P!z9U&!YIj5%M9m%%5s((<`m{^rXq_JhFaD<o)YF7mS)BpW@(06))JN!
zmIbUeOf{@EEH%svS(+Ic849^lSX0<qkj!sptYHv{X9AaK?0!X{EL_Atd4gb^E~pZO
zcu^4Kb8x}~6CfWKi83%SoCb-DF)%PxnNE%sR-C+2KtxlMt;hx>Yy~1ffn8Jp4(fs;
zP-L*BfwFIjCR<VB<kLdZvW&@48$oKA6O)r67DB9<EGsN7c?c9FAU*>l9~g2lvN7>6
ziZM=(7Pe(vfRrqg>$y`n=dd8sQBe-a4QLSv_8_P<1*fT%EJbMY6`%yj33hZ*AxJ+c
zc@-6bSm1;Oik+fT5EtaRTP&F+sktk8!09Qys0<`t4kAF3MIf_^DnZ;T5K#>x(5$Ng
ziGjQa@i-`GZ}FEU=47TMmSpDVfs@l&P!ND>4F)DICJsif$?~EOe4yBYCz&IYb46wA
zvlvrYQ<z#9Y8YKW)je|xOD|I`a|y!&rW7`W_(H~7mJ)_6rWAIBSPDcfYYjsda|%ZZ
z3n<YvGo^5XAy+ReBPcmCr*MPwJ#!7~0=9(=3(@84*=ra;MF2>$h7~H!)5{FfynwTY
z1twn0R>GCVUBkA3XCcD^-h~V$d^HSN{LPFoI)%5Fxt6^|V1XdGts|5rT*IEinZlRC
z-^;u}WFbQ>M~T1!(Hf40j3AS<#8Q|g8M4F|NPt>t95oDCk_(v`;fA^}#0J%J)-YsA
z)o@BOfZA`hTyUO{BttEC4HqZ_O<*jFOJS(x$>S-JuHkNGtl^SosO2e<NfBNkTf<qy
zQ^Q@u1<67a7z_1ML_lE&b1*2hYZ&77rD_<m<We}ZnI<q69jIZ*l1~v8XQ*L_moHIR
zpa|xRi7<d%4bE)hlNAJvO*I)I2@sS7Rx%X{f${-k1tc+m#2`gID6>5Wr4(>xGoS1)
zr93%KT#|hy<1OZt%-qR}0;U?^G77Av6QmN%0u!J#1<tN7K;mMP73EB|S27lXi!-nS
zT<TxK)myD(gd}FX>R*A?GcYhjag-(I#OLG}7l&w0wic3&y2YGWnykrGG!dj?5-5i;
z-D1tmOG&M`#h#p>Q<|Gse2XKvAT_z9C^2Ux(=FD-yuAF9D3;{h#DXdT=ls01(&Ehg
zJcZoElA_ECM3GTcH(60mWby$ym3mNqfCL+^(u#wTfuRT#fRHN1n1O*oQ>h4)pKmcI
zCzjk|EiOq+&Mx8sxe-(#7ukVWpfcbVM^1ivd`f0=Ns%W=%pDX<Oa&>oI9(D;65T*e
znfhDI$pxjiSo2DA3n~#M;4S8YoRV8?X_@JzMX9%#i}Q+av1g<vmgFWD+~P`4Es2N9
zuqNl^7pE41g60-iYDGb6QD$yxUdb)4#FP}68s6m6qN3EilK9ku{NxO9MFuIR>KPaq
z>WdmcF7pNvpd#)TTTx<ON`5Y=QInbit?;4*K=A^qq2eJe#$s^dyTw{il%G_5i!(Q|
zA|AqID@iO$Pb~q(PF`9rC~#8J;*)c4v7}|D#}_riYyp>{MNJG03_Osq2d58_$^RuK
z&7OfuEfY{B$SA=m!6?Eg!79b5z#_!R@}G@GfKi2!gHeQ$i&26ZQmk`<i*<>~Zc@6F
zGZiH3L8%r}3xevqW^mUjoiUhUC8M7va}lUfTm)*J6@dz6O(t;p3lRW20Zf35xFwRD
zp9cx#`22#B%-qbX)S_ZRMh1prkQty#rApZs)L`_?D@iR%OH5AH1L>IjUqNtky|5(X
zB4`cMh+eyBGC`sMT*#x9>Yx~aq%v^Br3DmTtstTuM1UO!CcrM_EJ-ac0T=Y5AZI}f
zdQiK8k&BUw3DU0PV4l2AT9*%$w&A7zoXNSOa`jBeh5kZDP=TL?Qs6@ica|E41<W-}
zH7pBQK!tZLE2J<7@xTo^wi?zNwiI?7h8i|dN$tWAs};jk%U;V-!d}BZn<0gxhGRCv
zTxL)!qK0Du$3lh$oC_Ih*y?N8Z5RqmN<gJ5H?+Y5DsOqf<t-<uFI>Zr1>tca%G;tF
z(9)I{skCKJ;a<RpTG|$_O5p*QwQ%!64VHMG8iob@HOvcG7S=O>UC&!0kR@2d3v(CD
zIWQVrr|^{sEf5CRDI!^-HGH5NB}Jf@d4bqMP@h<6fp`r=mc(p^6v4Snpi;U-Xn|x6
zKbR%3K(K}_opB)}xW^1C=Ot2vq`_r58@MnR1{LO-MZNIo#Z{(ifeWF@6J*WU!EN25
z?#YUBW*U72HEKgND$1F;^%GRA16B+!iNT30u{61gS;1C;%#za`W?az}1_p+ZqUj(9
z%mB6fIP{Zq<3UzbaTTQ&m*$k{B^Q@zN`p()qFEqivq1#BBrWm+6`ic0MpJPSsN5==
zp8-;^07PJEH$qA^P<dN47o=n!hyXROixz=ci$TN^5U~_wE_-oeS!!BldXY6qXc<U|
ztt3Al<RrnO#K{ZgMe5gpWY&U+bs%Cgh)4k`WXVpgEWX87mY7qTT8yoPD_Rdyy8%RO
z1ZiM{co@0d0}1niLIYAn$EW6{Xo9k|5wsK(VPuq;tf3&s4JrMYg&4V*C%Y@?O_ot?
zX9KkpZ!t_(6wt6)z_5^EB2yttFl0PIlkpZ;a$-SdNn#GD^a7OuXtsi@_gHW&P3Bdq
zU`(3arWEc9>f;n006B^)IWadiCq6GRH}w`PxU@qzz?^}BAspmZ8L(TqSU4DEz#YZS
zO3GD?jO>%=s>tYrdYqu{Z+dajeo#1o%HE=#Al6}!Eo`L)DTyVin(UC$_~hglDh49p
zat=&@3cI4?3=9kzlXX<hm^Aq&SE#GjgUl?d1)0$dB050CCJ?a&L~I2Sdq4!pfneRY
zc;n-fa}tY-lT&kY;^UD#k_A!+wgohH2l7ZUsJiE1=Ba03=1}C|=8)vj=3wCv;n3p{
zECQ)XW|*v~QDw0i)D&PY0##PvUiB^3;*7+CR9Llg9b_)paD+cV&Mu#PSz~iONLG^>
z>}61!rl7Rs7IS`S36i@&PP)aOmS0q!Sd>x?3S@BA!^Xo{#S@m9TCN9`(qt>T46^D9
zh`0(OK(-X!2C=Sz2(a~F0_+-)xkbLQ3=9lqFhe;Qd02`-BAObvn2Rfmi$EFc7E@k9
z(J7DtM?l0Xa9kB6=A|SSgL73BdvR)FQhrfh(c;NLTFQ(ICzolN2!iwQEe>$+SuZcI
zC~)#EEe%H7$sF4H^`I)~7B^U`EECi+0}r)C2_cE-frn9wi$LSRxA=WPgEyf5u`hUZ
z1049b*q|Jcdy#?*>_L#<ZgJQ^21D#XqXNZxpiVM)kV6C9lV;>#gp6x|NDy0yk%JjT
P^Dy!-3Dq;yGsFM@+q7C)

diff --git a/models/__pycache__/resnet50.cpython-39.pyc b/models/__pycache__/resnet50.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cfb660f9829304cb853031dd3d3274396afc38be
GIT binary patch
literal 8628
zcmYe~<>g{vU|?9cx;d#yhk@ZSh=YtlvJMOk48<oH7#LC*QW$d>q98P5E>jc}BSQ*P
z3Udx~E=v>(BSShv3d<tKDAp9#6t)!h6pj?m6s{ER6rL2`6uvo}QEVydDf}q{EsRm@
zDH<t)DMBrbQ5^0JDZ(itEet6lsZ3d%&CF3;?hGlSDPk=QDPm9&Zg++h@f3*`h7<{?
z2v3S+ic||j6mN=VifoEp3u6>tidKqzib4xx6n~0#ieidV3uBZ(icX4hib@M(lwgW(
zifW2l3uBZ}3S%&Xrrt}CAN^i3GB7Y`GTxHN$xlp)FD^+eNsUj*OfHE}E6UG}FD=T^
z<hsQclv?bUT5^k{D782*wZzZ@#xgd6u}lqaaYAJc4B<>uqg(u7(TWmNgLq@33X_yu
z0#E@%14EFAManJy^30Ufc(5=?nb9o)Bms~zqbkdcl9GaAD}DWx{PH}IOZ5sWOY)17
zGxYL{()DxmQ&Mw^^&uu0TIiZ48yc7;r=;o?lw^>i)!0PW*w{GDJT=*bRIR24x`vj?
z7OBRj=3uQ>Bskc>P}elYILRc%B8_CdrbfC+hDL^|DMl$^y$(d%i-;0k^W>B?6H^O_
z%N<ED0TemLMimw*x)w=^re<krro=k}DUMAI;*E4IP16#S6HOrAAj%O)CV(7bq-$)H
znrdihY6fzICT9^V0|P@58;D>B5gZ_b6GVVAbP+d*#RDRE85kInL5e{c0)#<qkQ!%D
zrk7%1U?^e8Vyt1PVN7AHVX9%NVNPMHVM$@`<p^fbWT}!dt}s?e&d)2$&nYd*%+FIO
z&n(GMC`e36$;?ZC$;`mO;HJrTi#t2DC@(c9zBscg^%h%kNl|7>>MeGts#|R7Mfs%#
z#kW|JG82n$ainDCB!Y}u$#jb;FYgwcbADc#QHmz(EzZom_=23oywu`bT=}IXP<D|3
z$Qwc+LWF^VVI@Nm$kD&7^fU5vQ}ruz67%%S^NX^J3lfu4^+Sph^NM{veRN$?Qw!XS
z64Odh^N*>4UP0w8c94gRD~u&TK>!Lv1|}Xxj{ik6ps)h@2^3Bs3}S<V7aUd~+mXWx
zR513kf<uYlu)+{Il!}<Z0mEFx4{|ox1HvGdD5yAME&|zGBnc7#OM(0Y_EeEH0|NuZ
zb%qs&av;}%3}Rp^k_EX8<RWAYb~#9q76SuADnk@w3PTiA3S$&=3P(Fb8e<Ae3Tq2z
z6bq=RZefUGZD(L%h++$7(B!<u<&;>Qne3F4pPUVLs^3dc6_5-{KQIeGY!(Iv22c<b
zPXH&A62=;aW~K#93mF(0YB*9DN*Rh|N|>`)YB_3HQy8<Ei*;()Qkb%ti%d#bQ<$3>
zn;2`@Qy8;Zid;(AQdmIz680L#8m1cNX671(c#ay58m4B(8ishz66P!}P@U2X;&GR-
z*D%#EH8a&P#Ph(_@Yb-@FvRoKfOPVwFa$GbvipJEUL?=Jz))q3@R34(UQVS#acMz8
zeo;xW0wnR-8Y(2_r6?pN7N^FSXQq^7*qWJCX}Ca>qJo`*p+a7M2~1aNib7_d0>a3U
zm5jI8GxG{^5|dLk1#YnxmlmWJ6@en?7Ds%1W?p7Ve0-4@DDt?S5=)XZ{PK%(jZ$uL
zg(c>crn(jt<rm!&_RBBv%q_@C%}vcKNlgI@f)m^=R#0j*yv3Z9XLySxDAgzQ7E4iT
zPANp(2qbQFiwjgW6eoi8gY$zX=Pj1v)SR?i9AH<0vJ4v}>wu+;bU}#=lxJ>nA-r4!
zO2N0d^74yv<8u-#Q;SwI-r|gpPtHj!E{=~!q`)F&1_lN-P{IOb2nHrTMh-?cMgc|+
z5M<(G<YMGw<YDAuWMixnMud?bEFhA>X%?D}U`Ys^gAy1R7-|@6nQ9moFw`(EWUOVb
zVOYRe!wh9H!CB0377L7(#hSuU%2?!8!?1uAte2^VxrWJ{0c-|LCvz~vN=83Twjy<q
z|204aC~p*LgIGEs0u(SSS&H;Pam841iz727HLoPIr1BPXerX9L2*BA9k`vSNi^>y=
zQY=7W0g7G*MiwS6MlMD!5DCT{j9g4rLP((i)uPD;PMk<d6vAK04oWe3iMgrq@wd3*
z<8$*<N^?MLp7{8}(!?C73}<RZL1JEUW_}(xGK#D~?&XEXRdQ-h4lKUFjx`4P4CG>P
zk_5$GF{p&+U}RxpU}X5uCj=HxhDHgT!d-rFg3Axq6gF`A!Jfi_QhspV;&RF_DalFA
zOHIyx35sEe-@rK$r4RwR0%Qxg5LrZ^5Gi4;VX9$EVU%VNVQ6NoVHIJh<t$;VVXWZ<
z7X|D!95tNa;)0`uvjzm2YglX8n^|iZ;<?~*+$C%^oHa}}3|TxP3}Ah{pdtjy2Nf@T
zCCpj;;IbRU695-0g5Y9B2wbeN_@zTaD8M14NC8w}3l}MY0@x8mfC3ks(~DF<QNo;=
zR{}1aky0Ee(;%fddyqPC9ybQ%W#b}EkU8LF2}+FM{Hn<XPLW0WASnY70j?Op@}OK?
zWC#)iW#1wruyNobtOy+M;1mFke308hKuHxO1k1J}OrWeP#K^@c#3;bX2g$VX*drp-
z)=`jY7r-(mDAU5SIU77nvxBpAMh(LPc6?d7$P^R{97sV74n0Jsb^(QsD~JGxm>Gx#
z3Ij-f28RtO1Eqq(1{6Ce*%&i3BSMEjW-hWuwFw*_VCz6>#T#TDC<kz`L-V0v5l9?&
zJ_MCB@OqpNoCn!b*jqTGm{a&covjweD3%ni6mC#%1a-Pv7^2v~wK{t+gQh?cC?~#r
z!OXzm2F{P*?C6J)A3-feaDM#32x|M`snj`Y7~)w|K)Hl11ymWc*MK67p~x<sA%%Gn
zV=a3LM-6*3V+u<)b5TSIXA0{At{V1*jG&6X29ztgYuG?lcnNbBZwgBdLlzsTC6dAd
z@<}g9hOdSLBnr;c{3Xm;0wBMEn*f3(%vnMyEGg_MOes9QpekRuhGBsSq>Y-wE6K1B
zBqmzJkOdOst6@rEl4M8$xsO+pA%(veWV;wtR-lF{g&isjQU#F}hsp}pFr~0SWkISS
zvJxfC3nWv77J~dPRl>YLIz_OCF-t~-p@w-OC?sTS7#7GaWN=}KjjQD>k*@(YV0fFE
zCNLIdlrU#0WGU8gW+{QXy(yx-%!~|qCQvmjCCpjMX-r_14j`2aRKO~k7BVq1!1<~v
zydeGr#zKc0mI;hSU*NX6Ky6b4n{WxPLp@6gq9;!dY?B69ejQw16DB`_v2Y!@loRs<
zCl7E!(PVaWb1UKlHKRO=Qo)U!ct{&Q9@N^8hc(z0iZk*{b5azNQWgC2^HLS^ixd(S
zjCE5XZ5oA=(t?~+9fkD#5{2s8m!M|zOHfO|O_K$w87v6OF`%|Ev}{FgB7@~XJ@%Ik
zpbXChE*fvKfvrz{$-v0KkjcQrz@W(lb~Tu|#h#X!7Z0ksK}CWl(=8TIclj1iUVdIq
zW?pJyQD#Y{rWU0924xROZ3?Q&A>}uBJfbxODSbhO>@DV;)Vv~9P)W;{mS2>cSW<*2
zg~dTWK(sI5rTHyR-^7Z5{QMlFlw0h%i4_I;`8l_^<8u?UQz5ObTWsJ)mm!okg3`uN
z+T@m?V@hH{NoHB9V_7=Ppv1Ctut97-VDA)xYZIol<Xh~JmP7F^uFT@hyyB9?yyVnd
zETE?7E&lAp%-qbpbg<tNbK-BYm8WK=XO!IHbO&|BK#kK|oXPoMO(pR~px6KxOhw=#
z6x6CM0=1Q(mC`NNBv3yA>dN9<+<B$BkTzp#@hzdM)S`Sy1Gp%)I5VX*F$Ww};9-xF
zjQE1oqIhrv{FV&9lm)5P7;~}JXrTJ1I0959Gea7pY)ov7Y>XVt9E?1SY>Z6*A$$Qw
zrvDsFJdAvdLX0Ae0*q|`Ihc@19!5}2$j2xKmK9-TW8`7vVH9EHVXWeZi~u3(LQs<h
zWFZKHdeR^aZopmw)rAaM47F@247Kbv3|WkzzE6=_4Fd~97Gn*IB?A*fJYx!DElUjx
zxJP7H!;r;P!&JkV#Vo>*&0LgH!m@z1gbi7=hNYRgh7G2snQ<XgEqgjcEoTiUs8)ok
zWT^q!1}ZK&YdC7ynmHCSGBUvQ)G)*|m2jmn)-X3SxiG}4)pDWeWypf)1G$hbg{hYf
zq_=QE2}>4h4QC&yC&dhDrfRZ);|{H{p~+no07~+pi~-3apfm_>1Kr{ZPAx12HJ}r7
zz^$NLY>5R0sd*^~b8fK~CFZ54YH~vwIG`%H2-Lv2#Rg7Q#o#<%1ga{*X#m`Dxg}Uo
zlv<XVUs?<=dWzCOnUN9NNdedG;CeQcfq`KosGbESSO!KmMiwyUV`O2JU}RwwV&wV9
z!pg!Z!pQNT4V*YYy%?r{Tr5>Wut-N_P?U}fsOSQf7~ps>0mnPE)4~Yuw1B!VAQlt2
z<pN?c!&xkF7Nn~Jl4XO7LE1VXF%F1W3FiVXP}R1OsTL&44H2yYjbk!FV}=L4@dIk-
z6@`J~B^*S6@<&l5h!qVAC|2<B;Vt&GoWzon)V!it(7*{3Bnr$ys=y%#CP2Y>i#HzH
z(uvQ^Eyy_u3Q*7(1p}yy1MBWUX&z9AM+6pl$l5T157f#42Op>r1nct!GeDYij76m&
zhk!yB(oBQ92He@X4RQl04j34#_@K5yEz@Kxat95sfGYrSE#~J3trOfdIf`sR#@d1i
zI}ibC6c>S#3B*27JFo~e*nwmzqS*$D@EVZKpvsT~)OliL_{V3=CBY#CZe5`aq=6bz
zAPu0zTWkXE+|)4EFr_hNGSsjxVeDh9WvgK>GALmzVX9$}WB`>2y-c<2HEau*LCqva
zh8oad43nQG<1LPYqWtut)Z$`I7O*`<prE_O1**$(K*Q9<xA;Nh!tsa^Y)y_^EQv+Q
z8Aae61<oB@1x2YPMTwbtsVPNepj5z?U7lE!UVMucY|bq%gl;6ygX5naIzW9N=6420
zAx0KP4shwp#aIL?gWW<<A{gXCP&o<qJ*ZJ#!jQsP!kEsG!nBAHG@O{mn#oYZ1PX^>
z22JKfMMee&NIdEqS}1^BqyQR)=TcBmP)JY$jn^pzL8}l2pVY*nyv)3Gg|z%41<%~X
z^i+kQ)a3m1yiCx*p^}0PY#2AOs3NlrJWO7YlBREHYG|ZqU~Ft*s#lPbW|t7p1vcHW
zD81MU%mUek2qguLr2PCGO)CY@G=-3&(o`LVqSTVoqP${-M5qG`ic)o<L8_3S2XTpC
zY6+5AkZ8j+DJ8SGASbaBWE3oJ6_OH*6!Oy)N-|OvK*Kqp7*#0AS12w?Ni8Y@n+YD(
z)nqC11f^V5VkHwe7Qj9EmCQv|AQhlkE~*Bxz?lb3fZ`Ku%nMLlf?|k)Lxhp(KUj=f
zF=%W84WQ8&M2|sGh_7S;x2B7n7#J8{f&?JkdQeP(nzm@!4?K+j6)n0z+Nc#>rUs+Y
z1)W+$k1kL*rKk~7J%UO(NJAJ~B!N^iF)}a|gJK4nr$Ay9MUoOysc2w08cR502~^c;
zB4;XaCk#|c7d7L`R3JV4Xz>LSqgs5K8jZ#mj`-pP_1_`!^%4~3FF}K)kU|h=e1Y`H
zpv4zROp_VX&;vCF!7NQi2*u<UqKS@$i;6%k9}aN)S1&KG2-MZT#addDnNtjDRI$g$
zCxVC1%RvbT)Nq0JTztVJ@!<LuQV4+Bc91$3G-V1dq>Dfqz6j)_B9PON%5iwzdyB&c
r(lfUMjbs#q`Z}PFDi1RU8xtcM{LaS9#m2=dq#<M|l+3}&#S4J|ESMn^

literal 0
HcmV?d00001

diff --git a/models/model_interface.py b/models/model_interface.py
index 1b0f6e1..60b5cc7 100755
--- a/models/model_interface.py
+++ b/models/model_interface.py
@@ -12,18 +12,24 @@ from matplotlib import pyplot as plt
 from MyOptimizer import create_optimizer
 from MyLoss import create_loss
 from utils.utils import cross_entropy_torch
+from timm.loss import AsymmetricLossSingleLabel
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
 
 #---->
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torchmetrics
+from torch import optim as optim
 
 #---->
 import pytorch_lightning as pl
 from .vision_transformer import vit_small
 from torchvision import models
 from torchvision.models import resnet
+from transformers import AutoFeatureExtractor, ViTModel
+
+from captum.attr import LayerGradCam
 
 class ModelInterface(pl.LightningModule):
 
@@ -33,13 +39,17 @@ class ModelInterface(pl.LightningModule):
         self.save_hyperparameters()
         self.load_model()
         self.loss = create_loss(loss)
+        # self.asl = AsymmetricLossSingleLabel()
+        # self.loss = LabelSmoothingCrossEntropy(smoothing=0.1)
+        
+        # self.loss = 
         self.optimizer = optimizer
         self.n_classes = model.n_classes
         self.log_path = kargs['log']
 
         #---->acc
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
-        
+        # print(self.experiment)
         #---->Metrics
         if self.n_classes > 2: 
             self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted')
@@ -73,35 +83,12 @@ class ModelInterface(pl.LightningModule):
         #--->random
         self.shuffle = kargs['data'].data_shuffle
         self.count = 0
+        self.backbone = kargs['backbone']
 
         self.out_features = 512
         if kargs['backbone'] == 'dino':
-            #---> dino feature extractor
-            arch = 'vit_small'
-            patch_size = 16
-            n_last_blocks = 4
-            # num_labels = 1000
-            avgpool_patchtokens = False
-            home = Path.cwd().parts[1]
-
-            weight_path = f'/{home}/ylan/workspace/dino/output/Aachen_2/checkpoint.pth'
-            model = vit_small(patch_size, num_classes=0)
-            # model.eval()
-            # set_parameter_requires_grad(model, feature_extracting)
-            for param in model.parameters():
-                param.requires_grad = False
-            # print(model.embed_dim)
-            # embed_dim = model.embed_dim * (n_last_blocks + int(avgpool_patchtokens))
-            # model.eval()
-            # print(embed_dim)
-            linear = nn.Linear(model.embed_dim, self.out_features)
-            linear.weight.data.normal_(mean=0.0, std=0.01)
-            linear.bias.data.zero_()
-            
-            self.model_ft = nn.Sequential(
-                model,
-                linear, 
-            )
+            self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
+            self.model_ft = ViTModel.from_pretrained('facebook/dino-vitb16')
         elif kargs['backbone'] == 'resnet18':
             resnet18 = models.resnet18(pretrained=True)
             modules = list(resnet18.children())[:-1]
@@ -109,7 +96,6 @@ class ModelInterface(pl.LightningModule):
 
             res18 = nn.Sequential(
                 *modules,
-                
             )
             for param in res18.parameters():
                 param.requires_grad = False
@@ -118,7 +104,7 @@ class ModelInterface(pl.LightningModule):
                 nn.AdaptiveAvgPool2d(1),
                 View((-1, 512)),
                 nn.Linear(512, self.out_features),
-                nn.ReLU(),
+                nn.GELU(),
             )
         elif kargs['backbone'] == 'resnet50':
 
@@ -135,7 +121,17 @@ class ModelInterface(pl.LightningModule):
                 nn.AdaptiveAvgPool2d(1),
                 View((-1, 1024)),
                 nn.Linear(1024, self.out_features),
-                nn.ReLU()
+                # nn.GELU()
+            )
+        elif kargs['backbone'] == 'efficientnet':
+            efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_widese_b0', pretrained=True)
+            for param in efficientnet.parameters():
+                param.requires_grad = False
+            # efn = list(efficientnet.children())[:-1]
+            efficientnet.classifier.fc = nn.Linear(1280, self.out_features)
+            self.model_ft = nn.Sequential(
+                efficientnet,
+                nn.GELU(),
             )
         elif kargs['backbone'] == 'simple': #mil-ab attention
             feature_extracting = False
@@ -151,21 +147,19 @@ class ModelInterface(pl.LightningModule):
                 nn.ReLU(),
             )
 
-    #---->remove v_num
-    # def get_progress_bar_dict(self):
-    #     # don't show the version number
-    #     items = super().get_progress_bar_dict()
-    #     items.pop("v_num", None)
-    #     return items
-
     def training_step(self, batch, batch_idx):
         #---->inference
+        
         data, label, _ = batch
         label = label.float()
-        data = data.squeeze(0)
+        data = data.squeeze(0).float()
+        # print(data)
         # print(data.shape)
-        features = self.model_ft(data)
-        
+        if self.backbone == 'dino':
+            features = self.model_ft(**data)
+            features = features.last_hidden_state
+        else:
+            features = self.model_ft(data)
         features = features.unsqueeze(0)
         # print(features.shape)
         # features = features.squeeze()
@@ -177,21 +171,28 @@ class ModelInterface(pl.LightningModule):
 
         #---->loss
         loss = self.loss(logits, label)
+        # loss = self.asl(logits, label.squeeze())
 
         #---->acc log
         # print(label)
-        Y_hat = int(Y_hat)
+        # Y_hat = int(Y_hat)
         # if self.n_classes == 2:
         #     Y = int(label[0][1])
         # else: 
         Y = torch.argmax(label)
             # Y = int(label[0])
         self.data[Y]["count"] += 1
-        self.data[Y]["correct"] += (Y_hat == Y)
+        self.data[Y]["correct"] += (int(Y_hat) == Y)
 
-        return {'loss': loss} 
+        return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': label} 
 
     def training_epoch_end(self, training_step_outputs):
+        # logits = torch.cat([x['logits'] for x in training_step_outputs], dim = 0)
+        probs = torch.cat([x['Y_prob'] for x in training_step_outputs])
+        max_probs = torch.stack([x['Y_hat'] for x in training_step_outputs])
+        # target = torch.stack([x['label'] for x in training_step_outputs], dim = 0)
+        target = torch.cat([x['label'] for x in training_step_outputs])
+        target = torch.argmax(target, dim=1)
         for c in range(self.n_classes):
             count = self.data[c]["count"]
             correct = self.data[c]["correct"]
@@ -202,12 +203,19 @@ class ModelInterface(pl.LightningModule):
             print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
 
+        # print('max_probs: ', max_probs)
+        # print('probs: ', probs)
+        if self.current_epoch % 10 == 0:
+            self.log_confusion_matrix(probs, target, stage='train')
+
+        self.log('Train/auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
+
     def validation_step(self, batch, batch_idx):
 
         data, label, _ = batch
 
         label = label.float()
-        data = data.squeeze(0)
+        data = data.squeeze(0).float()
         features = self.model_ft(data)
         features = features.unsqueeze(0)
 
@@ -224,20 +232,23 @@ class ModelInterface(pl.LightningModule):
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (Y_hat.item() == Y)
 
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y}
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
 
 
     def validation_epoch_end(self, val_step_outputs):
         logits = torch.cat([x['logits'] for x in val_step_outputs], dim = 0)
-        probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0)
+        # probs = torch.cat([x['Y_prob'] for x in val_step_outputs], dim = 0)
+        probs = torch.cat([x['Y_prob'] for x in val_step_outputs])
         max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs])
-        target = torch.stack([x['label'] for x in val_step_outputs], dim = 0)
+        # target = torch.stack([x['label'] for x in val_step_outputs], dim = 0)
+        target = torch.cat([x['label'] for x in val_step_outputs])
+        target = torch.argmax(target, dim=1)
         #---->
         # logits = logits.long()
         # target = target.squeeze().long()
         # logits = logits.squeeze(0)
         self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True)
-        self.log('auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
+        self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
 
         # print(max_probs.squeeze(0).shape)
         # print(target.shape)
@@ -245,12 +256,8 @@ class ModelInterface(pl.LightningModule):
                           on_epoch = True, logger = True)
 
         #----> log confusion matrix
-        confmat = self.confusion_matrix(max_probs.squeeze(), target)
-        df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
-        plt.figure()
-        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
-        plt.close(fig_)
-        self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch)
+        self.log_confusion_matrix(probs, target, stage='val')
+        
 
         #---->acc log
         for c in range(self.n_classes):
@@ -267,18 +274,12 @@ class ModelInterface(pl.LightningModule):
         if self.shuffle == True:
             self.count = self.count+1
             random.seed(self.count*50)
-    
-
-
-    def configure_optimizers(self):
-        optimizer = create_optimizer(self.optimizer, self.model)
-        return [optimizer]
 
     def test_step(self, batch, batch_idx):
 
         data, label, _ = batch
         label = label.float()
-        data = data.squeeze(0)
+        data = data.squeeze(0).float()
         features = self.model_ft(data)
         features = features.unsqueeze(0)
 
@@ -292,12 +293,14 @@ class ModelInterface(pl.LightningModule):
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (Y_hat.item() == Y)
 
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y}
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
 
     def test_epoch_end(self, output_results):
-        probs = torch.cat([x['Y_prob'] for x in output_results], dim = 0)
+        probs = torch.cat([x['Y_prob'] for x in output_results])
         max_probs = torch.stack([x['Y_hat'] for x in output_results])
-        target = torch.stack([x['label'] for x in output_results], dim = 0)
+        # target = torch.stack([x['label'] for x in output_results], dim = 0)
+        target = torch.cat([x['label'] for x in output_results])
+        target = torch.argmax(target, dim=1)
         
         #---->
         auc = self.AUROC(probs, target.squeeze())
@@ -326,19 +329,16 @@ class ModelInterface(pl.LightningModule):
             print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
 
-        confmat = self.confusion_matrix(max_probs.squeeze(), target)
-        df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
-        plt.figure()
-        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
-        # plt.close(fig_)
-        # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch)
-        plt.savefig(f'{self.log_path}/cm_test')
-        plt.close(fig_)
-
+        self.log_confusion_matrix(probs, target, stage='test')
         #---->
         result = pd.DataFrame([metrics])
         result.to_csv(self.log_path / 'result.csv')
 
+    def configure_optimizers(self):
+        # optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=self.optimizer.lr*0.1)
+        optimizer = create_optimizer(self.optimizer, self.model)
+        return optimizer     
+
 
     def load_model(self):
         name = self.hparams.model.name
@@ -350,6 +350,7 @@ class ModelInterface(pl.LightningModule):
         else:
             camel_name = name
         try:
+                
             Model = getattr(importlib.import_module(
                 f'models.{name}'), camel_name)
         except:
@@ -371,6 +372,20 @@ class ModelInterface(pl.LightningModule):
         args1.update(other_args)
         return Model(**args1)
 
+    def log_confusion_matrix(self, max_probs, target, stage):
+        confmat = self.confusion_matrix(max_probs.squeeze(), target)
+        df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
+        plt.figure()
+        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
+        # plt.close(fig_)
+        # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}')
+        self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+
+        if stage == 'test':
+            plt.savefig(f'{self.log_path}/cm_test')
+        plt.close(fig_)
+        # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch)
+
 class View(nn.Module):
     def __init__(self, shape):
         super().__init__()
@@ -383,4 +398,5 @@ class View(nn.Module):
         # batch_size = input.size(0)
         # shape = (batch_size, *self.shape)
         out = input.view(*self.shape)
-        return out
\ No newline at end of file
+        return out
+
diff --git a/models/resnet50.py b/models/resnet50.py
new file mode 100644
index 0000000..89e23d7
--- /dev/null
+++ b/models/resnet50.py
@@ -0,0 +1,293 @@
+import torch
+import torch.nn as nn
+from .utils import load_state_dict_from_url
+
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+           'wide_resnet50_2', 'wide_resnet101_2']
+
+
+model_urls = {
+    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+                 base_width=64, dilation=1, norm_layer=None):
+        super(BasicBlock, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        if groups != 1 or base_width != 64:
+            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+        if dilation > 1:
+            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = norm_layer(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = norm_layer(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+    # This variant is also known as ResNet V1.5 and improves accuracy according to
+    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+                 base_width=64, dilation=1, norm_layer=None):
+        super(Bottleneck, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        width = int(planes * (base_width / 64.)) * groups
+        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv1x1(inplanes, width)
+        self.bn1 = norm_layer(width)
+        self.conv2 = conv3x3(width, width, stride, groups, dilation)
+        self.bn2 = norm_layer(width)
+        self.conv3 = conv1x1(width, planes * self.expansion)
+        self.bn3 = norm_layer(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class ResNet(nn.Module):
+
+    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
+                 norm_layer=None):
+        super(ResNet, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        self._norm_layer = norm_layer
+
+        self.inplanes = 64
+        self.dilation = 1
+        if replace_stride_with_dilation is None:
+            # each element in the tuple indicates if we should replace
+            # the 2x2 stride with a dilated convolution instead
+            replace_stride_with_dilation = [False, False, False]
+        if len(replace_stride_with_dilation) != 3:
+            raise ValueError("replace_stride_with_dilation should be None "
+                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+        self.groups = groups
+        self.base_width = width_per_group
+        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = norm_layer(self.inplanes)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+                                       dilate=replace_stride_with_dilation[0])
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+                                       dilate=replace_stride_with_dilation[1])
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+                                       dilate=replace_stride_with_dilation[2])
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        # Zero-initialize the last BN in each residual branch,
+        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)
+                elif isinstance(m, BasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+        norm_layer = self._norm_layer
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                norm_layer(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+                            self.base_width, previous_dilation, norm_layer))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes, groups=self.groups,
+                                base_width=self.base_width, dilation=self.dilation,
+                                norm_layer=norm_layer))
+
+        return nn.Sequential(*layers)
+
+    def _forward_impl(self, x):
+        # See note [TorchScript super()]
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        x = torch.flatten(x, 1)
+        x = self.fc(x)
+
+        return x
+
+    def forward(self, x):
+        return self._forward_impl(x)
+
+
+def _resnet(arch, block, layers, pretrained, progress, **kwargs):
+    model = ResNet(block, layers, **kwargs)
+    if pretrained:
+        state_dict = load_state_dict_from_url(model_urls[arch],
+                                              progress=progress)
+        model.load_state_dict(state_dict)
+    return model
+
+
+def resnet18(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-18 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+                   **kwargs)
+
+
+
+def resnet34(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-34 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+                   **kwargs)
+
+
+
+def resnet50(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-50 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+                   **kwargs)
+
+
+
+def resnet101(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-101 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+                   **kwargs)
+
+
+
+def resnet152(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-152 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+                   **kwargs)
\ No newline at end of file
diff --git a/train.py b/train.py
index 182e3c0..036d5ed 100644
--- a/train.py
+++ b/train.py
@@ -3,6 +3,8 @@ from pathlib import Path
 import numpy as np
 import glob
 
+from sklearn.model_selection import KFold
+
 from datasets.data_interface import DataInterface, MILDataModule
 from models.model_interface import ModelInterface
 import models.vision_transformer as vits
@@ -11,25 +13,32 @@ from utils.utils import *
 # pytorch_lightning
 import pytorch_lightning as pl
 from pytorch_lightning import Trainer
+from pytorch_lightning.loops import KFoldLoop
+import torch
 
 #--->Setting parameters
 def make_parse():
     parser = argparse.ArgumentParser()
     parser.add_argument('--stage', default='train', type=str)
     parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
-    # parser.add_argument('--gpus', default = [2])
+    parser.add_argument('--gpus', default = 2, type=int)
+    parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
     parser.add_argument('--fold', default = 0)
+    parser.add_argument('--bag_size', default = 1024, type=int)
     args = parser.parse_args()
     return args
 
 #---->main
 def main(cfg):
 
+    torch.set_num_threads(16)
+
     #---->Initialize seed
     pl.seed_everything(cfg.General.seed)
 
     #---->load loggers
     cfg.load_loggers = load_loggers(cfg)
+    # print(cfg.load_loggers)
 
     #---->load callbacks
     cfg.callbacks = load_callbacks(cfg)
@@ -49,7 +58,10 @@ def main(cfg):
                 'batch_size': cfg.Data.train_dataloader.batch_size,
                 'num_workers': cfg.Data.train_dataloader.num_workers,
                 'n_classes': cfg.Model.n_classes,
+                'backbone': cfg.Model.backbone,
+                'bag_size': cfg.Data.bag_size,
                 }
+
     dm = MILDataModule(**DataInterface_dict)
     
 
@@ -70,9 +82,9 @@ def main(cfg):
         callbacks=cfg.callbacks,
         max_epochs= cfg.General.epochs,
         min_epochs = 200,
-        gpus=cfg.General.gpus,
-        # gpus = [4],
-        # strategy='ddp',
+        # gpus=cfg.General.gpus,
+        gpus = [2,3],
+        strategy='ddp',
         amp_backend='native',
         # amp_level=cfg.General.amp_level,  
         precision=cfg.General.precision,  
@@ -85,6 +97,7 @@ def main(cfg):
 
     #---->train or test
     if cfg.General.server == 'train':
+        trainer.fit_loop = KFoldLoop(3, trainer.fit_loop, )
         trainer.fit(model = model, datamodule = dm)
     else:
         model_paths = list(cfg.log_path.glob('*.ckpt'))
@@ -94,6 +107,7 @@ def main(cfg):
             new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
             trainer.test(model=new_model, datamodule=dm)
 
+
 if __name__ == '__main__':
 
     args = make_parse()
@@ -101,9 +115,12 @@ if __name__ == '__main__':
 
     #---->update
     cfg.config = args.config
-    # cfg.General.gpus = args.gpus
+    cfg.General.gpus = [args.gpus]
     cfg.General.server = args.stage
     cfg.Data.fold = args.fold
+    cfg.Loss.base_loss = args.loss
+    cfg.Data.bag_size = args.bag_size
+    
 
     #---->main
     main(cfg)
diff --git a/utils/__pycache__/utils.cpython-39.pyc b/utils/__pycache__/utils.cpython-39.pyc
index e1eeb958497b23b1f9e9934764083531720a66d0..704aade8ba3e8bcbedf237b3fde5db80b84fcd32 100644
GIT binary patch
delta 1426
zcmbOr_Df19k(ZZ?fq{YH?7=BX?LrI;k3k${%)`LI;K0DZP|U<QQAe3ka$-P!Jwu8B
z6f&g<&S6VoN?}N0N?}S7NP+M{@<J)XU@?$9j4uL}2dTs^FA6qKYz|wLM2dKdL<>Wd
zWC~L-gQnC=kef6aZ?Oa<mQ1{GNW)A&BR@A)zcME=Prp3BD7&~IF*#K~q$n}3*w@oX
z*CjQzz`ZCjtz_~B#$YA`#>uQq7WFC2H4IstSzNOj<}!&h)H0PY)i7kSHZx{%r!b_j
zv@%ID)Phv-q_Ac)7o96%s$pzqtYMO7s9~07NMWmC$l}f71L>7!s9}g_EaA@*s9`K&
zsbOkn6lbVmMplu+n8KdI(aT!PQo^4l2vt>&sf@Fixt2B03(YE492T%*H5OtjYYNz`
z5}_2X8ip)JkTaVZ85wF=YFN`j0m$taGWj}_JXc9-amg)~wEUcu$z06F(h>{|3`LR*
z3=CC#y1Kf$c0T#(3MrXIRtlQJlRcPSHEyx^<QEs;;!H{`PL0pWFD@<u>A1yOT#}fc
zdW*#+u_UoboPmL1@)Bk(fm^HvMVWaenrye2lhe|RWG7!|4l&UKMH~|cqY$GGqX-KZ
zqXHuzBOfEnKQ<OVM!x?X%wixOl+W^?jYWXDNN=()OBUm{$vaqb7&lE;WmVHyz*@qY
zB2WVI8bdQ<ieL(33S$aiickt;3V$!hLXbxpgBdi1H~X`)FiM&-FfiO=PsuDQNG!>?
z#gUeolbV;9n_6TsIfu<C=N5ZeVorQwX>yfhb!tI=a)y<GQA(|DHB{6}&m^t(7ISW5
zg(i2AEGT&7K!g#9kO2|qAk*0r3kp*6Qi=>fB8Fg*;?$zD)S@DNkcd131A``Ok<R3Y
zY{rbXlSSE;*o;AWC~UGVySoSnBL^cB0{-OSC<>U|#{PpbV{$!5t7ZsDGt({3iumIE
zw36J!id&48MZ6$m`9MSj0|P@Qh!Htims8v&3T$~vYF=@E(Jj{EjKqS}Tg=6!xtc6R
z;UHD93=9lWtQGM&`RTV<z%d9iN|PBP8v|A{xr0--J{{yh4p3|{N-?T1iZK?YfFxZ(
zu3@{yXmE?s@D`)dEk@&8j3&1jO<yuGFfhDi1`#YEf)%2UIVm$QimN0!w<x|WvnVkq
ziX$%{iIrcHky;eRnU^15T#}fSlX{CeF()S(WK=!KW)KE(K-d}N9v%h;h7yJv21$k*
zMoETX22DmkO~zZCIf+TBIq|uP1)5B^SV|I$vv0A2g&@HM@?a4ttZs=X=jWBB7L~+<
zWsBp}@{8g@%F97!gDhcSDk=dvoY9X-Q)==7PCIT5kf%Y3_LlTy11@QHNQ4?p*5<Nc
z14p#hWCJdl$*o-c+~71308$?{c|Mn>KoE!t4zxlLt7P&yE;SvHc|6P<j6957Y&?t{
YN&<{Lj66&nMxY$Q!6Cp1hrif30eovx`v3p{

delta 745
zcmew*H9<@#k(ZZ?fq{YH)M3Y@o7@Zxk3k${%*?>R;K0DZP&|cUqK-17z{G%jR;FME
zO~J`VjOQ42CNnb0PS#)wW;(+#xt7VIzLu$ksfHnowV5%CGle0ArIksNp_aLXsfMwc
zv4%;Sp@vzSA%(SuA&V=Edp5&dkO)IOV+l_RTMa`NBS@r~k&&T<H;b=^v4o|DshLrn
zp@tb*M+##KdkRM{Yb{F&Zx%nADoMD(DV)8`wX8KPHLU3%=W<Qn$gC)mmY<VS#iy&Q
zt83?zpRSOSS!AU!`3|!&w>Sd>Ly-gn1B0f(WGxm~@gk7GEf$x=lEfl01_p*(tOZ4x
zc_ovlvuJT@vfW}%PD?M6nS7BYL`en|SWFy@LX2{ZB8&=*0?d4je2gsr*jV@&`Tnyp
z7imq7WX)onH2DZ?4rAYBLpC**U<OTr&Gl?7jFN^RquEn3iwY7;GH!9CW#**jCFZ6U
znM|I;Zse@VT?8_{NEYNcJrE%cB8)+r*b)m0Qu9)ZbU-4yV3Fd~qO#PYB5ja}9LOuI
zMH-W>IE)!BC+Be}v8jO~(Pi>94tJ&^ugM=delP}4e#P0!=r_55OHL_>fq@|eq%#;q
zFoLwR-C`|D%u7$b#hjCxR}?aN1D6zAAV@ZR@&zslXD~Ai%m)#CAU6mn7v&ch$EW6%
z6y+CG#+T$5C1-$4F6ID*5~CEO3ZocfQ5?v0Mn6r7$pSof+^QhmS|HuhlN)%X*&!a#
znOx0d!3Or2*5n2rnaNjq__@Jh;0012GWk7^CYuk488=yjS4{>KcRb7-j6957Y&?t{
Qq5_OOj66&nTA+vl0EcCxs{jB1

diff --git a/utils/extract_features.py b/utils/extract_features.py
new file mode 100644
index 0000000..fb040a1
--- /dev/null
+++ b/utils/extract_features.py
@@ -0,0 +1,27 @@
+## Choose Model and extract features from (augmented) image patches and save as .pt file
+
+from datasets.custom_dataloader import HDF5MILDataloader
+
+
+def extract_features(input_dir, output_dir, model, batch_size):
+
+
+    dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes)
+    if model == 'resnet50':
+        model = Resnet50_baseline(pretrained = True)       
+        model = model.to(device)
+        model.eval()
+
+
+
+if __name__ == '__main__':
+
+    # input_dir, output_dir
+    # initiate data loader
+    # use data loader to load and augment images 
+    # prediction from model
+    # choose save as bag or not (needed?)
+    
+    # features = torch.from_numpy(features)
+    # torch.save(features, output_path + '.pt') 
+
diff --git a/utils/utils.py b/utils/utils.py
index 96ed223..010eaab 100755
--- a/utils/utils.py
+++ b/utils/utils.py
@@ -1,4 +1,19 @@
 from pathlib import Path
+from abc import ABC, abstractclassmethod
+import torch
+import torchvision.transforms as T
+from sklearn.model_selection import KFold
+from torch.nn import functional as F
+from torch.utils.data import random_split
+from torch.utils.data.dataloader import DataLoader
+from torch.utils.data.dataset import Dataset, Subset
+from torchmetrics.classification.accuracy import Accuracy
+
+from pytorch_lightning import LightningDataModule, seed_everything, Trainer
+from pytorch_lightning.core.module import LightningModule
+from pytorch_lightning.loops.base import Loop
+from pytorch_lightning.loops.fit_loop import FitLoop
+from pytorch_lightning.trainer.states import TrainerFn
 
 #---->read yaml
 import yaml
@@ -14,19 +29,32 @@ def load_loggers(cfg):
 
     log_path = cfg.General.log_path
     Path(log_path).mkdir(exist_ok=True, parents=True)
-    log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}'
+    log_name = str(Path(cfg.config).parent) + f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
     version_name = Path(cfg.config).name[:-5]
-    cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}'
-    print(f'---->Log dir: {cfg.log_path}')
+    
     
     #---->TensorBoard
-    tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name),
-                                             name = version_name, version = f'fold{cfg.Data.fold}',
-                                             log_graph = True, default_hp_metric = False)
-    #---->CSV
-    csv_logger = pl_loggers.CSVLogger(log_path+str(log_name),
-                                      name = version_name, version = f'fold{cfg.Data.fold}', )
+    if cfg.stage != 'test':
+        cfg.log_path = Path(log_path) / log_name / version_name / f'fold{cfg.Data.fold}'
+        tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name),
+                                                name = version_name, version = f'fold{cfg.Data.fold}',
+                                                log_graph = True, default_hp_metric = False)
+        #---->CSV
+        csv_logger = pl_loggers.CSVLogger(log_path+str(log_name),
+                                        name = version_name, version = f'fold{cfg.Data.fold}', )
+    else:  
+        cfg.log_path = Path(log_path) / log_name / version_name / f'test'
+        tb_logger = pl_loggers.TensorBoardLogger(log_path+str(log_name),
+                                                name = version_name, version = f'test',
+                                                log_graph = True, default_hp_metric = False)
+        #---->CSV
+        csv_logger = pl_loggers.CSVLogger(log_path+str(log_name),
+                                        name = version_name, version = f'test', )
+                                    
     
+    print(f'---->Log dir: {cfg.log_path}')
+
+    # return tb_logger
     return [tb_logger, csv_logger]
 
 
@@ -74,6 +102,14 @@ def load_callbacks(cfg):
                                          save_top_k = 1,
                                          mode = 'min',
                                          save_weights_only = True))
+        Mycallbacks.append(ModelCheckpoint(monitor = 'val_auc',
+                                         dirpath = str(cfg.log_path),
+                                         filename = '{epoch:02d}-{val_auc:.4f}',
+                                         verbose = True,
+                                         save_last = True,
+                                         save_top_k = 1,
+                                         mode = 'max',
+                                         save_weights_only = True))
     return Mycallbacks
 
 #---->val loss
@@ -84,3 +120,103 @@ def cross_entropy_torch(x, y):
     x_log = torch.tensor([torch.log(x_softmax[i][y[i]]) for i in range(y.shape[0])])
     loss = - torch.sum(x_log) / y.shape[0]
     return loss
+
+#-----> convert labels for task
+label_map = {
+    'bin': {'0': 0, '1': 1, '2': 1, '3': 1, '4': 1, '5': None},
+    'tcmr_viral': {'0': None, '1': 0, '2': None, '3': None, '4': 1, '5': None},
+    'no_viral': {'0': 0, '1': 1, '2': 2, '3': 3, '4': None, '5': None},
+    'no_other': {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': None},
+    'no_stable': {'0': None, '1': 1, '2': 2, '3': 3, '4': None, '5': None},
+    'all': {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5},
+
+}
+def convert_labels_for_task(task, label):
+
+    return label_map[task][label]
+
+
+#-----> KFOLD LOOP
+
+class KFoldLoop(Loop):
+    def __init__(self, num_folds: int, export_path: str) -> None:
+        super().__init__()
+        self.num_folds = num_folds
+        self.current_fold: int = 0
+        self.export_path = export_path
+
+    @property
+    def done(self) -> bool:
+        return self.current_fold >= self.num_folds
+
+    def connect(self, fit_loop: FitLoop) -> None:
+        self.fit_loop = fit_loop
+
+    def reset(self) -> None:
+        """Nothing to reset in this loop."""
+
+    def on_run_start(self, *args: Any, **kwargs: Any) -> None:
+        """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
+        model."""
+        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
+        self.trainer.datamodule.setup_folds(self.num_folds)
+        self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict())
+
+    def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
+        """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
+        print(f"STARTING FOLD {self.current_fold}")
+        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
+        self.trainer.datamodule.setup_fold_index(self.current_fold)
+
+    def advance(self, *args: Any, **kwargs: Any) -> None:
+        """Used to the run a fitting and testing on the current hold."""
+        self._reset_fitting()  # requires to reset the tracking stage.
+        self.fit_loop.run()
+
+        self._reset_testing()  # requires to reset the tracking stage.
+        self.trainer.test_loop.run()
+        self.current_fold += 1  # increment fold tracking number.
+
+    def on_advance_end(self) -> None:
+        """Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
+        self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt"))
+        # restore the original weights + optimizers and schedulers.
+        self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
+        self.trainer.strategy.setup_optimizers(self.trainer)
+        self.replace(fit_loop=FitLoop)
+
+    def on_run_end(self) -> None:
+        """Used to compute the performance of the ensemble model on the test set."""
+        checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)]
+        voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths)
+        voting_model.trainer = self.trainer
+        # This requires to connect the new model and move it the right device.
+        self.trainer.strategy.connect(voting_model)
+        self.trainer.strategy.model_to_device()
+        self.trainer.test_loop.run()
+
+    def on_save_checkpoint(self) -> Dict[str, int]:
+        return {"current_fold": self.current_fold}
+
+    def on_load_checkpoint(self, state_dict: Dict) -> None:
+        self.current_fold = state_dict["current_fold"]
+
+    def _reset_fitting(self) -> None:
+        self.trainer.reset_train_dataloader()
+        self.trainer.reset_val_dataloader()
+        self.trainer.state.fn = TrainerFn.FITTING
+        self.trainer.training = True
+
+    def _reset_testing(self) -> None:
+        self.trainer.reset_test_dataloader()
+        self.trainer.state.fn = TrainerFn.TESTING
+        self.trainer.testing = True
+
+    def __getattr__(self, key) -> Any:
+        # requires to be overridden as attributes of the wrapped loop are being accessed.
+        if key not in self.__dict__:
+            return getattr(self.fit_loop, key)
+        return self.__dict__[key]
+
+    def __setstate__(self, state: Dict[str, Any]) -> None:
+        self.__dict__.update(state)
\ No newline at end of file
-- 
GitLab