diff --git a/DeepGraft/AttMIL_simple_no_other.yaml b/DeepGraft/AttMIL_simple_no_other.yaml
index ae90a80ee152a2bd9d5b04c059fff38b1e77c06f..344c2041cda7d9a3b68e7d6bb5b7fc9a2c888812 100644
--- a/DeepGraft/AttMIL_simple_no_other.yaml
+++ b/DeepGraft/AttMIL_simple_no_other.yaml
@@ -16,7 +16,7 @@ General:
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    data_dir: '/home/ylan/data/DeepGraft/224_128um/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_other.json'
     fold: 1
     nfold: 3
diff --git a/DeepGraft/AttMIL_simple_no_viral.yaml b/DeepGraft/AttMIL_simple_no_viral.yaml
index 37ee07479e014d9277f582196eb5a50f4d74e2de..ac18f9cea8c08d0cf2f82f06f7cec31c7c6cec34 100644
--- a/DeepGraft/AttMIL_simple_no_viral.yaml
+++ b/DeepGraft/AttMIL_simple_no_viral.yaml
@@ -6,7 +6,7 @@ General:
     precision: 16 
     multi_gpu_mode: dp
     gpus: [3]
-    epochs: &epoch 500 
+    epochs: &epoch 200 
     grad_acc: 2
     frozen_bn: False
     patience: 50
@@ -16,7 +16,7 @@ General:
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    data_dir: '/home/ylan/data/DeepGraft/224_128um/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json'
     fold: 1
     nfold: 4
@@ -41,7 +41,7 @@ Optimizer:
     opt_eps: null 
     opt_betas: null
     momentum: null 
-    weight_decay: 0.00001
+    weight_decay: 0.01
 
 Loss:
     base_loss: CrossEntropyLoss
diff --git a/DeepGraft/AttMIL_simple_tcmr_viral.yaml b/DeepGraft/AttMIL_simple_tcmr_viral.yaml
index c982d3ad1bae365a2497a599a805dadf9874a9c2..2aa4ed853c9a324ce4ce6c3ca938c91c637b1e97 100644
--- a/DeepGraft/AttMIL_simple_tcmr_viral.yaml
+++ b/DeepGraft/AttMIL_simple_tcmr_viral.yaml
@@ -6,7 +6,7 @@ General:
     precision: 16 
     multi_gpu_mode: dp
     gpus: [3]
-    epochs: &epoch 300 
+    epochs: &epoch 200 
     grad_acc: 2
     frozen_bn: False
     patience: 20
@@ -16,7 +16,7 @@ General:
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    data_dir: '/home/ylan/data/DeepGraft/224_128um/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
     fold: 1
     nfold: 3
@@ -34,6 +34,7 @@ Model:
     name: AttMIL
     n_classes: 2
     backbone: simple
+    in_features: 512
 
 
 Optimizer:
@@ -42,7 +43,7 @@ Optimizer:
     opt_eps: null 
     opt_betas: null
     momentum: null 
-    weight_decay: 0.00001
+    weight_decay: 0.01
 
 Loss:
     base_loss: CrossEntropyLoss
diff --git a/DeepGraft/DTFDMIL_resnet50_tcmr_viral.yaml b/DeepGraft/DTFDMIL_resnet50_tcmr_viral.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7493b37b5251c6cd224154032a6a3d276c321e3a
--- /dev/null
+++ b/DeepGraft/DTFDMIL_resnet50_tcmr_viral.yaml
@@ -0,0 +1,51 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 200 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: DTFDMIL
+    n_classes: 2
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+
diff --git a/DeepGraft/TransMIL_efficientnet_no_viral.yaml b/DeepGraft/TransMIL_efficientnet_no_viral.yaml
index 98fe3778c9528e38d027b1f86d4d4b18631b0fe8..cffd5008f5daf838afb2f1e24c344c1088d7db96 100644
--- a/DeepGraft/TransMIL_efficientnet_no_viral.yaml
+++ b/DeepGraft/TransMIL_efficientnet_no_viral.yaml
@@ -11,15 +11,16 @@ General:
     frozen_bn: False
     patience: 20
     server: test #train #test
-    log_path: logs/
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
 
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
+    data_dir: '/home/ylan/data/DeepGraft/256_256um/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_no_viral.json'
     fold: 1
     nfold: 4
+    cross_val: False
 
     train_dataloader:
         batch_size: 1 
diff --git a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
index 52230329255872d150dbc4d0552d1265dc3abc80..0dc5fa1ea633d4ccd6d746a53ddf2c784bb7567a 100644
--- a/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
+++ b/DeepGraft/TransMIL_efficientnet_tcmr_viral.yaml
@@ -6,21 +6,21 @@ General:
     precision: 16 
     multi_gpu_mode: dp
     gpus: [0]
-    epochs: &epoch 500 
+    epochs: &epoch 200 
     grad_acc: 2
     frozen_bn: False
-    patience: 200
+    patience: 20
     server: train #train #test
-    log_path: logs/
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
 
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/data/DeepGraft/256_256um/'
+    data_dir: '/home/ylan/data/DeepGraft/224_128um/'
     label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
     fold: 1
     nfold: 3
-    cross_val: True
+    cross_val: False
 
     train_dataloader:
         batch_size: 1 
@@ -35,6 +35,7 @@ Model:
     n_classes: 2
     backbone: efficientnet
     in_features: 512
+    out_features: 512
 
 
 Optimizer:
@@ -43,7 +44,7 @@ Optimizer:
     opt_eps: null 
     opt_betas: null
     momentum: null 
-    weight_decay: 0.00001
+    weight_decay: 0.01
 
 Loss:
     base_loss: CrossEntropyLoss
diff --git a/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml
index c26e1e9ff0329f32efbfd96334c6d1ac957d90bb..961f819772a39cbe332cebdd21c12a4f50b4dc36 100644
--- a/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml
+++ b/DeepGraft/TransMIL_resnet18_tcmr_viral.yaml
@@ -5,21 +5,22 @@ General:
     amp_level: O2
     precision: 16 
     multi_gpu_mode: dp
-    gpus: [3]
-    epochs: &epoch 500 
+    gpus: [0]
+    epochs: &epoch 200 
     grad_acc: 2
     frozen_bn: False
     patience: 50
     server: test #train #test
-    log_path: logs/
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
 
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
-    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    data_dir: '/home/ylan/data/DeepGraft/224_128um/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json'
     fold: 1
-    nfold: 4
+    nfold: 3
+    cross_val: False
 
     train_dataloader:
         batch_size: 1 
@@ -33,6 +34,8 @@ Model:
     name: TransMIL
     n_classes: 2
     backbone: resnet18
+    in_features: 512
+    out_features: 512
 
 
 Optimizer:
@@ -41,7 +44,7 @@ Optimizer:
     opt_eps: null 
     opt_betas: null
     momentum: null 
-    weight_decay: 0.00001
+    weight_decay: 0.01
 
 Loss:
     base_loss: CrossEntropyLoss
diff --git a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
index f6e469763b0a6df74a2ce2306ae7e94f53d9770b..df205a2342386af0e48aedb777b8e5a5e6aa7aba 100644
--- a/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+++ b/DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
@@ -6,21 +6,21 @@ General:
     precision: 16 
     multi_gpu_mode: dp
     gpus: [0]
-    epochs: &epoch 200 
+    epochs: &epoch 500 
     grad_acc: 2
     frozen_bn: False
     patience: 50
     server: test #train #test
-    log_path: logs/
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
 
 Data:
     dataset_name: custom
     data_shuffle: False
-    data_dir: '/home/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
-    label_file: '/home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json'
+    data_dir: '/home/ylan/data/DeepGraft/224_128um/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
     fold: 1
     nfold: 3
-    cross_val: True
+    cross_val: False
 
     train_dataloader:
         batch_size: 1 
@@ -34,6 +34,8 @@ Model:
     name: TransMIL
     n_classes: 2
     backbone: resnet50
+    in_features: 512
+    out_features: 512
 
 
 Optimizer:
@@ -42,8 +44,9 @@ Optimizer:
     opt_eps: null 
     opt_betas: null
     momentum: null 
-    weight_decay: 0.00001
+    weight_decay: 0.01
 
 Loss:
     base_loss: CrossEntropyLoss
+    loss_weight: [1., 1.]
 
diff --git a/DeepGraft/TransformerMIL_resnet50_tcmr_viral.yaml b/DeepGraft/TransformerMIL_resnet50_tcmr_viral.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..896bef13fb4cde9841f9afcb7e87a7fe6d72dd91
--- /dev/null
+++ b/DeepGraft/TransformerMIL_resnet50_tcmr_viral.yaml
@@ -0,0 +1,52 @@
+General:
+    comment: 
+    seed: 2021
+    fp16: True
+    amp_level: O2
+    precision: 16 
+    multi_gpu_mode: dp
+    gpus: [0]
+    epochs: &epoch 500 
+    grad_acc: 2
+    frozen_bn: False
+    patience: 50
+    server: test #train #test
+    log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+
+Data:
+    dataset_name: custom
+    data_shuffle: False
+    data_dir: '/home/ylan/data/DeepGraft/224_128um/'
+    label_file: '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    fold: 1
+    nfold: 3
+    cross_val: False
+
+    train_dataloader:
+        batch_size: 1 
+        num_workers: 8
+
+    test_dataloader:
+        batch_size: 1
+        num_workers: 8
+
+Model:
+    name: TransformerMIL
+    n_classes: 2
+    backbone: resnet50
+    in_features: 512
+    out_features: 512
+
+
+Optimizer:
+    opt: lookahead_radam
+    lr: 0.0002
+    opt_eps: null 
+    opt_betas: null
+    momentum: null 
+    weight_decay: 0.01
+
+Loss:
+    base_loss: CrossEntropyLoss
+    
+
diff --git a/MyLoss/__pycache__/loss_factory.cpython-39.pyc b/MyLoss/__pycache__/loss_factory.cpython-39.pyc
deleted file mode 100644
index 14452dde7b34fd11b5255f4ed07bdeb48a27e49f..0000000000000000000000000000000000000000
Binary files a/MyLoss/__pycache__/loss_factory.cpython-39.pyc and /dev/null differ
diff --git a/README.md b/README.md
index effc0fe369c45153ec681403569c710ea114baf3..c8594bc7edf8bcade5f8d1c9749978fd0c71f98e 100644
--- a/README.md
+++ b/README.md
@@ -13,3 +13,26 @@ python train.py --stage='test' --config='Camelyon/TransMIL.yaml'  --gpus=0 --fol
 
 ### Changes Made: 
 
+### Baseline: 
+
+lr = 0.0002
+wd = 0.01
+
+| task        | main | backbone | train_auc | val_auc | epochs | version |  
+|---|---|---|---|---|
+| tcmr_viral | TransMIL | resnet50 |  0.997 | 0.871 | 200 | 4 |
+|            |          | resnet18 |  0.999 | 0.687 | 200 | 0 |
+|            |          | efficientnet | 0.99 | 0.76 | 200 | 107 |
+|            | DTFD     | resnet50 | 0.989 | 0.621 | 200 | 44 |
+|            | AttMIL   | simple | 0.513 | 0.518 | 200 | 50 |
+
+
+159	28639			0.9222221970558167	0.19437336921691895	0.5906432867050171	0.56540447473526	0.7159091234207153	0.8709122538566589	0.30908203125
+
+### Ablation
+
+image drop out: 
+tcmr_viral TCMR efficientnet: version 0
+
+wd incerease: 
+tcmr_viral TCMR efficientnet: version 110
\ No newline at end of file
diff --git a/MyBackbone/__init__.py b/code/MyBackbone/__init__.py
similarity index 100%
rename from MyBackbone/__init__.py
rename to code/MyBackbone/__init__.py
diff --git a/MyBackbone/backbone_factory.py b/code/MyBackbone/backbone_factory.py
similarity index 99%
rename from MyBackbone/backbone_factory.py
rename to code/MyBackbone/backbone_factory.py
index ff770e583fc3ddc4712424f9381ed9adeb8b7742..31bc17bd566126365af5cc5f4e4eb9d8d69bb417 100644
--- a/MyBackbone/backbone_factory.py
+++ b/code/MyBackbone/backbone_factory.py
@@ -25,7 +25,6 @@ def init_backbone(**kargs):
         resnet18 = models.resnet18(pretrained=True)
         modules = list(resnet18.children())[:-1]
         # model_ft.fc = nn.Linear(512, out_features)
-
         res18 = nn.Sequential(
             *modules,
         )
diff --git a/MyLoss/ND_Crossentropy.py b/code/MyLoss/ND_Crossentropy.py
similarity index 100%
rename from MyLoss/ND_Crossentropy.py
rename to code/MyLoss/ND_Crossentropy.py
diff --git a/MyLoss/__init__.py b/code/MyLoss/__init__.py
similarity index 100%
rename from MyLoss/__init__.py
rename to code/MyLoss/__init__.py
diff --git a/MyLoss/__pycache__/ND_Crossentropy.cpython-39.pyc b/code/MyLoss/__pycache__/ND_Crossentropy.cpython-39.pyc
similarity index 100%
rename from MyLoss/__pycache__/ND_Crossentropy.cpython-39.pyc
rename to code/MyLoss/__pycache__/ND_Crossentropy.cpython-39.pyc
diff --git a/MyLoss/__pycache__/__init__.cpython-39.pyc b/code/MyLoss/__pycache__/__init__.cpython-39.pyc
similarity index 100%
rename from MyLoss/__pycache__/__init__.cpython-39.pyc
rename to code/MyLoss/__pycache__/__init__.cpython-39.pyc
diff --git a/MyLoss/__pycache__/boundary_loss.cpython-39.pyc b/code/MyLoss/__pycache__/boundary_loss.cpython-39.pyc
similarity index 100%
rename from MyLoss/__pycache__/boundary_loss.cpython-39.pyc
rename to code/MyLoss/__pycache__/boundary_loss.cpython-39.pyc
diff --git a/MyLoss/__pycache__/dice_loss.cpython-39.pyc b/code/MyLoss/__pycache__/dice_loss.cpython-39.pyc
similarity index 100%
rename from MyLoss/__pycache__/dice_loss.cpython-39.pyc
rename to code/MyLoss/__pycache__/dice_loss.cpython-39.pyc
diff --git a/MyLoss/__pycache__/focal_loss.cpython-39.pyc b/code/MyLoss/__pycache__/focal_loss.cpython-39.pyc
similarity index 100%
rename from MyLoss/__pycache__/focal_loss.cpython-39.pyc
rename to code/MyLoss/__pycache__/focal_loss.cpython-39.pyc
diff --git a/MyLoss/__pycache__/hausdorff.cpython-39.pyc b/code/MyLoss/__pycache__/hausdorff.cpython-39.pyc
similarity index 100%
rename from MyLoss/__pycache__/hausdorff.cpython-39.pyc
rename to code/MyLoss/__pycache__/hausdorff.cpython-39.pyc
diff --git a/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc b/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a56099c87d6626e35a3f5d5f82502c95139ee5b
Binary files /dev/null and b/code/MyLoss/__pycache__/loss_factory.cpython-39.pyc differ
diff --git a/MyLoss/__pycache__/lovasz_loss.cpython-39.pyc b/code/MyLoss/__pycache__/lovasz_loss.cpython-39.pyc
similarity index 100%
rename from MyLoss/__pycache__/lovasz_loss.cpython-39.pyc
rename to code/MyLoss/__pycache__/lovasz_loss.cpython-39.pyc
diff --git a/MyLoss/__pycache__/poly_loss.cpython-39.pyc b/code/MyLoss/__pycache__/poly_loss.cpython-39.pyc
similarity index 100%
rename from MyLoss/__pycache__/poly_loss.cpython-39.pyc
rename to code/MyLoss/__pycache__/poly_loss.cpython-39.pyc
diff --git a/MyLoss/boundary_loss.py b/code/MyLoss/boundary_loss.py
similarity index 100%
rename from MyLoss/boundary_loss.py
rename to code/MyLoss/boundary_loss.py
diff --git a/MyLoss/dice_loss.py b/code/MyLoss/dice_loss.py
similarity index 100%
rename from MyLoss/dice_loss.py
rename to code/MyLoss/dice_loss.py
diff --git a/MyLoss/focal_loss.py b/code/MyLoss/focal_loss.py
similarity index 100%
rename from MyLoss/focal_loss.py
rename to code/MyLoss/focal_loss.py
diff --git a/MyLoss/hausdorff.py b/code/MyLoss/hausdorff.py
similarity index 100%
rename from MyLoss/hausdorff.py
rename to code/MyLoss/hausdorff.py
diff --git a/MyLoss/loss_factory.py b/code/MyLoss/loss_factory.py
similarity index 91%
rename from MyLoss/loss_factory.py
rename to code/MyLoss/loss_factory.py
index f3bdcebb96480a0ed6de1e33a0de2defaf1792ea..8ff6d1814af1bb43a9e8e16a243f6fe87fe476cf 100755
--- a/MyLoss/loss_factory.py
+++ b/code/MyLoss/loss_factory.py
@@ -19,11 +19,15 @@ from pytorch_toolbelt import losses as L
 
 def create_loss(args, w1=1.0, w2=0.5):
     conf_loss = args.base_loss
+    if args.loss_weight: 
+        weight = torch.tensor(args.loss_weight)
+    else: weight = None
     ### MulticlassJaccardLoss(classes=np.arange(11)
     # mode = args.base_loss #BINARY_MODE \MULTICLASS_MODE \MULTILABEL_MODE 
     loss = None
     if hasattr(nn, conf_loss): 
-        loss = getattr(nn, conf_loss)(label_smoothing=0.5) 
+        loss = getattr(nn, conf_loss)(weight=weight, label_smoothing=0.5) 
+        # loss = getattr(nn, conf_loss)(label_smoothing=0.5) 
     #binary loss
     elif conf_loss == "focal":
         loss = L.BinaryFocalLoss()
diff --git a/MyLoss/lovasz_loss.py b/code/MyLoss/lovasz_loss.py
similarity index 100%
rename from MyLoss/lovasz_loss.py
rename to code/MyLoss/lovasz_loss.py
diff --git a/MyLoss/poly_loss.py b/code/MyLoss/poly_loss.py
similarity index 100%
rename from MyLoss/poly_loss.py
rename to code/MyLoss/poly_loss.py
diff --git a/MyOptimizer/__init__.py b/code/MyOptimizer/__init__.py
similarity index 100%
rename from MyOptimizer/__init__.py
rename to code/MyOptimizer/__init__.py
diff --git a/MyOptimizer/__pycache__/__init__.cpython-39.pyc b/code/MyOptimizer/__pycache__/__init__.cpython-39.pyc
similarity index 100%
rename from MyOptimizer/__pycache__/__init__.cpython-39.pyc
rename to code/MyOptimizer/__pycache__/__init__.cpython-39.pyc
diff --git a/MyOptimizer/__pycache__/adafactor.cpython-39.pyc b/code/MyOptimizer/__pycache__/adafactor.cpython-39.pyc
similarity index 100%
rename from MyOptimizer/__pycache__/adafactor.cpython-39.pyc
rename to code/MyOptimizer/__pycache__/adafactor.cpython-39.pyc
diff --git a/MyOptimizer/__pycache__/adahessian.cpython-39.pyc b/code/MyOptimizer/__pycache__/adahessian.cpython-39.pyc
similarity index 100%
rename from MyOptimizer/__pycache__/adahessian.cpython-39.pyc
rename to code/MyOptimizer/__pycache__/adahessian.cpython-39.pyc
diff --git a/MyOptimizer/__pycache__/adamp.cpython-39.pyc b/code/MyOptimizer/__pycache__/adamp.cpython-39.pyc
similarity index 100%
rename from MyOptimizer/__pycache__/adamp.cpython-39.pyc
rename to code/MyOptimizer/__pycache__/adamp.cpython-39.pyc
diff --git a/MyOptimizer/__pycache__/adamw.cpython-39.pyc b/code/MyOptimizer/__pycache__/adamw.cpython-39.pyc
similarity index 100%
rename from MyOptimizer/__pycache__/adamw.cpython-39.pyc
rename to code/MyOptimizer/__pycache__/adamw.cpython-39.pyc
diff --git a/MyOptimizer/__pycache__/lookahead.cpython-39.pyc b/code/MyOptimizer/__pycache__/lookahead.cpython-39.pyc
similarity index 100%
rename from MyOptimizer/__pycache__/lookahead.cpython-39.pyc
rename to code/MyOptimizer/__pycache__/lookahead.cpython-39.pyc
diff --git a/MyOptimizer/__pycache__/nadam.cpython-39.pyc b/code/MyOptimizer/__pycache__/nadam.cpython-39.pyc
similarity index 100%
rename from MyOptimizer/__pycache__/nadam.cpython-39.pyc
rename to code/MyOptimizer/__pycache__/nadam.cpython-39.pyc
diff --git a/MyOptimizer/__pycache__/novograd.cpython-39.pyc b/code/MyOptimizer/__pycache__/novograd.cpython-39.pyc
similarity index 100%
rename from MyOptimizer/__pycache__/novograd.cpython-39.pyc
rename to code/MyOptimizer/__pycache__/novograd.cpython-39.pyc
diff --git a/MyOptimizer/__pycache__/nvnovograd.cpython-39.pyc b/code/MyOptimizer/__pycache__/nvnovograd.cpython-39.pyc
similarity index 100%
rename from MyOptimizer/__pycache__/nvnovograd.cpython-39.pyc
rename to code/MyOptimizer/__pycache__/nvnovograd.cpython-39.pyc
diff --git a/MyOptimizer/__pycache__/optim_factory.cpython-39.pyc b/code/MyOptimizer/__pycache__/optim_factory.cpython-39.pyc
similarity index 100%
rename from MyOptimizer/__pycache__/optim_factory.cpython-39.pyc
rename to code/MyOptimizer/__pycache__/optim_factory.cpython-39.pyc
diff --git a/MyOptimizer/__pycache__/radam.cpython-39.pyc b/code/MyOptimizer/__pycache__/radam.cpython-39.pyc
similarity index 100%
rename from MyOptimizer/__pycache__/radam.cpython-39.pyc
rename to code/MyOptimizer/__pycache__/radam.cpython-39.pyc
diff --git a/MyOptimizer/__pycache__/rmsprop_tf.cpython-39.pyc b/code/MyOptimizer/__pycache__/rmsprop_tf.cpython-39.pyc
similarity index 100%
rename from MyOptimizer/__pycache__/rmsprop_tf.cpython-39.pyc
rename to code/MyOptimizer/__pycache__/rmsprop_tf.cpython-39.pyc
diff --git a/MyOptimizer/__pycache__/sgdp.cpython-39.pyc b/code/MyOptimizer/__pycache__/sgdp.cpython-39.pyc
similarity index 100%
rename from MyOptimizer/__pycache__/sgdp.cpython-39.pyc
rename to code/MyOptimizer/__pycache__/sgdp.cpython-39.pyc
diff --git a/MyOptimizer/adafactor.py b/code/MyOptimizer/adafactor.py
similarity index 100%
rename from MyOptimizer/adafactor.py
rename to code/MyOptimizer/adafactor.py
diff --git a/MyOptimizer/adahessian.py b/code/MyOptimizer/adahessian.py
similarity index 100%
rename from MyOptimizer/adahessian.py
rename to code/MyOptimizer/adahessian.py
diff --git a/MyOptimizer/adamp.py b/code/MyOptimizer/adamp.py
similarity index 100%
rename from MyOptimizer/adamp.py
rename to code/MyOptimizer/adamp.py
diff --git a/MyOptimizer/adamw.py b/code/MyOptimizer/adamw.py
similarity index 100%
rename from MyOptimizer/adamw.py
rename to code/MyOptimizer/adamw.py
diff --git a/MyOptimizer/lookahead.py b/code/MyOptimizer/lookahead.py
similarity index 100%
rename from MyOptimizer/lookahead.py
rename to code/MyOptimizer/lookahead.py
diff --git a/MyOptimizer/nadam.py b/code/MyOptimizer/nadam.py
similarity index 100%
rename from MyOptimizer/nadam.py
rename to code/MyOptimizer/nadam.py
diff --git a/MyOptimizer/novograd.py b/code/MyOptimizer/novograd.py
similarity index 100%
rename from MyOptimizer/novograd.py
rename to code/MyOptimizer/novograd.py
diff --git a/MyOptimizer/nvnovograd.py b/code/MyOptimizer/nvnovograd.py
similarity index 100%
rename from MyOptimizer/nvnovograd.py
rename to code/MyOptimizer/nvnovograd.py
diff --git a/MyOptimizer/optim_factory.py b/code/MyOptimizer/optim_factory.py
similarity index 100%
rename from MyOptimizer/optim_factory.py
rename to code/MyOptimizer/optim_factory.py
diff --git a/MyOptimizer/radam.py b/code/MyOptimizer/radam.py
similarity index 100%
rename from MyOptimizer/radam.py
rename to code/MyOptimizer/radam.py
diff --git a/MyOptimizer/rmsprop_tf.py b/code/MyOptimizer/rmsprop_tf.py
similarity index 100%
rename from MyOptimizer/rmsprop_tf.py
rename to code/MyOptimizer/rmsprop_tf.py
diff --git a/MyOptimizer/sgdp.py b/code/MyOptimizer/sgdp.py
similarity index 100%
rename from MyOptimizer/sgdp.py
rename to code/MyOptimizer/sgdp.py
diff --git a/code/__pycache__/test_visualize.cpython-39.pyc b/code/__pycache__/test_visualize.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f22258b82070b835ec1875cf4e46622cbad9198a
Binary files /dev/null and b/code/__pycache__/test_visualize.cpython-39.pyc differ
diff --git a/code/__pycache__/train_loop.cpython-39.pyc b/code/__pycache__/train_loop.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ee4cf9ea9af6982a3d224e17d28e5ff5d4ef98c
Binary files /dev/null and b/code/__pycache__/train_loop.cpython-39.pyc differ
diff --git a/code/datasets/__init__.py b/code/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eb1fe74d53130bd211a4475d70a12507e9cf879
--- /dev/null
+++ b/code/datasets/__init__.py
@@ -0,0 +1,3 @@
+
+from .custom_jpg_dataloader import JPGMILDataloader
+from .data_interface import MILDataModule
diff --git a/code/datasets/__pycache__/__init__.cpython-39.pyc b/code/datasets/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3735aa6b4ec5cd68471afdec6d2068bc6293a2ed
Binary files /dev/null and b/code/datasets/__pycache__/__init__.cpython-39.pyc differ
diff --git a/datasets/__pycache__/camel_data.cpython-39.pyc b/code/datasets/__pycache__/camel_data.cpython-39.pyc
similarity index 100%
rename from datasets/__pycache__/camel_data.cpython-39.pyc
rename to code/datasets/__pycache__/camel_data.cpython-39.pyc
diff --git a/datasets/__pycache__/camel_dataloader.cpython-39.pyc b/code/datasets/__pycache__/camel_dataloader.cpython-39.pyc
similarity index 100%
rename from datasets/__pycache__/camel_dataloader.cpython-39.pyc
rename to code/datasets/__pycache__/camel_dataloader.cpython-39.pyc
diff --git a/datasets/__pycache__/custom_dataloader.cpython-39.pyc b/code/datasets/__pycache__/custom_dataloader.cpython-39.pyc
similarity index 100%
rename from datasets/__pycache__/custom_dataloader.cpython-39.pyc
rename to code/datasets/__pycache__/custom_dataloader.cpython-39.pyc
diff --git a/code/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc b/code/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1cd6ef3a6a996cd23622c3e29f8ed89d4d0d9038
Binary files /dev/null and b/code/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc differ
diff --git a/datasets/__pycache__/data_interface.cpython-39.pyc b/code/datasets/__pycache__/data_interface.cpython-39.pyc
similarity index 100%
rename from datasets/__pycache__/data_interface.cpython-39.pyc
rename to code/datasets/__pycache__/data_interface.cpython-39.pyc
diff --git a/datasets/camel_data.py b/code/datasets/camel_data.py
similarity index 100%
rename from datasets/camel_data.py
rename to code/datasets/camel_data.py
diff --git a/datasets/camel_dataloader.py b/code/datasets/camel_dataloader.py
similarity index 100%
rename from datasets/camel_dataloader.py
rename to code/datasets/camel_dataloader.py
diff --git a/datasets/custom_dataloader.py b/code/datasets/custom_dataloader.py
similarity index 100%
rename from datasets/custom_dataloader.py
rename to code/datasets/custom_dataloader.py
diff --git a/datasets/custom_jpg_dataloader.py b/code/datasets/custom_jpg_dataloader.py
similarity index 87%
rename from datasets/custom_jpg_dataloader.py
rename to code/datasets/custom_jpg_dataloader.py
index 722b27593ee28ebbbcf42249cd9545bdde6ed85c..b47fc7fbe2b461a5e446106f3d8c5a9edb55ba4a 100644
--- a/datasets/custom_jpg_dataloader.py
+++ b/code/datasets/custom_jpg_dataloader.py
@@ -25,18 +25,8 @@ class RangeNormalization(object):
         return (img / 255.0 - 0.5) / 0.5
 
 class JPGMILDataloader(data.Dataset):
-    """Represents an abstract HDF5 dataset. For single H5 container! 
     
-    Input params:
-        file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
-        mode: 'train' or 'test'
-        load_data: If True, loads all the data immediately into RAM. Use this if
-            the dataset is fits into memory. Otherwise, leave this at false and 
-            the data will load lazily.
-        data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
-
-    """
-    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=10, bag_size=1024):
+    def __init__(self, file_path, label_path, mode, n_classes, load_data=False, data_cache_size=10, max_bag_size=1296):
         super().__init__()
 
         self.data_info = []
@@ -49,7 +39,8 @@ class JPGMILDataloader(data.Dataset):
         # self.csv_path = csv_path
         self.label_path = label_path
         self.n_classes = n_classes
-        self.bag_size = bag_size
+        self.max_bag_size = max_bag_size
+        self.min_bag_size = 120
         self.empty_slides = []
         # self.label_file = label_path
         recursive = True
@@ -64,7 +55,7 @@ class JPGMILDataloader(data.Dataset):
                 for cohort in Path(self.file_path).iterdir():
                     x_complete_path = Path(self.file_path) / cohort / 'BLOCKS' / Path(x)
                     if x_complete_path.is_dir():
-                        if len(list(x_complete_path.iterdir())) > 50:
+                        if len(list(x_complete_path.iterdir())) > self.min_bag_size:
                         # print(x_complete_path)
                             self.slideLabelDict[x] = y
                             self.files.append(x_complete_path)
@@ -87,7 +78,7 @@ class JPGMILDataloader(data.Dataset):
         sometimes5 = lambda aug: iaa.Sometimes(0.9, aug, name="Random5")
 
         self.train_transforms = iaa.Sequential([
-            iaa.AddToHueAndSaturation(value=(-13, 13), name="MyHSV"),
+            iaa.AddToHueAndSaturation(value=(-30, 30), name="MyHSV"), #13
             sometimes2(iaa.GammaContrast(gamma=(0.85, 1.15), name="MyGamma")),
             iaa.Fliplr(0.5, name="MyFlipLR"),
             iaa.Flipud(0.5, name="MyFlipUD"),
@@ -139,7 +130,7 @@ class JPGMILDataloader(data.Dataset):
 
     def __getitem__(self, index):
         # get data
-        batch, label, name = self.get_data(index)
+        (batch, batch_names), label, name = self.get_data(index)
         out_batch = []
         seq_img_d = self.train_transforms.to_deterministic()
         
@@ -174,7 +165,7 @@ class JPGMILDataloader(data.Dataset):
         label = torch.as_tensor(label)
         label = torch.nn.functional.one_hot(label, num_classes=self.n_classes)
         # print(out_batch)
-        return out_batch, label, name #, name_batch
+        return out_batch, label, (name, batch_names) #, name_batch
 
     def __len__(self):
         return len(self.data_info)
@@ -194,7 +185,7 @@ class JPGMILDataloader(data.Dataset):
         data_info structure.
         """
         wsi_batch = []
-        tile_names = []
+        name_batch = []
         # print(wsi_batch)
         # for tile_path in Path(file_path).iterdir():
         #     print(tile_path)
@@ -206,7 +197,7 @@ class JPGMILDataloader(data.Dataset):
             # print(wsi_batch)
             wsi_batch.append(img)
             
-            tile_names.append(tile_path.stem)
+            name_batch.append(tile_path.stem)
                 
         # if wsi_batch:
         wsi_batch = torch.stack(wsi_batch)
@@ -220,8 +211,9 @@ class JPGMILDataloader(data.Dataset):
         # if wsi_batch.shape[1:] != torch.Size([3, 256, 256]) and wsi_batch.shape[1:] != torch.Size([256, 256, 3]):
         #     print(file_path)
         #     print(wsi_batch.shape)
-        # wsi_batch, name_batch = to_fixed_size_bag(wsi_batch, self.bag_size)
-        idx = self._add_to_cache(wsi_batch, file_path)
+        if wsi_batch.size(0) > self.max_bag_size:
+            wsi_batch, name_batch, _ = to_fixed_size_bag(wsi_batch, name_batch, self.max_bag_size)
+        idx = self._add_to_cache((wsi_batch,name_batch), file_path)
         file_idx = next(i for i,v in enumerate(self.data_info) if v['data_path'] == file_path)
         self.data_info[file_idx + idx]['cache_idx'] = idx
 
@@ -309,25 +301,26 @@ class RandomHueSaturationValue(object):
             img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
         return img #, lbl
 
-def to_fixed_size_bag(bag, bag_size: int = 512):
+def to_fixed_size_bag(bag, names, bag_size: int = 512):
 
     #duplicate bag instances unitl 
 
     # get up to bag_size elements
     bag_idxs = torch.randperm(bag.shape[0])[:bag_size]
     bag_samples = bag[bag_idxs]
+    name_samples = [names[i] for i in bag_idxs]
     # bag_sample_names = [bag_names[i] for i in bag_idxs]
-    q, r  = divmod(bag_size, bag_samples.shape[0])
-    if q > 0:
-        bag_samples = torch.cat([bag_samples]*q, 0)
+    # q, r  = divmod(bag_size, bag_samples.shape[0])
+    # if q > 0:
+    #     bag_samples = torch.cat([bag_samples]*q, 0)
     
-    self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
+    # self_padded = torch.cat([bag_samples, bag_samples[:r,:, :, :]])
 
     # zero-pad if we don't have enough samples
     # zero_padded = torch.cat((bag_samples,
                             # torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1], bag_samples.shape[2], bag_samples.shape[3])))
 
-    return self_padded, min(bag_size, len(bag))
+    return bag_samples, name_samples, min(bag_size, len(bag))
 
 
 class RandomHueSaturationValue(object):
@@ -365,23 +358,23 @@ if __name__ == '__main__':
 
     home = Path.cwd().parts[1]
     train_csv = f'/{home}/ylan/DeepGraft_project/code/debug_train.csv'
-    data_root = f'/{home}/ylan/data/DeepGraft/256_256um'
+    data_root = f'/{home}/ylan/data/DeepGraft/224_128um'
     # data_root = f'/{home}/ylan/DeepGraft/dataset/hdf5/256_256um_split/'
     # label_path = f'/{home}/ylan/DeepGraft_project/code/split_PAS_bin.json'
     label_path = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
-    output_path = f'/{home}/ylan/DeepGraft/dataset/check/256_256um_split/augments'
-    os.makedirs(output_path, exist_ok=True)
+    output_dir = f'/{data_root}/debug/augments'
+    os.makedirs(output_dir, exist_ok=True)
 
     n_classes = 2
 
-    dataset = JPGMILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes, bag_size=20)
+    dataset = JPGMILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes)
     # print(dataset.dataset)
     # a = int(len(dataset)* 0.8)
     # b = int(len(dataset) - a)
     # train_ds, val_ds = torch.utils.data.random_split(dataset, [a, b])
-    dl = DataLoader(dataset,  None, num_workers=1)
-    print(len(dl))
-    dl = DataLoader(dataset,  None, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
+    dl = DataLoader(dataset, None, num_workers=1, shuffle=False)
+    # print(len(dl))
+    # dl = DataLoader(dataset, batch_size=1, sampler=ImbalancedDatasetSampler(dataset), num_workers=5)
 
     
     
@@ -394,18 +387,16 @@ if __name__ == '__main__':
     label_count = [0] *n_classes
     print(len(dl))
     for item in dl: 
-        # if c >=10:
-        #     break
-        bag, label, name = item
-        # print(label)
+        if c >= 5:
+            break
+        bag, label, (name, _) = item
         label_count[torch.argmax(label)] += 1
         # print(name)
         # if name == 'RU0248_PASD_jke_PASD_20200201_195900_BIG':
         
             # print(bag)
             # print(label)
-        c += 1
-    print(label_count)
+        
     #     # # print(bag.shape)
     #     # if bag.shape[1] == 1:
     #     #     print(name)
@@ -423,21 +414,23 @@ if __name__ == '__main__':
     #     # bag = item[0]
     #     bag = bag.squeeze()
     #     original = original.squeeze()
-    #     for i in range(bag.shape[0]):
-    #         img = bag[i, :, :, :]
-    #         img = img.squeeze()
+        output_path = Path(output_dir) / name
+        output_path.mkdir(exist_ok=True)
+        for i in range(bag.shape[0]):
+            img = bag[i, :, :, :]
+            img = img.squeeze()
             
-    #         img = ((img-img.min())/(img.max() - img.min())) * 255
-    #         print(img)
-    #         # print(img)
-    #         img = img.numpy().astype(np.uint8).transpose(1,2,0)
+            img = ((img-img.min())/(img.max() - img.min())) * 255
+            # print(img)
+            # print(img)
+            img = img.numpy().astype(np.uint8).transpose(1,2,0)
 
             
-    #         img = Image.fromarray(img)
-    #         img = img.convert('RGB')
-    #         img.save(f'{output_path}/{i}.png')
-
+            img = Image.fromarray(img)
+            img = img.convert('RGB')
+            img.save(f'{output_path}/{i}.png')
 
+        c += 1
             
     #         o_img = original[i,:,:,:]
     #         o_img = o_img.squeeze()
diff --git a/code/datasets/dali_dataloader.py b/code/datasets/dali_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..0851613c97ebc3705d324e03cb22ce6060e12315
--- /dev/null
+++ b/code/datasets/dali_dataloader.py
@@ -0,0 +1,139 @@
+from nvidia.dali import pipeline_def
+from nvidia.dali.pipeline import Pipeline
+import nvidia.dali.fn as fn
+import nvidia.dali.types as types
+
+from pathlib import Path
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+import math
+import json
+import numpy as np
+import cupy as cp
+import torch
+import imageio
+
+batch_size = 10
+home = Path.cwd().parts[1]
+# image_filename = f"/{home}/ylan/data/DeepGraft/224_128um/Aachen_Biopsy_Slides/BLOCKS/"
+
+class ExternalInputIterator(object):
+    def __init__(self, batch_size):
+        self.file_path = f"/{home}/ylan/data/DeepGraft/224_128um/"
+        # self.label_file = f'/{home}/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+        self.label_path = f'/{home}/ylan/DeepGraft/training_tables/split_Aachen_PAS_tcmr_viral.json'
+        self.batch_size = batch_size
+        
+        mode = 'test'
+        # with open(self.images_dir + "file_list.txt", 'r') as f:
+        #     self.files = [line.rstrip() for line in f if line != '']
+        self.slideLabelDict = {}
+        self.files = []
+        self.empty_slides = []
+        with open(self.label_path, 'r') as f:
+            temp_slide_label_dict = json.load(f)[mode]
+            for (x, y) in temp_slide_label_dict:
+                x = Path(x).stem 
+
+                # x_complete_path = Path(self.file_path)/Path(x)
+                for cohort in Path(self.file_path).iterdir():
+                    x_complete_path = Path(self.file_path) / cohort / 'BLOCKS' / Path(x)
+                    if x_complete_path.is_dir():
+                        if len(list(x_complete_path.iterdir())) > 50:
+                        # print(x_complete_path)
+                            # self.slideLabelDict[x] = y
+                            self.files.append((x_complete_path, y))
+                        else: self.empty_slides.append(x_complete_path)
+        
+        # shuffle(self.files)
+
+    def __iter__(self):
+        self.i = 0
+        self.n = len(self.files)
+        return self
+
+    def __next__(self):
+        batch = []
+        labels = []
+        file_names = []
+
+        for _ in range(self.batch_size):
+            wsi_batch = []
+            wsi_filename, label = self.files[self.i]
+            # jpeg_filename, label = self.files[self.i].split(' ')
+            name = Path(wsi_filename).stem
+            file_names.append(name)
+            for img_path in Path(wsi_filename).iterdir():
+                # f = open(img, 'rb')
+
+                f = imageio.imread(img_path)
+                img = cp.asarray(f)
+                wsi_batch.append(img.astype(cp.uint8))
+            wsi_batch = cp.stack(wsi_batch)
+            batch.append(wsi_batch)
+            labels.append(cp.array([label], dtype = cp.uint8))
+            
+            self.i = (self.i + 1) % self.n
+        # print(batch)
+        # print(labels)
+        return (batch, labels)
+
+
+eii = ExternalInputIterator(batch_size=10)
+
+@pipeline_def()
+def hsv_pipeline(device, hue, saturation, value):
+    # files, labels = fn.readers.file(file_root=image_filename)
+    files, labels = fn.external_source(source=eii, num_outputs=2, dtype=types.UINT8)
+    images = fn.decoders.image(files, device = 'cpu' if device == 'cpu' else 'mixed')
+    converted = fn.hsv(images, hue=hue, saturation=saturation, value=value)
+    return images, converted
+
+def display(outputs, idx, columns=2, captions=None, cpu=True):
+    rows = int(math.ceil(len(outputs) / columns))
+    fig = plt.figure()
+    fig.set_size_inches(16, 6 * rows)
+    gs = gridspec.GridSpec(rows, columns)
+    row = 0
+    col = 0
+    for i, out in enumerate(outputs):
+        plt.subplot(gs[i])
+        plt.axis("off")
+        if captions is not None:
+            plt.title(captions[i])
+        plt.imshow(out.at(idx) if cpu else out.as_cpu().at(idx))
+
+# pipe_cpu = hsv_pipeline(device='cpu', hue=120, saturation=1, value=0.4, batch_size=batch_size, num_threads=1, device_id=0)
+# pipe_cpu.build()
+# cpu_output = pipe_cpu.run()
+
+# display(cpu_output, 3, captions=["Original", "Hue=120, Saturation=1, Value=0.4"])
+
+# pipe_gpu = hsv_pipeline(device='gpu', hue=120, saturation=2, value=1, batch_size=batch_size, num_threads=1, device_id=0)
+# pipe_gpu.build()
+# gpu_output = pipe_gpu.run()
+
+# display(gpu_output, 0, cpu=False, captions=["Original", "Hue=120, Saturation=2, Value=1"])
+
+
+pipe_gpu = Pipeline(batch_size=batch_size, num_threads=2, device_id=0)
+with pipe_gpu:
+    images, labels = fn.external_source(source=eii, num_outputs=2, device="gpu", dtype=types.UINT8)
+    enhance = fn.brightness_contrast(images, contrast=2)
+    pipe_gpu.set_outputs(enhance, labels)
+
+pipe_gpu.build()
+pipe_out_gpu = pipe_gpu.run()
+batch_gpu = pipe_out_gpu[0].as_cpu()
+labels_gpu = pipe_out_gpu[1].as_cpu()
+# file_names = pipe_out_gpu[2].as_cpu()
+
+output_path = f"/{home}/ylan/data/DeepGraft/224_128um/debug/augments/"
+output_path.mkdir(exist_ok=True)
+
+
+img = batch_gpu.at(2)
+print(img.shape)
+print(labels_gpu.at(2))
+plt.axis('off')
+plt.imsave(f'{output_path}/0.jpg', img[0, :, :, :])
\ No newline at end of file
diff --git a/datasets/data_interface.py b/code/datasets/data_interface.py
similarity index 100%
rename from datasets/data_interface.py
rename to code/datasets/data_interface.py
diff --git a/code/lightning_logs/version_0/cm_test.png b/code/lightning_logs/version_0/cm_test.png
new file mode 100644
index 0000000000000000000000000000000000000000..6fb672f0640e83968af9558a7d438cbf42b51815
Binary files /dev/null and b/code/lightning_logs/version_0/cm_test.png differ
diff --git a/code/lightning_logs/version_0/events.out.tfevents.1657535217.dgx2.2080039.0 b/code/lightning_logs/version_0/events.out.tfevents.1657535217.dgx2.2080039.0
new file mode 100644
index 0000000000000000000000000000000000000000..734f0a9372e6608b0ab5aed218222431f0044e31
Binary files /dev/null and b/code/lightning_logs/version_0/events.out.tfevents.1657535217.dgx2.2080039.0 differ
diff --git a/code/lightning_logs/version_0/hparams.yaml b/code/lightning_logs/version_0/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..de11b2861a631025df79cacf8d8a13415bbe1769
--- /dev/null
+++ b/code/lightning_logs/version_0/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/256_256um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/256_256um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_1/cm_test.png b/code/lightning_logs/version_1/cm_test.png
new file mode 100644
index 0000000000000000000000000000000000000000..e2ba8d1036ec9fa3e57078dd47b9a747a2809fc0
Binary files /dev/null and b/code/lightning_logs/version_1/cm_test.png differ
diff --git a/code/lightning_logs/version_1/events.out.tfevents.1657535625.dgx2.2086189.0 b/code/lightning_logs/version_1/events.out.tfevents.1657535625.dgx2.2086189.0
new file mode 100644
index 0000000000000000000000000000000000000000..b11a9e985744cff77e2d36a3e726bb4b8b9ac89c
Binary files /dev/null and b/code/lightning_logs/version_1/events.out.tfevents.1657535625.dgx2.2086189.0 differ
diff --git a/code/lightning_logs/version_1/hparams.yaml b/code/lightning_logs/version_1/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3027219ffc70ee8a4f037c88505b459131ec11e4
--- /dev/null
+++ b/code/lightning_logs/version_1/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/256_256um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/256_256um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_10/events.out.tfevents.1657546166.dgx1.47613.0 b/code/lightning_logs/version_10/events.out.tfevents.1657546166.dgx1.47613.0
new file mode 100644
index 0000000000000000000000000000000000000000..afa287ceab85beb3b739a44d7e230d7ada1d726f
Binary files /dev/null and b/code/lightning_logs/version_10/events.out.tfevents.1657546166.dgx1.47613.0 differ
diff --git a/code/lightning_logs/version_10/hparams.yaml b/code/lightning_logs/version_10/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f
--- /dev/null
+++ b/code/lightning_logs/version_10/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_11/events.out.tfevents.1657546322.dgx1.48740.0 b/code/lightning_logs/version_11/events.out.tfevents.1657546322.dgx1.48740.0
new file mode 100644
index 0000000000000000000000000000000000000000..328e3163a4bab058ba6a473d8d58e58f2fb51ac5
Binary files /dev/null and b/code/lightning_logs/version_11/events.out.tfevents.1657546322.dgx1.48740.0 differ
diff --git a/code/lightning_logs/version_11/hparams.yaml b/code/lightning_logs/version_11/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f
--- /dev/null
+++ b/code/lightning_logs/version_11/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_12/events.out.tfevents.1657546521.dgx1.50053.0 b/code/lightning_logs/version_12/events.out.tfevents.1657546521.dgx1.50053.0
new file mode 100644
index 0000000000000000000000000000000000000000..692f51dfebf86309232688295f7c9f36fc3096a5
Binary files /dev/null and b/code/lightning_logs/version_12/events.out.tfevents.1657546521.dgx1.50053.0 differ
diff --git a/code/lightning_logs/version_12/hparams.yaml b/code/lightning_logs/version_12/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f
--- /dev/null
+++ b/code/lightning_logs/version_12/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_13/events.out.tfevents.1657546918.dgx1.52290.0 b/code/lightning_logs/version_13/events.out.tfevents.1657546918.dgx1.52290.0
new file mode 100644
index 0000000000000000000000000000000000000000..db43abd93eaa392fca46ac8cac0cf6ea91076b76
Binary files /dev/null and b/code/lightning_logs/version_13/events.out.tfevents.1657546918.dgx1.52290.0 differ
diff --git a/code/lightning_logs/version_13/hparams.yaml b/code/lightning_logs/version_13/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f
--- /dev/null
+++ b/code/lightning_logs/version_13/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_14/events.out.tfevents.1657546992.dgx1.53435.0 b/code/lightning_logs/version_14/events.out.tfevents.1657546992.dgx1.53435.0
new file mode 100644
index 0000000000000000000000000000000000000000..3ac295e5023bf3d8ac7d7a060eb0aa861cfe9c80
Binary files /dev/null and b/code/lightning_logs/version_14/events.out.tfevents.1657546992.dgx1.53435.0 differ
diff --git a/code/lightning_logs/version_14/hparams.yaml b/code/lightning_logs/version_14/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f
--- /dev/null
+++ b/code/lightning_logs/version_14/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_15/events.out.tfevents.1657547134.dgx1.54703.0 b/code/lightning_logs/version_15/events.out.tfevents.1657547134.dgx1.54703.0
new file mode 100644
index 0000000000000000000000000000000000000000..797ee6e3689ccb965f4e97b808281e986ef758b0
Binary files /dev/null and b/code/lightning_logs/version_15/events.out.tfevents.1657547134.dgx1.54703.0 differ
diff --git a/code/lightning_logs/version_15/hparams.yaml b/code/lightning_logs/version_15/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f
--- /dev/null
+++ b/code/lightning_logs/version_15/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_16/events.out.tfevents.1657547198.dgx1.55641.0 b/code/lightning_logs/version_16/events.out.tfevents.1657547198.dgx1.55641.0
new file mode 100644
index 0000000000000000000000000000000000000000..b4cf4197bf688807ead868952cb4b66fdee7b451
Binary files /dev/null and b/code/lightning_logs/version_16/events.out.tfevents.1657547198.dgx1.55641.0 differ
diff --git a/code/lightning_logs/version_16/hparams.yaml b/code/lightning_logs/version_16/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f
--- /dev/null
+++ b/code/lightning_logs/version_16/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_17/events.out.tfevents.1657623153.dgx1.41577.0 b/code/lightning_logs/version_17/events.out.tfevents.1657623153.dgx1.41577.0
new file mode 100644
index 0000000000000000000000000000000000000000..8236270b84bc8d5cbd60a4f1b08f8752fb734a6a
Binary files /dev/null and b/code/lightning_logs/version_17/events.out.tfevents.1657623153.dgx1.41577.0 differ
diff --git a/code/lightning_logs/version_17/hparams.yaml b/code/lightning_logs/version_17/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..399847bb2e810ee8b9dde0c0723e1615ac8d92dc
--- /dev/null
+++ b/code/lightning_logs/version_17/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 7
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_18/events.out.tfevents.1657624768.dgx1.72352.0 b/code/lightning_logs/version_18/events.out.tfevents.1657624768.dgx1.72352.0
new file mode 100644
index 0000000000000000000000000000000000000000..e2c425bbcbeaabdd74a5233ee4679ec8c0b467cf
Binary files /dev/null and b/code/lightning_logs/version_18/events.out.tfevents.1657624768.dgx1.72352.0 differ
diff --git a/code/lightning_logs/version_18/hparams.yaml b/code/lightning_logs/version_18/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5d3068796e8af6c6aacc06462015239f67239118
--- /dev/null
+++ b/code/lightning_logs/version_18/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_19/events.out.tfevents.1657624869.dgx1.76706.0 b/code/lightning_logs/version_19/events.out.tfevents.1657624869.dgx1.76706.0
new file mode 100644
index 0000000000000000000000000000000000000000..da36871f7274cacd3a80d00f386c7f1036749c2c
Binary files /dev/null and b/code/lightning_logs/version_19/events.out.tfevents.1657624869.dgx1.76706.0 differ
diff --git a/code/lightning_logs/version_19/hparams.yaml b/code/lightning_logs/version_19/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5d3068796e8af6c6aacc06462015239f67239118
--- /dev/null
+++ b/code/lightning_logs/version_19/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_2/events.out.tfevents.1657543650.dgx1.33429.0 b/code/lightning_logs/version_2/events.out.tfevents.1657543650.dgx1.33429.0
new file mode 100644
index 0000000000000000000000000000000000000000..0e8f866c6b4bb783b058864f37241e2a7966e66a
Binary files /dev/null and b/code/lightning_logs/version_2/events.out.tfevents.1657543650.dgx1.33429.0 differ
diff --git a/code/lightning_logs/version_2/hparams.yaml b/code/lightning_logs/version_2/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3978f85d8fc605d6a8b9c42d1590620769fbdf5d
--- /dev/null
+++ b/code/lightning_logs/version_2/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 0
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/256_256um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/256_256um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_20/events.out.tfevents.1657625133.dgx1.5999.0 b/code/lightning_logs/version_20/events.out.tfevents.1657625133.dgx1.5999.0
new file mode 100644
index 0000000000000000000000000000000000000000..30bdf9ae84865abecf1bbf3fe104f84a7e9f681e
Binary files /dev/null and b/code/lightning_logs/version_20/events.out.tfevents.1657625133.dgx1.5999.0 differ
diff --git a/code/lightning_logs/version_20/hparams.yaml b/code/lightning_logs/version_20/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8ca6a0997effb2e3843c368c93978c2621870cb
--- /dev/null
+++ b/code/lightning_logs/version_20/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_21/events.out.tfevents.1657625248.dgx1.11114.0 b/code/lightning_logs/version_21/events.out.tfevents.1657625248.dgx1.11114.0
new file mode 100644
index 0000000000000000000000000000000000000000..580ae5449c6c891452ad6bb188dade7852059ad3
Binary files /dev/null and b/code/lightning_logs/version_21/events.out.tfevents.1657625248.dgx1.11114.0 differ
diff --git a/code/lightning_logs/version_21/hparams.yaml b/code/lightning_logs/version_21/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8ca6a0997effb2e3843c368c93978c2621870cb
--- /dev/null
+++ b/code/lightning_logs/version_21/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_22/events.out.tfevents.1657625470.dgx1.20071.0 b/code/lightning_logs/version_22/events.out.tfevents.1657625470.dgx1.20071.0
new file mode 100644
index 0000000000000000000000000000000000000000..259e2f747c61500222f0b1faf00b2d6b3a41f41e
Binary files /dev/null and b/code/lightning_logs/version_22/events.out.tfevents.1657625470.dgx1.20071.0 differ
diff --git a/code/lightning_logs/version_22/hparams.yaml b/code/lightning_logs/version_22/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8ca6a0997effb2e3843c368c93978c2621870cb
--- /dev/null
+++ b/code/lightning_logs/version_22/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_23/events.out.tfevents.1657625510.dgx1.22295.0 b/code/lightning_logs/version_23/events.out.tfevents.1657625510.dgx1.22295.0
new file mode 100644
index 0000000000000000000000000000000000000000..c74c353abdb31d4bab1671520ec1793298923045
Binary files /dev/null and b/code/lightning_logs/version_23/events.out.tfevents.1657625510.dgx1.22295.0 differ
diff --git a/code/lightning_logs/version_23/hparams.yaml b/code/lightning_logs/version_23/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8ca6a0997effb2e3843c368c93978c2621870cb
--- /dev/null
+++ b/code/lightning_logs/version_23/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_24/events.out.tfevents.1657625570.dgx1.25099.0 b/code/lightning_logs/version_24/events.out.tfevents.1657625570.dgx1.25099.0
new file mode 100644
index 0000000000000000000000000000000000000000..d23806a37c9d31bcee1beeafa2689732bbe3ab9c
Binary files /dev/null and b/code/lightning_logs/version_24/events.out.tfevents.1657625570.dgx1.25099.0 differ
diff --git a/code/lightning_logs/version_24/hparams.yaml b/code/lightning_logs/version_24/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8ca6a0997effb2e3843c368c93978c2621870cb
--- /dev/null
+++ b/code/lightning_logs/version_24/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_25/events.out.tfevents.1657625613.dgx1.27343.0 b/code/lightning_logs/version_25/events.out.tfevents.1657625613.dgx1.27343.0
new file mode 100644
index 0000000000000000000000000000000000000000..a51495644c8c24447f5a9d0bce559fe1c0f7e857
Binary files /dev/null and b/code/lightning_logs/version_25/events.out.tfevents.1657625613.dgx1.27343.0 differ
diff --git a/code/lightning_logs/version_25/hparams.yaml b/code/lightning_logs/version_25/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8ca6a0997effb2e3843c368c93978c2621870cb
--- /dev/null
+++ b/code/lightning_logs/version_25/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_26/events.out.tfevents.1657625763.dgx1.33397.0 b/code/lightning_logs/version_26/events.out.tfevents.1657625763.dgx1.33397.0
new file mode 100644
index 0000000000000000000000000000000000000000..19376c22a488f8e753dbdbff5f0e810aa1ae4057
Binary files /dev/null and b/code/lightning_logs/version_26/events.out.tfevents.1657625763.dgx1.33397.0 differ
diff --git a/code/lightning_logs/version_26/hparams.yaml b/code/lightning_logs/version_26/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8ca6a0997effb2e3843c368c93978c2621870cb
--- /dev/null
+++ b/code/lightning_logs/version_26/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_27/events.out.tfevents.1657625819.dgx1.36236.0 b/code/lightning_logs/version_27/events.out.tfevents.1657625819.dgx1.36236.0
new file mode 100644
index 0000000000000000000000000000000000000000..63e1dd8a5203e7fb1ff819e29310446d3cd72067
Binary files /dev/null and b/code/lightning_logs/version_27/events.out.tfevents.1657625819.dgx1.36236.0 differ
diff --git a/code/lightning_logs/version_27/hparams.yaml b/code/lightning_logs/version_27/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8ca6a0997effb2e3843c368c93978c2621870cb
--- /dev/null
+++ b/code/lightning_logs/version_27/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_28/events.out.tfevents.1657625859.dgx1.38333.0 b/code/lightning_logs/version_28/events.out.tfevents.1657625859.dgx1.38333.0
new file mode 100644
index 0000000000000000000000000000000000000000..6b6c0ae6852731226a72c9746e71c67c5b96361b
Binary files /dev/null and b/code/lightning_logs/version_28/events.out.tfevents.1657625859.dgx1.38333.0 differ
diff --git a/code/lightning_logs/version_28/hparams.yaml b/code/lightning_logs/version_28/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8ca6a0997effb2e3843c368c93978c2621870cb
--- /dev/null
+++ b/code/lightning_logs/version_28/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_29/events.out.tfevents.1657625903.dgx1.40628.0 b/code/lightning_logs/version_29/events.out.tfevents.1657625903.dgx1.40628.0
new file mode 100644
index 0000000000000000000000000000000000000000..8783be4dc429493cadb5cc21bca2dc1a24e98d8f
Binary files /dev/null and b/code/lightning_logs/version_29/events.out.tfevents.1657625903.dgx1.40628.0 differ
diff --git a/code/lightning_logs/version_29/hparams.yaml b/code/lightning_logs/version_29/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8ca6a0997effb2e3843c368c93978c2621870cb
--- /dev/null
+++ b/code/lightning_logs/version_29/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_3/cm_test.png b/code/lightning_logs/version_3/cm_test.png
new file mode 100644
index 0000000000000000000000000000000000000000..4a523c98c4bd81c8e446deba57c6e24579f4911a
Binary files /dev/null and b/code/lightning_logs/version_3/cm_test.png differ
diff --git a/code/lightning_logs/version_3/events.out.tfevents.1657543830.dgx1.34643.0 b/code/lightning_logs/version_3/events.out.tfevents.1657543830.dgx1.34643.0
new file mode 100644
index 0000000000000000000000000000000000000000..0726f25de4e54bfa2acdf91a638a4d86787f60ac
Binary files /dev/null and b/code/lightning_logs/version_3/events.out.tfevents.1657543830.dgx1.34643.0 differ
diff --git a/code/lightning_logs/version_3/hparams.yaml b/code/lightning_logs/version_3/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a662047b861838dd7c0c87b5d45111c35a9a254d
--- /dev/null
+++ b/code/lightning_logs/version_3/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 0
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_30/events.out.tfevents.1657625960.dgx1.43463.0 b/code/lightning_logs/version_30/events.out.tfevents.1657625960.dgx1.43463.0
new file mode 100644
index 0000000000000000000000000000000000000000..5a9ead248d726804faeead94babbb17b53639f97
Binary files /dev/null and b/code/lightning_logs/version_30/events.out.tfevents.1657625960.dgx1.43463.0 differ
diff --git a/code/lightning_logs/version_30/hparams.yaml b/code/lightning_logs/version_30/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8ca6a0997effb2e3843c368c93978c2621870cb
--- /dev/null
+++ b/code/lightning_logs/version_30/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_31/events.out.tfevents.1657626120.dgx1.48909.0 b/code/lightning_logs/version_31/events.out.tfevents.1657626120.dgx1.48909.0
new file mode 100644
index 0000000000000000000000000000000000000000..125ad0c8b29dcda71447e15108aa5dc4d5d085aa
Binary files /dev/null and b/code/lightning_logs/version_31/events.out.tfevents.1657626120.dgx1.48909.0 differ
diff --git a/code/lightning_logs/version_31/hparams.yaml b/code/lightning_logs/version_31/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8ca6a0997effb2e3843c368c93978c2621870cb
--- /dev/null
+++ b/code/lightning_logs/version_31/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_32/events.out.tfevents.1657626628.dgx1.65460.0 b/code/lightning_logs/version_32/events.out.tfevents.1657626628.dgx1.65460.0
new file mode 100644
index 0000000000000000000000000000000000000000..32f354dc772b3edd2b278e47b884bfb45af1ef7e
Binary files /dev/null and b/code/lightning_logs/version_32/events.out.tfevents.1657626628.dgx1.65460.0 differ
diff --git a/code/lightning_logs/version_32/hparams.yaml b/code/lightning_logs/version_32/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8ca6a0997effb2e3843c368c93978c2621870cb
--- /dev/null
+++ b/code/lightning_logs/version_32/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_33/events.out.tfevents.1657626794.dgx1.71384.0 b/code/lightning_logs/version_33/events.out.tfevents.1657626794.dgx1.71384.0
new file mode 100644
index 0000000000000000000000000000000000000000..1c2bc9d73a885ec5d45e076338de61fb63e3475d
Binary files /dev/null and b/code/lightning_logs/version_33/events.out.tfevents.1657626794.dgx1.71384.0 differ
diff --git a/code/lightning_logs/version_33/hparams.yaml b/code/lightning_logs/version_33/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8ca6a0997effb2e3843c368c93978c2621870cb
--- /dev/null
+++ b/code/lightning_logs/version_33/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - test
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: test
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_4/events.out.tfevents.1657545281.dgx1.41143.0 b/code/lightning_logs/version_4/events.out.tfevents.1657545281.dgx1.41143.0
new file mode 100644
index 0000000000000000000000000000000000000000..04c542e7c7c8be7ebf3a1c040e54855ccc321392
Binary files /dev/null and b/code/lightning_logs/version_4/events.out.tfevents.1657545281.dgx1.41143.0 differ
diff --git a/code/lightning_logs/version_4/hparams.yaml b/code/lightning_logs/version_4/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a662047b861838dd7c0c87b5d45111c35a9a254d
--- /dev/null
+++ b/code/lightning_logs/version_4/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 0
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_5/events.out.tfevents.1657545407.dgx1.42159.0 b/code/lightning_logs/version_5/events.out.tfevents.1657545407.dgx1.42159.0
new file mode 100644
index 0000000000000000000000000000000000000000..de560d26010186b57ca59ac3412d2603077d9c84
Binary files /dev/null and b/code/lightning_logs/version_5/events.out.tfevents.1657545407.dgx1.42159.0 differ
diff --git a/code/lightning_logs/version_5/hparams.yaml b/code/lightning_logs/version_5/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f
--- /dev/null
+++ b/code/lightning_logs/version_5/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_6/events.out.tfevents.1657545574.dgx1.43290.0 b/code/lightning_logs/version_6/events.out.tfevents.1657545574.dgx1.43290.0
new file mode 100644
index 0000000000000000000000000000000000000000..b96b3cc792298ad93aea695082529868f03f3aa0
Binary files /dev/null and b/code/lightning_logs/version_6/events.out.tfevents.1657545574.dgx1.43290.0 differ
diff --git a/code/lightning_logs/version_6/hparams.yaml b/code/lightning_logs/version_6/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f
--- /dev/null
+++ b/code/lightning_logs/version_6/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_7/events.out.tfevents.1657545679.dgx1.44154.0 b/code/lightning_logs/version_7/events.out.tfevents.1657545679.dgx1.44154.0
new file mode 100644
index 0000000000000000000000000000000000000000..7094035f824ab5b403ca76dc451d320b83f282b6
Binary files /dev/null and b/code/lightning_logs/version_7/events.out.tfevents.1657545679.dgx1.44154.0 differ
diff --git a/code/lightning_logs/version_7/hparams.yaml b/code/lightning_logs/version_7/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f
--- /dev/null
+++ b/code/lightning_logs/version_7/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_8/events.out.tfevents.1657545837.dgx1.45339.0 b/code/lightning_logs/version_8/events.out.tfevents.1657545837.dgx1.45339.0
new file mode 100644
index 0000000000000000000000000000000000000000..feb362684898989fdf7cfbdf3ce96d9eeef16d25
Binary files /dev/null and b/code/lightning_logs/version_8/events.out.tfevents.1657545837.dgx1.45339.0 differ
diff --git a/code/lightning_logs/version_8/hparams.yaml b/code/lightning_logs/version_8/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f
--- /dev/null
+++ b/code/lightning_logs/version_8/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/code/lightning_logs/version_9/events.out.tfevents.1657545985.dgx1.46385.0 b/code/lightning_logs/version_9/events.out.tfevents.1657545985.dgx1.46385.0
new file mode 100644
index 0000000000000000000000000000000000000000..1f6b05a4e88ebf59fe40d3af43cf85a4991a21c0
Binary files /dev/null and b/code/lightning_logs/version_9/events.out.tfevents.1657545985.dgx1.46385.0 differ
diff --git a/code/lightning_logs/version_9/hparams.yaml b/code/lightning_logs/version_9/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..385e52b4b6623ae9f0b42aa6e66bf1642b66bf4f
--- /dev/null
+++ b/code/lightning_logs/version_9/hparams.yaml
@@ -0,0 +1,368 @@
+backbone: resnet50
+cfg: &id010 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - General
+    - &id002 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - comment
+        - null
+      - !!python/tuple
+        - seed
+        - 2021
+      - !!python/tuple
+        - fp16
+        - true
+      - !!python/tuple
+        - amp_level
+        - O2
+      - !!python/tuple
+        - precision
+        - 16
+      - !!python/tuple
+        - multi_gpu_mode
+        - dp
+      - !!python/tuple
+        - gpus
+        - &id001
+          - 1
+      - !!python/tuple
+        - epochs
+        - 200
+      - !!python/tuple
+        - grad_acc
+        - 2
+      - !!python/tuple
+        - frozen_bn
+        - false
+      - !!python/tuple
+        - patience
+        - 50
+      - !!python/tuple
+        - server
+        - train
+      - !!python/tuple
+        - log_path
+        - /home/ylan/workspace/TransMIL-DeepGraft/logs/
+      dictitems:
+        amp_level: O2
+        comment: null
+        epochs: 200
+        fp16: true
+        frozen_bn: false
+        gpus: *id001
+        grad_acc: 2
+        log_path: /home/ylan/workspace/TransMIL-DeepGraft/logs/
+        multi_gpu_mode: dp
+        patience: 50
+        precision: 16
+        seed: 2021
+        server: train
+      state: *id002
+  - !!python/tuple
+    - Data
+    - &id005 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - dataset_name
+        - custom
+      - !!python/tuple
+        - data_shuffle
+        - false
+      - !!python/tuple
+        - data_dir
+        - /home/ylan/data/DeepGraft/224_128um/
+      - !!python/tuple
+        - label_file
+        - /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+      - !!python/tuple
+        - fold
+        - 0
+      - !!python/tuple
+        - nfold
+        - 3
+      - !!python/tuple
+        - cross_val
+        - false
+      - !!python/tuple
+        - train_dataloader
+        - &id003 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id003
+      - !!python/tuple
+        - test_dataloader
+        - &id004 !!python/object/new:addict.addict.Dict
+          args:
+          - !!python/tuple
+            - batch_size
+            - 1
+          - !!python/tuple
+            - num_workers
+            - 8
+          dictitems:
+            batch_size: 1
+            num_workers: 8
+          state: *id004
+      - !!python/tuple
+        - bag_size
+        - 1024
+      dictitems:
+        bag_size: 1024
+        cross_val: false
+        data_dir: /home/ylan/data/DeepGraft/224_128um/
+        data_shuffle: false
+        dataset_name: custom
+        fold: 0
+        label_file: /home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral_test.json
+        nfold: 3
+        test_dataloader: *id004
+        train_dataloader: *id003
+      state: *id005
+  - !!python/tuple
+    - Model
+    - &id006 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - name
+        - TransMIL
+      - !!python/tuple
+        - n_classes
+        - 2
+      - !!python/tuple
+        - backbone
+        - resnet50
+      - !!python/tuple
+        - in_features
+        - 512
+      - !!python/tuple
+        - out_features
+        - 512
+      dictitems:
+        backbone: resnet50
+        in_features: 512
+        n_classes: 2
+        name: TransMIL
+        out_features: 512
+      state: *id006
+  - !!python/tuple
+    - Optimizer
+    - &id007 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - opt
+        - lookahead_radam
+      - !!python/tuple
+        - lr
+        - 0.0002
+      - !!python/tuple
+        - opt_eps
+        - null
+      - !!python/tuple
+        - opt_betas
+        - null
+      - !!python/tuple
+        - momentum
+        - null
+      - !!python/tuple
+        - weight_decay
+        - 0.01
+      dictitems:
+        lr: 0.0002
+        momentum: null
+        opt: lookahead_radam
+        opt_betas: null
+        opt_eps: null
+        weight_decay: 0.01
+      state: *id007
+  - !!python/tuple
+    - Loss
+    - &id008 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - base_loss
+        - CrossEntropyLoss
+      dictitems:
+        base_loss: CrossEntropyLoss
+      state: *id008
+  - !!python/tuple
+    - config
+    - ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+  - !!python/tuple
+    - version
+    - 4
+  - !!python/tuple
+    - epoch
+    - '159'
+  - !!python/tuple
+    - log_path
+    - &id009 !!python/object/apply:pathlib.PosixPath
+      - /
+      - home
+      - ylan
+      - workspace
+      - TransMIL-DeepGraft
+      - logs
+      - DeepGraft
+      - TransMIL
+      - tcmr_viral
+      - _resnet50_CrossEntropyLoss
+      - lightning_logs
+      - version_4
+  dictitems:
+    Data: *id005
+    General: *id002
+    Loss: *id008
+    Model: *id006
+    Optimizer: *id007
+    config: ../DeepGraft/TransMIL_resnet50_tcmr_viral.yaml
+    epoch: '159'
+    log_path: *id009
+    version: 4
+  state: *id010
+data: &id013 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - dataset_name
+    - custom
+  - !!python/tuple
+    - data_shuffle
+    - false
+  - !!python/tuple
+    - data_dir
+    - /home/ylan/data/DeepGraft/224_128um/
+  - !!python/tuple
+    - label_file
+    - /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+  - !!python/tuple
+    - fold
+    - 0
+  - !!python/tuple
+    - nfold
+    - 3
+  - !!python/tuple
+    - cross_val
+    - false
+  - !!python/tuple
+    - train_dataloader
+    - &id011 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id011
+  - !!python/tuple
+    - test_dataloader
+    - &id012 !!python/object/new:addict.addict.Dict
+      args:
+      - !!python/tuple
+        - batch_size
+        - 1
+      - !!python/tuple
+        - num_workers
+        - 8
+      dictitems:
+        batch_size: 1
+        num_workers: 8
+      state: *id012
+  - !!python/tuple
+    - bag_size
+    - 1024
+  dictitems:
+    bag_size: 1024
+    cross_val: false
+    data_dir: /home/ylan/data/DeepGraft/224_128um/
+    data_shuffle: false
+    dataset_name: custom
+    fold: 0
+    label_file: /home/ylan/DeepGraft/training_tables/split_bt_PAS_tcmr_viral.json
+    nfold: 3
+    test_dataloader: *id012
+    train_dataloader: *id011
+  state: *id013
+log: !!python/object/apply:pathlib.PosixPath
+- /
+- home
+- ylan
+- workspace
+- TransMIL-DeepGraft
+- logs
+- DeepGraft
+- TransMIL
+- tcmr_viral
+- _resnet50_CrossEntropyLoss
+loss: &id014 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - base_loss
+    - CrossEntropyLoss
+  dictitems:
+    base_loss: CrossEntropyLoss
+  state: *id014
+model: &id015 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - name
+    - TransMIL
+  - !!python/tuple
+    - n_classes
+    - 2
+  - !!python/tuple
+    - backbone
+    - resnet50
+  - !!python/tuple
+    - in_features
+    - 512
+  - !!python/tuple
+    - out_features
+    - 512
+  dictitems:
+    backbone: resnet50
+    in_features: 512
+    n_classes: 2
+    name: TransMIL
+    out_features: 512
+  state: *id015
+optimizer: &id016 !!python/object/new:addict.addict.Dict
+  args:
+  - !!python/tuple
+    - opt
+    - lookahead_radam
+  - !!python/tuple
+    - lr
+    - 0.0002
+  - !!python/tuple
+    - opt_eps
+    - null
+  - !!python/tuple
+    - opt_betas
+    - null
+  - !!python/tuple
+    - momentum
+    - null
+  - !!python/tuple
+    - weight_decay
+    - 0.01
+  dictitems:
+    lr: 0.0002
+    momentum: null
+    opt: lookahead_radam
+    opt_betas: null
+    opt_eps: null
+    weight_decay: 0.01
+  state: *id016
diff --git a/models/AttMIL.py b/code/models/AttMIL.py
similarity index 98%
rename from models/AttMIL.py
rename to code/models/AttMIL.py
index d1e20eb3e5d465eab0008c5121b21f1fd9394605..d4a5938c1a27531cf992f9cad099ba445e7d093c 100644
--- a/models/AttMIL.py
+++ b/code/models/AttMIL.py
@@ -76,4 +76,4 @@ class AttMIL(nn.Module): #gated attention
         M = torch.mm(A, H)  # KxL
         logits = self.classifier(M)
        
-        return logits
\ No newline at end of file
+        return logits, A
\ No newline at end of file
diff --git a/code/models/DTFDMIL.py b/code/models/DTFDMIL.py
new file mode 100644
index 0000000000000000000000000000000000000000..272c1c7d642d4551baf37f778d64a5eefd0931b1
--- /dev/null
+++ b/code/models/DTFDMIL.py
@@ -0,0 +1,109 @@
+import os
+import logging
+from argparse import ArgumentParser
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as models
+
+import pytorch_lightning as pl
+from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
+
+class Attention_Gated(nn.Module):
+    def __init__(self, features=512, D=128, K=1):
+        super(Attention_Gated, self).__init__()
+
+        self.L = features
+        self.D = D
+        self.K = K
+
+        self.attention_V = nn.Sequential(
+            nn.Linear(self.L, self.D),
+            nn.Tanh()
+        )
+
+        self.attention_U = nn.Sequential(
+            nn.Linear(self.L, self.D),
+            nn.Sigmoid()
+        )
+
+        self.attention_weights = nn.Linear(self.D, self.K)
+
+    def forward(self, x, isNorm=True):
+        ## x: N x L
+        # print(x.shape)
+        A_V = self.attention_V(x)  # NxD
+        A_U = self.attention_U(x)  # NxD
+        A = self.attention_weights(A_V * A_U) # NxK
+        A = torch.transpose(A, 1, 0)  # KxN
+
+        if isNorm:
+            A = F.softmax(A, dim=1)  # softmax over N
+
+        return A  ### K x N
+
+class Attention_with_Classifier(nn.Module):
+    def __init__(self, L=512, D=128, K=1, num_cls=2, droprate=0):
+        super(Attention_with_Classifier, self).__init__()
+        self.attention = Attention_Gated(L, D, K)
+        self.classifier = Classifier_1fc(L, num_cls, droprate)
+    def forward(self, x): ## x: N x L
+        AA = self.attention(x)  ## K x N
+        afeat = torch.mm(AA, x) ## K x L
+        pred = self.classifier(afeat) ## K x num_cls
+        return pred
+
+class Classifier_1fc(nn.Module):
+    def __init__(self, n_channels, n_classes, droprate=0.0):
+        super(Classifier_1fc, self).__init__()
+        self.fc = nn.Linear(n_channels, n_classes)
+        self.droprate = droprate
+        if self.droprate != 0.0:
+            self.dropout = torch.nn.Dropout(p=self.droprate)
+
+    def forward(self, x):
+
+        if self.droprate != 0.0:
+            x = self.dropout(x)
+        x = self.fc(x)
+        return x
+
+
+class residual_block(nn.Module):
+    def __init__(self, nChn=512):
+        super(residual_block, self).__init__()
+        self.block = nn.Sequential(
+                nn.Linear(nChn, nChn, bias=False),
+                nn.ReLU(inplace=True),
+                nn.Linear(nChn, nChn, bias=False),
+                nn.ReLU(inplace=True),
+            )
+    def forward(self, x):
+        tt = self.block(x)
+        x = x + tt
+        return x
+
+
+class DimReduction(nn.Module):
+    def __init__(self, n_channels, m_dim=512, numLayer_Res=0):
+        super(DimReduction, self).__init__()
+        self.fc1 = nn.Linear(n_channels, m_dim, bias=False)
+        self.relu1 = nn.ReLU(inplace=True)
+        self.numRes = numLayer_Res
+
+        self.resBlocks = []
+        for ii in range(numLayer_Res):
+            self.resBlocks.append(residual_block(m_dim))
+        self.resBlocks = nn.Sequential(*self.resBlocks)
+
+    def forward(self, x):
+
+        x = self.fc1(x)
+        x = self.relu1(x)
+
+        if self.numRes > 0:
+            x = self.resBlocks(x)
+
+        return x
\ No newline at end of file
diff --git a/models/TransMIL.py b/code/models/TransMIL.py
similarity index 70%
rename from models/TransMIL.py
rename to code/models/TransMIL.py
index ca2a1fbe8240fd1983b246cbbfbdad187c48b214..c78599d34d7faf2cc54fa182f5fb11ef63b13ea0 100755
--- a/models/TransMIL.py
+++ b/code/models/TransMIL.py
@@ -17,13 +17,15 @@ class TransLayer(nn.Module):
             num_landmarks = dim//2,    # number of landmarks
             pinv_iterations = 6,    # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
             residual = True,         # whether to do an extra residual with the value or not. supposedly faster convergence if turned on
-            dropout=0.1
+            dropout=0.7 #0.1
         )
 
     def forward(self, x):
-        x = x + self.attn(self.norm(x))
+        out, attn = self.attn(self.norm(x), return_attn=True)
+        x = x + out
+        # x = x + self.attn(self.norm(x))
 
-        return x
+        return x, attn
 
 
 class PPEG(nn.Module):
@@ -44,8 +46,10 @@ class PPEG(nn.Module):
 
 
 class TransMIL(nn.Module):
-    def __init__(self, n_classes, in_features, out_features=384):
+    def __init__(self, n_classes):
         super(TransMIL, self).__init__()
+        in_features = 512
+        out_features=512
         self.pos_layer = PPEG(dim=out_features)
         self._fc1 = nn.Sequential(nn.Linear(in_features, out_features), nn.GELU())
         # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
@@ -62,42 +66,63 @@ class TransMIL(nn.Module):
         h = x.float() #[B, n, 1024]
         h = self._fc1(h) #[B, n, 512]
         
-        #---->pad
+        # print('Feature Representation: ', h.shape)
+        #---->duplicate pad
         H = h.shape[1]
         _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
         add_length = _H * _W - H
         h = torch.cat([h, h[:,:add_length,:]],dim = 1) #[B, N, 512]
+        
 
         #---->cls_token
         B = h.shape[0]
         cls_tokens = self.cls_token.expand(B, -1, -1).cuda()
         h = torch.cat((cls_tokens, h), dim=1)
 
+
         #---->Translayer x1
-        h = self.layer1(h) #[B, N, 512]
+        h, attn1 = self.layer1(h) #[B, N, 512]
+
+        # print('After first TransLayer: ', h.shape)
 
         #---->PPEG
         h = self.pos_layer(h, _H, _W) #[B, N, 512]
+        # print('After PPEG: ', h.shape)
         
         #---->Translayer x2
-        h = self.layer2(h) #[B, N, 512]
+        h, attn2 = self.layer2(h) #[B, N, 512]
 
+        # print('After second TransLayer: ', h.shape) #[1, 1025, 512] 1025 = cls_token + 1024
         #---->cls_token
-        print(h.shape) #[1, 1025, 512] 1025 = cls_token + 1024
-
-        # tokens = h
+        
         h = self.norm(h)[:,0]
 
         #---->predict
         logits = self._fc2(h) #[B, n_classes]
-        return logits
+        return logits, attn2
 
 if __name__ == "__main__":
     data = torch.randn((1, 6000, 512)).cuda()
-    model = TransMIL(n_classes=2, in_features=512).cuda()
+    model = TransMIL(n_classes=2).cuda()
     print(model.eval())
-    results_dict = model(data)
-    print(results_dict)
+    logits, attn = model(data)
+    cls_attention = attn[:,:, 0, :6000]
+    values, indices = torch.max(cls_attention, 1)
+    mean = values.mean()
+    zeros = torch.zeros(values.shape).cuda()
+    filtered = torch.where(values > mean, values, zeros)
+    
+    # filter = values > values.mean()
+    # filtered_values = values[filter]
+    # values = np.where(values>values.mean(), values, 0)
+
+    print(filtered.shape)
+
+
+    # values = [v if v > values.mean().item() else 0 for v in values]
+    # print(values)
+    # print(len(values))
+
     # logits = results_dict['logits']
     # Y_prob = results_dict['Y_prob']
     # Y_hat = results_dict['Y_hat']
diff --git a/code/models/TransformerMIL.py b/code/models/TransformerMIL.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea249ba371fcfa7113625f6459ceb73421c13182
--- /dev/null
+++ b/code/models/TransformerMIL.py
@@ -0,0 +1,129 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from nystrom_attention import NystromAttention
+
+
+class TransLayer(nn.Module):
+
+    def __init__(self, norm_layer=nn.LayerNorm, dim=512):
+        super().__init__()
+        self.norm = norm_layer(dim)
+        self.attn = NystromAttention(
+            dim = dim,
+            dim_head = dim//8,
+            heads = 8,
+            num_landmarks = dim//2,    # number of landmarks
+            pinv_iterations = 6,    # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
+            residual = True,         # whether to do an extra residual with the value or not. supposedly faster convergence if turned on
+            dropout=0.25 #0.1
+        )
+
+    def forward(self, x):
+        out, attn = self.attn(self.norm(x), return_attn=True)
+        x = x + out
+        # x = x + self.attn(self.norm(x))
+
+        return x, attn
+
+
+class PPEG(nn.Module):
+    def __init__(self, dim=512):
+        super(PPEG, self).__init__()
+        self.proj = nn.Conv2d(dim, dim, 7, 1, 7//2, groups=dim)
+        self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5//2, groups=dim)
+        self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3//2, groups=dim)
+
+    def forward(self, x, H, W):
+        B, _, C = x.shape
+        cls_token, feat_token = x[:, 0], x[:, 1:]
+        cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
+        x = self.proj(cnn_feat)+cnn_feat+self.proj1(cnn_feat)+self.proj2(cnn_feat)
+        x = x.flatten(2).transpose(1, 2)
+        x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
+        return x
+
+
+class TransformerMIL(nn.Module):
+    def __init__(self, n_classes):
+        super(TransformerMIL, self).__init__()
+        in_features = 512
+        out_features = 512
+        self.pos_layer = PPEG(dim=out_features)
+        self._fc1 = nn.Sequential(nn.Linear(in_features, out_features), nn.GELU())
+        # self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
+        self.cls_token = nn.Parameter(torch.randn(1, 1, out_features))
+        self.n_classes = n_classes
+        self.layer1 = TransLayer(dim=out_features)
+        self.layer2 = TransLayer(dim=out_features)
+        self.norm = nn.LayerNorm(out_features)
+        self._fc2 = nn.Linear(out_features, self.n_classes)
+
+
+    def forward(self, x): #, **kwargs
+
+        h = x.float() #[B, n, 1024]
+        h = self._fc1(h) #[B, n, 512]
+        
+        # print('Feature Representation: ', h.shape)
+        #---->duplicate pad
+        H = h.shape[1]
+        _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
+        add_length = _H * _W - H
+        h = torch.cat([h, h[:,:add_length,:]],dim = 1) #[B, N, 512]
+        
+
+        #---->cls_token
+        B = h.shape[0]
+        cls_tokens = self.cls_token.expand(B, -1, -1).cuda()
+        h = torch.cat((cls_tokens, h), dim=1)
+
+
+        #---->Translayer x1
+        h, attn1 = self.layer1(h) #[B, N, 512]
+
+        # print('After first TransLayer: ', h.shape)
+
+        #---->PPEG
+        # h = self.pos_layer(h, _H, _W) #[B, N, 512]
+        # # print('After PPEG: ', h.shape)
+        
+        # #---->Translayer x2
+        # h, attn2 = self.layer2(h) #[B, N, 512]
+
+        # print('After second TransLayer: ', h.shape) #[1, 1025, 512] 1025 = cls_token + 1024
+        #---->cls_token
+        
+        h = self.norm(h)[:,0]
+
+        #---->predict
+        logits = self._fc2(h) #[B, n_classes]
+        return logits, attn1
+
+if __name__ == "__main__":
+    data = torch.randn((1, 6000, 512)).cuda()
+    model = TransMIL(n_classes=2).cuda()
+    print(model.eval())
+    logits, attn = model(data)
+    cls_attention = attn[:,:, 0, :6000]
+    values, indices = torch.max(cls_attention, 1)
+    mean = values.mean()
+    zeros = torch.zeros(values.shape).cuda()
+    filtered = torch.where(values > mean, values, zeros)
+    
+    # filter = values > values.mean()
+    # filtered_values = values[filter]
+    # values = np.where(values>values.mean(), values, 0)
+
+    print(filtered.shape)
+
+
+    # values = [v if v > values.mean().item() else 0 for v in values]
+    # print(values)
+    # print(len(values))
+
+    # logits = results_dict['logits']
+    # Y_prob = results_dict['Y_prob']
+    # Y_hat = results_dict['Y_hat']
+    # print(F.sigmoid(logits))
diff --git a/models/__init__.py b/code/models/__init__.py
similarity index 100%
rename from models/__init__.py
rename to code/models/__init__.py
diff --git a/models/__pycache__/AttMIL.cpython-39.pyc b/code/models/__pycache__/AttMIL.cpython-39.pyc
similarity index 79%
rename from models/__pycache__/AttMIL.cpython-39.pyc
rename to code/models/__pycache__/AttMIL.cpython-39.pyc
index 5ee3af4e0d05559e5bc7a4d8ed62481a3254c9f1..f0f6f23506b721b78a8172b4ed1b215c65ae45b0 100644
Binary files a/models/__pycache__/AttMIL.cpython-39.pyc and b/code/models/__pycache__/AttMIL.cpython-39.pyc differ
diff --git a/code/models/__pycache__/DTFDMIL.cpython-39.pyc b/code/models/__pycache__/DTFDMIL.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ca3b6870a8316ae4e8ac9dc503615612cb9bde2
Binary files /dev/null and b/code/models/__pycache__/DTFDMIL.cpython-39.pyc differ
diff --git a/code/models/__pycache__/TransMIL.cpython-39.pyc b/code/models/__pycache__/TransMIL.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38ed5a1f3fedfb29af2fc50c7909b312399d39d3
Binary files /dev/null and b/code/models/__pycache__/TransMIL.cpython-39.pyc differ
diff --git a/code/models/__pycache__/TransformerMIL.cpython-39.pyc b/code/models/__pycache__/TransformerMIL.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..999c95ee4caea337f6a77e6cc0eaad0fb3a6c90c
Binary files /dev/null and b/code/models/__pycache__/TransformerMIL.cpython-39.pyc differ
diff --git a/models/__pycache__/__init__.cpython-39.pyc b/code/models/__pycache__/__init__.cpython-39.pyc
similarity index 100%
rename from models/__pycache__/__init__.cpython-39.pyc
rename to code/models/__pycache__/__init__.cpython-39.pyc
diff --git a/code/models/__pycache__/model_interface.cpython-39.pyc b/code/models/__pycache__/model_interface.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..565cfe10390ef1f7b43c48995fdfeb7ea63773a2
Binary files /dev/null and b/code/models/__pycache__/model_interface.cpython-39.pyc differ
diff --git a/code/models/__pycache__/model_interface_dtfd.cpython-39.pyc b/code/models/__pycache__/model_interface_dtfd.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01c221509c05637cfc3eaca68d9fbe8ec26c0416
Binary files /dev/null and b/code/models/__pycache__/model_interface_dtfd.cpython-39.pyc differ
diff --git a/models/__pycache__/resnet50.cpython-39.pyc b/code/models/__pycache__/resnet50.cpython-39.pyc
similarity index 100%
rename from models/__pycache__/resnet50.cpython-39.pyc
rename to code/models/__pycache__/resnet50.cpython-39.pyc
diff --git a/models/__pycache__/vision_transformer.cpython-39.pyc b/code/models/__pycache__/vision_transformer.cpython-39.pyc
similarity index 100%
rename from models/__pycache__/vision_transformer.cpython-39.pyc
rename to code/models/__pycache__/vision_transformer.cpython-39.pyc
diff --git a/models/model_interface.py b/code/models/model_interface.py
similarity index 70%
rename from models/model_interface.py
rename to code/models/model_interface.py
index b3a561eba25656b93f4e4663c88e796022a556d0..586d80c41c22c5e5f3b1f48a95e130fdb0911571 100755
--- a/models/model_interface.py
+++ b/code/models/model_interface.py
@@ -1,5 +1,6 @@
 import sys
 import numpy as np
+import re
 import inspect
 import importlib
 import random
@@ -30,9 +31,10 @@ from torch import optim as optim
 #---->
 import pytorch_lightning as pl
 from .vision_transformer import vit_small
+import torchvision
 from torchvision import models
 from torchvision.models import resnet
-from transformers import AutoFeatureExtractor, ViTModel
+from transformers import AutoFeatureExtractor, ViTModel, SwinModel
 
 from pytorch_grad_cam import GradCAM, EigenGradCAM
 from pytorch_grad_cam.utils.image import show_cam_on_image
@@ -52,12 +54,11 @@ class ModelInterface(pl.LightningModule):
         # self.loss = LabelSmoothingCrossEntropy(smoothing=0.1)
         # self.loss = 
         # print(self.model)
-        
+        self.model_name = model.name
         
         # self.ecam = EigenGradCAM(model = self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform)
         self.optimizer = optimizer
         self.n_classes = model.n_classes
-        print(self.n_classes)
         self.save_path = kargs['log']
         if Path(self.save_path).parts[3] == 'tcmr':
             temp = list(Path(self.save_path).parts)
@@ -85,7 +86,7 @@ class ModelInterface(pl.LightningModule):
                                                                             num_classes = self.n_classes)])
                                                                             
         else : 
-            self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted')
+            self.AUROC = torchmetrics.AUROC(num_classes=self.n_classes, average = 'weighted')
 
             metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2,
                                                                            average = 'micro'),
@@ -97,6 +98,7 @@ class ModelInterface(pl.LightningModule):
                                                      torchmetrics.Precision(average = 'macro',
                                                                             num_classes = 2)])
         self.PRC = torchmetrics.PrecisionRecallCurve(num_classes = self.n_classes)
+        self.ROC = torchmetrics.ROC(num_classes=self.n_classes)
         # self.pr_curve = torchmetrics.BinnedPrecisionRecallCurve(num_classes = self.n_classes, thresholds=10)
         self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes)                                                                    
         self.valid_metrics = metrics.clone(prefix = 'val_')
@@ -112,39 +114,42 @@ class ModelInterface(pl.LightningModule):
             self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
             self.model_ft = ViTModel.from_pretrained('facebook/dino-vitb16')
         elif kargs['backbone'] == 'resnet18':
-            resnet18 = models.resnet18(pretrained=True)
-            modules = list(resnet18.children())[:-1]
-            # model_ft.fc = nn.Linear(512, out_features)
-
-            res18 = nn.Sequential(
-                *modules,
-            )
-            for param in res18.parameters():
+            self.model_ft = models.resnet18(pretrained=True)
+            # modules = list(resnet18.children())[:-1]
+            for param in self.model_ft.parameters():
                 param.requires_grad = False
-            self.model_ft = nn.Sequential(
-                res18,
-                nn.AdaptiveAvgPool2d(1),
-                View((-1, 512)),
-                nn.Linear(512, self.out_features),
-                nn.GELU(),
-            )
+            self.model_ft.fc = nn.Linear(512, self.out_features)
+
+            # res18 = nn.Sequential(
+            #     *modules,
+            # )
+            # for param in res18.parameters():
+            #     param.requires_grad = False
+            # self.model_ft = nn.Sequential(
+            #     res18,
+            #     nn.AdaptiveAvgPool2d(1),
+            #     View((-1, 512)),
+            #     nn.Linear(512, self.out_features),
+            #     nn.GELU(),
+            # )
         elif kargs['backbone'] == 'resnet50':
 
-            resnet50 = models.resnet50(pretrained=True)    
-            # model_ft.fc = nn.Linear(1024, out_features)
-            modules = list(resnet50.children())[:-3]
-            res50 = nn.Sequential(
-                *modules,     
-            )
-            for param in res50.parameters():
+            self.model_ft = models.resnet50(pretrained=True)    
+            for param in self.model_ft.parameters():
                 param.requires_grad = False
-            self.model_ft = nn.Sequential(
-                res50,
-                nn.AdaptiveAvgPool2d(1),
-                View((-1, 1024)),
-                nn.Linear(1024, self.out_features),
-                # nn.GELU()
-            )
+            self.model_ft.fc = nn.Linear(2048, self.out_features)
+            # modules = list(resnet50.children())[:-3]
+            # res50 = nn.Sequential(
+            #     *modules,     
+            # )
+            
+            # self.model_ft = nn.Sequential(
+            #     res50,
+            #     nn.AdaptiveAvgPool2d(1),
+            #     View((-1, 1024)),
+            #     nn.Linear(1024, self.out_features),
+            #     # nn.GELU()
+            # )
         elif kargs['backbone'] == 'efficientnet':
             efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_widese_b0', pretrained=True)
             for param in efficientnet.parameters():
@@ -164,30 +169,22 @@ class ModelInterface(pl.LightningModule):
                 nn.Conv2d(20, 50, kernel_size=5),
                 nn.ReLU(),
                 nn.MaxPool2d(2, stride=2),
-                View((-1, 1024)),
-                nn.Linear(1024, self.out_features),
+                View((-1, 53*53)),
+                nn.Linear(53*53, self.out_features),
                 nn.ReLU(),
             )
         # print(self.model_ft[0].features[-1])
         # print(self.model_ft)
-        if model.name == 'TransMIL':
-            target_layers = [self.model.layer2.norm] # 32x32
-            # target_layers = [self.model_ft[0].features[-1]] # 32x32
-            self.cam = GradCAM(model=self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform) #, reshape_transform=self.reshape_transform
-            # self.cam_ft = GradCAM(model=self.model, target_layers = target_layers_ft, use_cuda=True) #, reshape_transform=self.reshape_transform
-        else:
-            target_layers = [self.model.attention_weights]
-            self.cam = GradCAM(model = self.model, target_layers = target_layers, use_cuda=True)
 
     def forward(self, x):
-        
+        # print(x.shape)
         feats = self.model_ft(x).unsqueeze(0)
         return self.model(feats)
 
     def step(self, input):
 
         input = input.squeeze(0).float()
-        logits = self(input) 
+        logits, _ = self(input) 
 
         Y_hat = torch.argmax(logits, dim=1)
         Y_prob = F.softmax(logits, dim=1)
@@ -195,10 +192,14 @@ class ModelInterface(pl.LightningModule):
         return logits, Y_prob, Y_hat
 
     def training_step(self, batch, batch_idx):
-        #---->inference
-        
 
         input, label, _= batch
+
+        #random image dropout
+        # bag_size = 500
+        # bag_idxs = torch.randperm(input.squeeze(0).shape[0])[:bag_size]
+        # input = input.squeeze(0)[bag_idxs].unsqueeze(0)
+
         label = label.float()
         
         logits, Y_prob, Y_hat = self.step(input) 
@@ -217,24 +218,31 @@ class ModelInterface(pl.LightningModule):
             # Y = int(label[0])
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (int(Y_hat) == Y)
-        self.log('loss', loss, prog_bar=True, on_epoch=True, logger=True)
+        self.log('loss', loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1)
 
         if self.current_epoch % 10 == 0:
 
-            grid = torchvision.utils.make_grid(images)
+            # images = input.squeeze()[:10, :, :, :]
+            # for i in range(10):
+            img = input.squeeze(0)[:10, :, :, :]
+            img = (img - torch.min(img)/(torch.max(img)-torch.min(img)))*255.0
+            
+            # mg = img.cpu().numpy()
+            grid = torchvision.utils.make_grid(img, normalize=True, value_range=(0, 255), scale_each=True)
+            # grid = img.detach().cpu().numpy()
         # log input images 
-        # self.loggers[0].experiment.add_figure(f'{stage}/input', , self.current_epoch)
+            self.loggers[0].experiment.add_image(f'{self.current_epoch}/input', grid)
 
 
-        return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': label} 
+        return {'loss': loss, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'label': Y} 
 
     def training_epoch_end(self, training_step_outputs):
         # logits = torch.cat([x['logits'] for x in training_step_outputs], dim = 0)
         probs = torch.cat([x['Y_prob'] for x in training_step_outputs])
         max_probs = torch.stack([x['Y_hat'] for x in training_step_outputs])
         # target = torch.stack([x['label'] for x in training_step_outputs], dim = 0)
-        target = torch.cat([x['label'] for x in training_step_outputs])
-        target = torch.argmax(target, dim=1)
+        target = torch.stack([x['label'] for x in training_step_outputs])
+        # target = torch.argmax(target, dim=1)
         for c in range(self.n_classes):
             count = self.data[c]["count"]
             correct = self.data[c]["correct"]
@@ -248,9 +256,9 @@ class ModelInterface(pl.LightningModule):
         # print('max_probs: ', max_probs)
         # print('probs: ', probs)
         if self.current_epoch % 10 == 0:
-            self.log_confusion_matrix(probs, target, stage='train')
+            self.log_confusion_matrix(max_probs, target, stage='train')
 
-        self.log('Train/auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
+        self.log('Train/auc', self.AUROC(probs, target), prog_bar=True, on_epoch=True, logger=True)
 
     def validation_step(self, batch, batch_idx):
 
@@ -266,21 +274,24 @@ class ModelInterface(pl.LightningModule):
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (Y_hat.item() == Y)
 
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y}
 
 
     def validation_epoch_end(self, val_step_outputs):
         logits = torch.cat([x['logits'] for x in val_step_outputs], dim = 0)
         probs = torch.cat([x['Y_prob'] for x in val_step_outputs])
         max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs])
-        target = torch.cat([x['label'] for x in val_step_outputs])
-        target = torch.argmax(target, dim=1)
+        target = torch.stack([x['label'] for x in val_step_outputs])
+        
+        self.log_dict(self.valid_metrics(logits, target),
+                          on_epoch = True, logger = True)
+        
         #---->
         # logits = logits.long()
         # target = target.squeeze().long()
         # logits = logits.squeeze(0)
         if len(target.unique()) != 1:
-            self.log('val_auc', self.AUROC(probs, target.squeeze()), prog_bar=True, on_epoch=True, logger=True)
+            self.log('val_auc', self.AUROC(probs, target), prog_bar=True, on_epoch=True, logger=True)
         else:    
             self.log('val_auc', 0.0, prog_bar=True, on_epoch=True, logger=True)
 
@@ -289,13 +300,16 @@ class ModelInterface(pl.LightningModule):
         self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True)
         
 
+        precision, recall, thresholds = self.PRC(probs, target)
+
+
+
         # print(max_probs.squeeze(0).shape)
         # print(target.shape)
-        self.log_dict(self.valid_metrics(max_probs.squeeze() , target),
-                          on_epoch = True, logger = True)
+        
 
         #----> log confusion matrix
-        self.log_confusion_matrix(probs, target, stage='val')
+        self.log_confusion_matrix(max_probs, target, stage='val')
         
 
         #---->acc log
@@ -306,7 +320,7 @@ class ModelInterface(pl.LightningModule):
                 acc = None
             else:
                 acc = float(correct) / count
-            print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+            print('val class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
         self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
         
         #---->random, if shuffle data, change seed
@@ -315,34 +329,73 @@ class ModelInterface(pl.LightningModule):
             random.seed(self.count*50)
 
     def test_step(self, batch, batch_idx):
+
         torch.set_grad_enabled(True)
-        data, label, name = batch
+        data, label, (wsi_name, batch_names) = batch
+        wsi_name = wsi_name[0]
         label = label.float()
         # logits, Y_prob, Y_hat = self.step(data) 
         # print(data.shape)
         data = data.squeeze(0).float()
-        logits = self(data).detach() 
+        logits, attn = self(data)
+        attn = attn.detach()
+        logits = logits.detach()
 
         Y = torch.argmax(label)
         Y_hat = torch.argmax(logits, dim=1)
         Y_prob = F.softmax(logits, dim = 1)
         
-        #----> Get Topk tiles 
+        #----> Get GradCam maps, map each instance to attention value, assemble, overlay on original WSI 
+        if self.model_name == 'TransMIL':
+           
+            target_layers = [self.model.layer2.norm] # 32x32
+            # target_layers = [self.model_ft[0].features[-1]] # 32x32
+            self.cam = GradCAM(model=self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform) #, reshape_transform=self.reshape_transform
+            # self.cam_ft = GradCAM(model=self.model, target_layers = target_layers_ft, use_cuda=True) #, reshape_transform=self.reshape_transform
+        else:
+            target_layers = [self.model.attention_weights]
+            self.cam = GradCAM(model = self.model, target_layers = target_layers, use_cuda=True)
 
-        target = [ClassifierOutputTarget(Y)]
 
         data_ft = self.model_ft(data).unsqueeze(0).float()
-        # data_ft = self.model_ft(data).unsqueeze(0).float()
-        # print(data_ft.shape)
-        # print(target)
+        instance_count = data.size(0)
+        target = [ClassifierOutputTarget(Y)]
         grayscale_cam = self.cam(input_tensor=data_ft, targets=target)
-        # grayscale_ecam = self.ecam(input_tensor=data_ft, targets=target)
+        grayscale_cam = torch.Tensor(grayscale_cam)[:instance_count, :]
+
+        # attention_map = grayscale_cam[:, :, 1].squeeze()
+        # attention_map = F.relu(attention_map)
+        # mask = torch.zeros((instance_count, 3, 256, 256)).to(self.device)
+        # for i, v in enumerate(attention_map):
+        #     mask[i, :, :, :] = v
+
+        # mask = self.assemble(mask, batch_names)
+        # mask = (mask - mask.min())/(mask.max()-mask.min())
+        # mask = mask.cpu().numpy()
+        # wsi = self.assemble(data, batch_names)
+        # wsi = wsi.cpu().numpy()
+
+        # def show_cam_on_image(img, mask):
+        #     heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
+        #     heatmap = np.float32(heatmap) / 255
+        #     cam = heatmap*0.4 + np.float32(img)
+        #     cam = cam / np.max(cam)
+        #     return cam
+
+        # wsi = show_cam_on_image(wsi, mask)
+        # wsi = ((wsi-wsi.min())/(wsi.max()-wsi.min()) * 255.0).astype(np.uint8)
+        
+        # img = Image.fromarray(wsi)
+        # img = img.convert('RGB')
+        
 
-        # print(grayscale_cam)
+        # output_path = self.save_path / str(Y.item())
+        # output_path.mkdir(parents=True, exist_ok=True)
+        # img.save(f'{output_path}/{wsi_name}.jpg')
 
-        summed = torch.mean(torch.Tensor(grayscale_cam), dim=2)
-        print(summed)
-        print(summed.shape)
+
+        #----> Get Topk Tiles and Topk Patients
+        summed = torch.mean(grayscale_cam, dim=2)
         topk_tiles, topk_indices = torch.topk(summed.squeeze(0), 5, dim=0)
         topk_data = data[topk_indices].detach()
         
@@ -368,20 +421,33 @@ class ModelInterface(pl.LightningModule):
         self.data[Y]["count"] += 1
         self.data[Y]["correct"] += (Y_hat.item() == Y)
 
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name, 'topk_data': topk_data} #
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y, 'name': wsi_name, 'topk_data': topk_data} #
         # return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name} #, 'topk_data': topk_data
 
     def test_epoch_end(self, output_results):
+        logits = torch.cat([x['logits'] for x in output_results], dim = 0)
         probs = torch.cat([x['Y_prob'] for x in output_results])
         max_probs = torch.stack([x['Y_hat'] for x in output_results])
         # target = torch.stack([x['label'] for x in output_results], dim = 0)
-        target = torch.cat([x['label'] for x in output_results])
-        target = torch.argmax(target, dim=1)
+        target = torch.stack([x['label'] for x in output_results])
+        # target = torch.argmax(target, dim=1)
         patients = [x['name'] for x in output_results]
         topk_tiles = [x['topk_data'] for x in output_results]
         #---->
-        auc = self.AUROC(probs, target.squeeze())
-        metrics = self.test_metrics(max_probs.squeeze() , target)
+        auc = self.AUROC(probs, target)
+        fpr, tpr, thresholds = self.ROC(probs, target)
+        fpr = fpr.cpu().numpy()
+        tpr = tpr.cpu().numpy()
+
+        plt.figure(1)
+        plt.plot(fpr, tpr)
+        plt.xlabel('False positive rate')
+        plt.ylabel('True positive rate')
+        plt.title('ROC curve')
+        plt.savefig(f'{self.save_path}/roc.jpg')
+        # self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+
+        metrics = self.test_metrics(logits , target)
 
 
         # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1))
@@ -390,16 +456,20 @@ class ModelInterface(pl.LightningModule):
         # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
 
         #---->get highest scoring patients for each class
-        test_path = Path(self.save_path) / 'most_predictive'
+        # test_path = Path(self.save_path) / 'most_predictive' 
+        
+        # Path.mkdir(output_path, exist_ok=True)
         topk, topk_indices = torch.topk(probs.squeeze(0), 5, dim=0)
         for n in range(self.n_classes):
             print('class: ', n)
+            
             topk_patients = [patients[i[n]] for i in topk_indices]
             topk_patient_tiles = [topk_tiles[i[n]] for i in topk_indices]
             for x, p, t in zip(topk, topk_patients, topk_patient_tiles):
                 print(p, x[n])
-                patient = p[0]
-                outpath = test_path / str(n) / patient 
+                patient = p
+                # outpath = test_path / str(n) / patient 
+                outpath = Path(self.save_path) / str(n) / patient
                 outpath.mkdir(parents=True, exist_ok=True)
                 for i in range(len(t)):
                     tile = t[i]
@@ -408,7 +478,7 @@ class ModelInterface(pl.LightningModule):
                     tile = tile.astype(np.uint8)
                     img = Image.fromarray(tile)
                     
-                    img.save(f'{test_path}/{n}/{patient}/{i}_gradcam.jpg')
+                    img.save(f'{outpath}/{i}.jpg')
 
             
             
@@ -448,7 +518,7 @@ class ModelInterface(pl.LightningModule):
         #[tp, fp, tn, fn, tp+fn]
 
 
-        self.log_confusion_matrix(probs, target, stage='test')
+        self.log_confusion_matrix(max_probs, target, stage='test')
         #---->
         result = pd.DataFrame([metrics])
         result.to_csv(Path(self.save_path) / f'test_result.csv', mode='a', header=not Path(self.save_path).exists())
@@ -462,8 +532,13 @@ class ModelInterface(pl.LightningModule):
         optimizer = create_optimizer(self.optimizer, self.model)
         return optimizer     
 
-    def reshape_transform(self, tensor, h=32, w=32):
-        result = tensor[:, 1:, :].reshape(tensor.size(0), h, w, tensor.size(2))
+    def reshape_transform(self, tensor):
+        # print(tensor.shape)
+        H = tensor.shape[1]
+        _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
+        add_length = _H * _W - H
+        tensor = torch.cat([tensor, tensor[:,:add_length,:]],dim = 1)
+        result = tensor[:, :, :].reshape(tensor.size(0), _H, _W, tensor.size(2))
         result = result.transpose(2,3).transpose(1,2)
         # print(result.shape)
         return result
@@ -510,24 +585,42 @@ class ModelInterface(pl.LightningModule):
 
 
     def log_confusion_matrix(self, max_probs, target, stage):
-        confmat = self.confusion_matrix(max_probs.squeeze(), target)
+        confmat = self.confusion_matrix(max_probs, target)
         print(confmat)
         df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
-        # plt.figure()
-        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
-        # plt.close(fig_)
-        # plt.savefig(f'{self.save_path}/cm_e{self.current_epoch}')
+        fig_ = sns.heatmap(df_cm, annot=True, fmt='d', cmap='Spectral').get_figure()
+        if stage == 'train':
+            self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+        else:
+            fig_.savefig(f'{self.loggers[0].log_dir}/cm_test.png', dpi=400)
+
+    def log_roc_curve(self, probs, target, stage):
+
+        fpr_list, tpr_list, thresholds = self.ROC(probs, target)
+
+        plt.figure(1)
+        if self.n_classes > 2:
+            for i in range(len(fpr_list)):
+                fpr = fpr_list[i].cpu().numpy()
+                tpr = tpr_list[i].cpu().numpy()
+                plt.plot(fpr, tpr, label=f'class_{i}')
+        else: 
+            print(fpr_list)
+            fpr = fpr_list.cpu().numpy()
+            tpr = tpr_list.cpu().numpy()
+            plt.plot(fpr, tpr)
         
+        plt.xlabel('False positive rate')
+        plt.ylabel('True positive rate')
+        plt.title('ROC curve')
+        plt.savefig(f'{self.loggers[0].log_dir}/roc.jpg')
 
         if stage == 'train':
-            # print(self.save_path)
-            # plt.savefig(f'{self.save_path}/cm_test')
-
-            self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+            self.loggers[0].experiment.add_figure(f'{stage}/ROC', plt, self.current_epoch)
         else:
-            fig_.savefig(f'{self.save_path}/cm_test.png', dpi=400)
-        # plt.close(fig_)
-        # self.logger[0].experiment.add_figure('Confusion matrix', fig_, self.current_epoch)
+            plt.savefig(f'{self.loggers[0].log_dir}/roc.jpg', dpi=400)
+
+    
 
 class View(nn.Module):
     def __init__(self, shape):
@@ -538,8 +631,6 @@ class View(nn.Module):
         '''
         Reshapes the input according to the shape saved in the view data structure.
         '''
-        # batch_size = input.size(0)
-        # shape = (batch_size, *self.shape)
         out = input.view(*self.shape)
         return out
 
diff --git a/code/models/model_interface_dtfd.py b/code/models/model_interface_dtfd.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b503ccdfcc92bc3f1a2bdeadd5a7a4f4955385c
--- /dev/null
+++ b/code/models/model_interface_dtfd.py
@@ -0,0 +1,707 @@
+import sys
+import numpy as np
+import re
+import inspect
+import importlib
+import random
+import pandas as pd
+import seaborn as sns
+from pathlib import Path
+from matplotlib import pyplot as plt
+import cv2
+from PIL import Image
+
+#---->
+from MyOptimizer import create_optimizer
+from MyLoss import create_loss
+from utils.utils import cross_entropy_torch
+from timm.loss import AsymmetricLossSingleLabel
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
+
+#---->
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchmetrics
+from torchmetrics.functional import stat_scores
+from torch import optim as optim
+# from sklearn.metrics import roc_curve, auc, roc_curve_score
+
+
+#---->
+import pytorch_lightning as pl
+from .vision_transformer import vit_small
+import torchvision
+from torchvision import models
+from torchvision.models import resnet
+from transformers import AutoFeatureExtractor, ViTModel
+
+from pytorch_grad_cam import GradCAM, EigenGradCAM
+from pytorch_grad_cam.utils.image import show_cam_on_image
+from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
+
+from captum.attr import LayerGradCam
+from models.DTFDMIL import Attention_Gated, Classifier_1fc, DimReduction, Attention_with_Classifier
+
+class ModelInterface_DTFD(pl.LightningModule):
+
+    #---->init
+    def __init__(self, model, loss, optimizer, **kargs):
+        super(ModelInterface_DTFD, self).__init__()
+        self.save_hyperparameters()
+        # self.load_model()
+        self.loss = create_loss(loss)
+        # self.asl = AsymmetricLossSingleLabel()
+        # self.loss = LabelSmoothingCrossEntropy(smoothing=0.1)
+        # self.loss = 
+        # print(self.model)
+        self.model_name = model.name
+        
+        self.optimizer = optimizer
+        self.n_classes = model.n_classes
+        self.save_path = kargs['log']
+        if Path(self.save_path).parts[3] == 'tcmr':
+            temp = list(Path(self.save_path).parts)
+            temp[3] = 'tcmr_viral'
+            self.save_path = '/'.join(temp)
+
+        #---->acc
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+        #---->Metrics
+        if self.n_classes > 2: 
+            self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted')
+            
+            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes,
+                                                                           average='micro'),
+                                                     torchmetrics.CohenKappa(num_classes = self.n_classes),
+                                                     torchmetrics.F1Score(num_classes = self.n_classes,
+                                                                     average = 'macro'),
+                                                     torchmetrics.Recall(average = 'macro',
+                                                                         num_classes = self.n_classes),
+                                                     torchmetrics.Precision(average = 'macro',
+                                                                            num_classes = self.n_classes),
+                                                     torchmetrics.Specificity(average = 'macro',
+                                                                            num_classes = self.n_classes)])
+                                                                            
+        else : 
+            self.AUROC = torchmetrics.AUROC(num_classes=self.n_classes, average = 'weighted')
+
+            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2,
+                                                                           average = 'micro'),
+                                                     torchmetrics.CohenKappa(num_classes = 2),
+                                                     torchmetrics.F1Score(num_classes = 2,
+                                                                     average = 'macro'),
+                                                     torchmetrics.Recall(average = 'macro',
+                                                                         num_classes = 2),
+                                                     torchmetrics.Precision(average = 'macro',
+                                                                            num_classes = 2)])
+        self.PRC = torchmetrics.PrecisionRecallCurve(num_classes = self.n_classes)
+        self.ROC = torchmetrics.ROC(num_classes=self.n_classes)
+        # self.pr_curve = torchmetrics.BinnedPrecisionRecallCurve(num_classes = self.n_classes, thresholds=10)
+        self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes)                                                                    
+        self.valid_metrics = metrics.clone(prefix = 'val_')
+        self.test_metrics = metrics.clone(prefix = 'test_')
+
+        #--->random
+        self.shuffle = kargs['data'].data_shuffle
+        self.count = 0
+        self.backbone = kargs['backbone']
+
+        self.out_features = 1024
+        if kargs['backbone'] == 'dino':
+            self.feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
+            self.model_ft = ViTModel.from_pretrained('facebook/dino-vitb16')
+        elif kargs['backbone'] == 'resnet18':
+            self.model_ft = models.resnet18(pretrained=True)
+            # modules = list(resnet18.children())[:-1]
+            for param in self.model_ft.parameters():
+                param.requires_grad = False
+            self.model_ft.fc = nn.Linear(512, self.out_features)
+
+            # res18 = nn.Sequential(
+            #     *modules,
+            # )
+            # for param in res18.parameters():
+            #     param.requires_grad = False
+            # self.model_ft = nn.Sequential(
+            #     res18,
+            #     nn.AdaptiveAvgPool2d(1),
+            #     View((-1, 512)),
+            #     nn.Linear(512, self.out_features),
+            #     nn.GELU(),
+            # )
+        elif kargs['backbone'] == 'resnet50':
+
+            self.model_ft = models.resnet50(pretrained=True)    
+            for param in self.model_ft.parameters():
+                param.requires_grad = False
+            self.model_ft.fc = nn.Linear(2048, self.out_features)
+
+            # modules = list(resnet50.children())[:-3]
+            # res50 = nn.Sequential(
+            #     *modules,     
+            # )
+            
+            # self.model_ft = nn.Sequential(
+            #     res50,
+            #     nn.AdaptiveAvgPool2d(1),
+            #     View((-1, 1024)),
+            #     nn.Linear(1024, self.out_features),
+            #     # nn.GELU()
+            # )
+        elif kargs['backbone'] == 'efficientnet':
+            efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_widese_b0', pretrained=True)
+            for param in efficientnet.parameters():
+                param.requires_grad = False
+            # efn = list(efficientnet.children())[:-1]
+            efficientnet.classifier.fc = nn.Linear(1280, self.out_features)
+            self.model_ft = nn.Sequential(
+                efficientnet,
+                nn.GELU(),
+            )
+        self.classifier = Classifier_1fc(n_channels=512, n_classes=self.n_classes)
+        self.attention = Attention_Gated(features=512)
+        self.dimreduction = DimReduction(n_channels=self.out_features, m_dim=512)
+        self.attCls = Attention_with_Classifier(L=512, num_cls=self.n_classes)
+        self.trainable_parameters = []
+        self.trainable_parameters += list(self.classifier.parameters())
+        self.trainable_parameters += list(self.attention.parameters())
+        self.trainable_parameters += list(self.dimreduction.parameters())
+        
+        # print(self.model_ft[0].features[-1])
+        # print(self.model_ft)
+
+    def forward(self, x, bag_size=120):
+        # print(x.shape)
+        x = x.float()
+        max_pseudo_bags = x.squeeze(0).shape[0] // bag_size
+        max_pseudo_bags = min(8, max_pseudo_bags)
+        
+        slide_pseudo_feat = []
+        sub_predictions = []
+
+        input = x.squeeze(0)
+        features = self.model_ft(input) # max_pseudo_bags, 512
+        features = self.dimreduction(features)
+        randomized_idx = torch.randperm(features.shape[0])
+
+        
+        for n in range(max_pseudo_bags):
+
+            bag_idxs = randomized_idx[bag_size*n:bag_size*(n+1)] #torch.randperm(x.squeeze(0).shape[0])
+            bag_features = features.squeeze(0)[bag_idxs]
+            
+            t1AA = self.attention(bag_features).squeeze(0)
+            # print('features: ', features.shape)
+            # print('t1AA: ', t1AA.shape)
+            t1attFeats = torch.einsum('ns, n->ns', bag_features, t1AA)
+            # print('t1attFeats: ', t1attFeats.shape)
+            t1attFeats_tensor = torch.sum(t1attFeats, dim=0).unsqueeze(0)
+            # print('t1attFeats_tensor: ', t1attFeats_tensor.shape)
+            t1Predict = self.classifier(t1attFeats_tensor)
+            sub_predictions.append(t1Predict)
+
+            patch_pred_logits = get_cam_1d(self.classifier, t1attFeats.unsqueeze(0)).squeeze(0)
+            patch_pred_logits = torch.transpose(patch_pred_logits, 0, 1)  ## n x cls
+            # patch_pred_softmax = torch.softmax(patch_pred_logits, dim=1)  ## n x cls
+
+            af_inst_feat = t1attFeats_tensor
+            slide_pseudo_feat.append(af_inst_feat)
+
+        slide_pseudo_feat = torch.cat(slide_pseudo_feat, dim=0)
+
+        ## optimization for first tier
+        sub_predictions = torch.cat(sub_predictions, dim=0)
+
+        ## optimization for second tier
+        slide_prediction = self.attCls(slide_pseudo_feat)
+
+        Y_hat = torch.argmax(slide_prediction, dim=1)
+        Y_prob = F.softmax(slide_prediction, dim=1)
+
+        
+        
+        return sub_predictions, slide_prediction, Y_prob, Y_hat
+
+
+    def training_step(self, batch, batch_idx, optimizer_idx):
+
+        input, label, _= batch
+
+
+        sub_predictions, slide_prediction, Y_prob, Y_hat = self(input)
+
+        # print(sub_predictions.size(0))
+        label = label.float()
+        sub_labels = [label] * sub_predictions.size(0)         
+        sub_labels = torch.cat(sub_labels, dim=0)
+        
+        
+        sub_loss = self.loss(sub_predictions, sub_labels)
+        slide_loss = self.loss(slide_prediction, label)
+
+        
+        Y = torch.argmax(label)
+            # Y = int(label[0])
+        self.data[Y]["count"] += 1
+        self.data[Y]["correct"] += (int(Y_hat) == Y)
+        self.log('sub_loss', sub_loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1)
+        self.log('slide_loss', slide_loss, prog_bar=True, on_epoch=True, logger=True, batch_size=1)
+
+        # if self.current_epoch % 10 == 0:
+
+        #     # images = input.squeeze()[:10, :, :, :]
+        #     # for i in range(10):
+        #     img = input.squeeze(0)[:10, :, :, :]
+        #     img = (img - torch.min(img)/(torch.max(img)-torch.min(img)))*255.0
+            
+        #     # mg = img.cpu().numpy()
+        #     grid = torchvision.utils.make_grid(img, normalize=True, value_range=(0, 255), scale_each=True)
+        #     # grid = img.detach().cpu().numpy()
+        # # log input images 
+        #     self.loggers[0].experiment.add_image(f'{self.current_epoch}/input', grid)
+
+
+        # print(Y_prob)
+        # print(Y_prob.shape)
+        
+        total_loss = (sub_loss + slide_loss)/2
+        # print(sub_predictions)
+        # print(sub_labels)
+        # sub_probs = sub_predictions
+        # sub_targets = torch.argmax(sub_labels, dim=1)
+        # if len(sub_targets.unique()) != 1:
+        #     self.log('Train/sub_auc', self.AUROC(sub_predictions, sub_targets), prog_bar=True, on_epoch=True, logger=True)
+
+        # else:    
+        #     self.log('Train/sub_auc', 0.0, prog_bar=True, on_epoch=True, logger=True)
+
+        return {'loss': total_loss, 'Y_prob': Y_prob.detach(), 'Y_hat': Y_hat.detach(), 'label': Y} 
+
+    def training_epoch_end(self, training_step_outputs):
+        # print(training_step_outputs)
+        # for x in training_step_outputs:
+        #     print(x)
+            # print(x['Y_prob'])
+        # logits = torch.cat([x['logits'] for x in training_step_outputs], dim = 0)
+        probs = torch.cat([x[0]['Y_prob'] for x in training_step_outputs])
+        max_probs = torch.stack([x[0]['Y_hat'] for x in training_step_outputs])
+        # target = torch.stack([x['label'] for x in training_step_outputs], dim = 0)
+        target = torch.stack([x[0]['label'] for x in training_step_outputs])
+
+        # sub_probs = torch.cat(x[0]['sub_probs'] for x in training_step_outputs)
+        # sub_targets = torch.cat(x[0]['sub_targets'] for x in training_step_outputs)
+        # target = torch.argmax(target, dim=1)
+        for c in range(self.n_classes):
+            count = self.data[c]["count"]
+            correct = self.data[c]["correct"]
+            if count == 0: 
+                acc = None
+            else:
+                acc = float(correct) / count
+            print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+
+        # print('max_probs: ', max_probs)
+        # print('probs: ', probs)
+        if self.current_epoch % 10 == 0:
+            self.log_confusion_matrix(max_probs, target, stage='train')
+
+        self.log('Train/auc', self.AUROC(probs, target), prog_bar=True, on_epoch=True, logger=True)
+        # self.log('Train/sub_auc', self.AUROC(sub_probs, sub_targets), prog_bar=True, on_epoch=True, logger=True)
+
+    def validation_step(self, batch, batch_idx):
+
+        input, label, _= batch
+        label = label.float()
+
+        sub_predictions, slide_prediction, Y_prob, Y_hat = self(input)
+
+        # print(sub_predictions.size(0))
+        sub_labels = [label] * sub_predictions.size(0)         
+
+        
+        sub_labels = torch.stack(sub_labels).squeeze()
+        # print(sub_labels.shape)
+        # print(sub_predictions.shape)
+        
+        sub_loss = self.loss(sub_predictions, sub_labels)
+        slide_loss = self.loss(slide_prediction, label)
+
+        
+        Y = torch.argmax(label)
+
+        #---->acc log
+        # Y = int(label[0][1])
+        Y = torch.argmax(label)
+
+        self.data[Y]["count"] += 1
+        self.data[Y]["correct"] += (Y_hat.item() == Y)
+
+        return {'val_sub_loss': sub_loss, 'val_slide_loss': slide_loss, 'logits' : slide_prediction, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y}
+
+
+    def validation_epoch_end(self, val_step_outputs):
+        logits = torch.cat([x['logits'] for x in val_step_outputs], dim = 0)
+        probs = torch.cat([x['Y_prob'] for x in val_step_outputs])
+        max_probs = torch.stack([x['Y_hat'] for x in val_step_outputs])
+        target = torch.stack([x['label'] for x in val_step_outputs])
+        
+        self.log_dict(self.valid_metrics(logits, target),
+                          on_epoch = True, logger = True)
+        
+        #---->
+        # logits = logits.long()
+        # target = target.squeeze().long()
+        # logits = logits.squeeze(0)
+        if len(target.unique()) != 1:
+            self.log('val_auc', self.AUROC(probs, target), prog_bar=True, on_epoch=True, logger=True)
+        else:    
+            self.log('val_auc', 0.0, prog_bar=True, on_epoch=True, logger=True)
+
+        
+
+        self.log('val_loss', cross_entropy_torch(logits, target), prog_bar=True, on_epoch=True, logger=True)
+        
+
+        precision, recall, thresholds = self.PRC(probs, target)
+
+
+
+        # print(max_probs.squeeze(0).shape)
+        # print(target.shape)
+        
+
+        #----> log confusion matrix
+        self.log_confusion_matrix(max_probs, target, stage='val')
+        
+
+        #---->acc log
+        for c in range(self.n_classes):
+            count = self.data[c]["count"]
+            correct = self.data[c]["correct"]
+            if count == 0: 
+                acc = None
+            else:
+                acc = float(correct) / count
+            print('val class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+        
+        #---->random, if shuffle data, change seed
+        if self.shuffle == True:
+            self.count = self.count+1
+            random.seed(self.count*50)
+
+    def test_step(self, batch, batch_idx):
+
+        torch.set_grad_enabled(True)
+        data, label, (wsi_name, batch_names) = batch
+        wsi_name = wsi_name[0]
+        label = label.float()
+        # logits, Y_prob, Y_hat = self.step(data) 
+        # print(data.shape)
+        data = data.squeeze(0).float()
+        logits, attn = self(data)
+        attn = attn.detach()
+        logits = logits.detach()
+
+        Y = torch.argmax(label)
+        Y_hat = torch.argmax(logits, dim=1)
+        Y_prob = F.softmax(logits, dim = 1)
+        
+        #----> Get GradCam maps, map each instance to attention value, assemble, overlay on original WSI 
+        if self.model_name == 'TransMIL':
+           
+            target_layers = [self.model.layer2.norm] # 32x32
+            # target_layers = [self.model_ft[0].features[-1]] # 32x32
+            self.cam = GradCAM(model=self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform) #, reshape_transform=self.reshape_transform
+            # self.cam_ft = GradCAM(model=self.model, target_layers = target_layers_ft, use_cuda=True) #, reshape_transform=self.reshape_transform
+        else:
+            target_layers = [self.model.attention_weights]
+            self.cam = GradCAM(model = self.model, target_layers = target_layers, use_cuda=True)
+
+
+        data_ft = self.model_ft(data).unsqueeze(0).float()
+        instance_count = data.size(0)
+        target = [ClassifierOutputTarget(Y)]
+        grayscale_cam = self.cam(input_tensor=data_ft, targets=target)
+        grayscale_cam = torch.Tensor(grayscale_cam)[:instance_count, :]
+
+        # attention_map = grayscale_cam[:, :, 1].squeeze()
+        # attention_map = F.relu(attention_map)
+        # mask = torch.zeros((instance_count, 3, 256, 256)).to(self.device)
+        # for i, v in enumerate(attention_map):
+        #     mask[i, :, :, :] = v
+
+        # mask = self.assemble(mask, batch_names)
+        # mask = (mask - mask.min())/(mask.max()-mask.min())
+        # mask = mask.cpu().numpy()
+        # wsi = self.assemble(data, batch_names)
+        # wsi = wsi.cpu().numpy()
+
+        # def show_cam_on_image(img, mask):
+        #     heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
+        #     heatmap = np.float32(heatmap) / 255
+        #     cam = heatmap*0.4 + np.float32(img)
+        #     cam = cam / np.max(cam)
+        #     return cam
+
+        # wsi = show_cam_on_image(wsi, mask)
+        # wsi = ((wsi-wsi.min())/(wsi.max()-wsi.min()) * 255.0).astype(np.uint8)
+        
+        # img = Image.fromarray(wsi)
+        # img = img.convert('RGB')
+        
+
+        # output_path = self.save_path / str(Y.item())
+        # output_path.mkdir(parents=True, exist_ok=True)
+        # img.save(f'{output_path}/{wsi_name}.jpg')
+
+
+        #----> Get Topk Tiles and Topk Patients
+        summed = torch.mean(grayscale_cam, dim=2)
+        topk_tiles, topk_indices = torch.topk(summed.squeeze(0), 5, dim=0)
+        topk_data = data[topk_indices].detach()
+        
+        # target_ft = 
+        # grayscale_cam_ft = self.cam_ft(input_tensor=data, )
+        # for i in range(data.shape[0]):
+            
+            # vis_img = data[i, :, :, :].cpu().numpy()
+            # vis_img = np.transpose(vis_img, (1,2,0))
+            # print(vis_img.shape)
+            # cam_img = grayscale_cam.squeeze(0)
+        # cam_img = self.reshape_transform(grayscale_cam)
+
+        # print(cam_img.shape)
+            
+            # visualization = show_cam_on_image(vis_img, cam_img, use_rgb=True)
+            # visualization = ((visualization/visualization.max())*255.0).astype(np.uint8)
+            # print(visualization)
+        # cv2.imwrite(f'{test_path}/{Y}/{name}/gradcam.jpg', cam_img)
+
+        #---->acc log
+        Y = torch.argmax(label)
+        self.data[Y]["count"] += 1
+        self.data[Y]["correct"] += (Y_hat.item() == Y)
+
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y, 'name': wsi_name, 'topk_data': topk_data} #
+        # return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name} #, 'topk_data': topk_data
+
+    def test_epoch_end(self, output_results):
+        logits = torch.cat([x['logits'] for x in output_results], dim = 0)
+        probs = torch.cat([x['Y_prob'] for x in output_results])
+        max_probs = torch.stack([x['Y_hat'] for x in output_results])
+        # target = torch.stack([x['label'] for x in output_results], dim = 0)
+        target = torch.stack([x['label'] for x in output_results])
+        # target = torch.argmax(target, dim=1)
+        patients = [x['name'] for x in output_results]
+        topk_tiles = [x['topk_data'] for x in output_results]
+        #---->
+        auc = self.AUROC(probs, target)
+        fpr, tpr, thresholds = self.ROC(probs, target)
+        fpr = fpr.cpu().numpy()
+        tpr = tpr.cpu().numpy()
+
+        plt.figure(1)
+        plt.plot(fpr, tpr)
+        plt.xlabel('False positive rate')
+        plt.ylabel('True positive rate')
+        plt.title('ROC curve')
+        plt.savefig(f'{self.save_path}/roc.jpg')
+        # self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+
+        metrics = self.test_metrics(logits , target)
+
+
+        # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1))
+        metrics['test_auc'] = auc
+
+        # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
+
+        #---->get highest scoring patients for each class
+        # test_path = Path(self.save_path) / 'most_predictive' 
+        
+        # Path.mkdir(output_path, exist_ok=True)
+        topk, topk_indices = torch.topk(probs.squeeze(0), 5, dim=0)
+        for n in range(self.n_classes):
+            print('class: ', n)
+            
+            topk_patients = [patients[i[n]] for i in topk_indices]
+            topk_patient_tiles = [topk_tiles[i[n]] for i in topk_indices]
+            for x, p, t in zip(topk, topk_patients, topk_patient_tiles):
+                print(p, x[n])
+                patient = p
+                # outpath = test_path / str(n) / patient 
+                outpath = Path(self.save_path) / str(n) / patient
+                outpath.mkdir(parents=True, exist_ok=True)
+                for i in range(len(t)):
+                    tile = t[i]
+                    tile = tile.cpu().numpy().transpose(1,2,0)
+                    tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
+                    tile = tile.astype(np.uint8)
+                    img = Image.fromarray(tile)
+                    
+                    img.save(f'{outpath}/{i}.jpg')
+
+            
+            
+        #----->visualize top predictive tiles
+        
+        
+
+        
+                # img = img.squeeze(0).cpu().numpy()
+                # img = np.transpose(img, (1,2,0))
+                # # print(img)
+                # # print(grayscale_cam.shape)
+                # visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)
+
+
+        for keys, values in metrics.items():
+            print(f'{keys} = {values}')
+            metrics[keys] = values.cpu().numpy()
+        #---->acc log
+        for c in range(self.n_classes):
+            count = self.data[c]["count"]
+            correct = self.data[c]["correct"]
+            if count == 0: 
+                acc = None
+            else:
+                acc = float(correct) / count
+            print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+
+        #---->plot auroc curve
+        # stats = stat_scores(probs, target, reduce='macro', num_classes=self.n_classes)
+        # fpr = {}
+        # tpr = {}
+        # for n in self.n_classes: 
+
+        # fpr, tpr, thresh = roc_curve(target.cpu().numpy(), probs.cpu().numpy())
+        #[tp, fp, tn, fn, tp+fn]
+
+
+        self.log_confusion_matrix(max_probs, target, stage='test')
+        #---->
+        result = pd.DataFrame([metrics])
+        result.to_csv(Path(self.save_path) / f'test_result.csv', mode='a', header=not Path(self.save_path).exists())
+
+        # with open(f'{self.save_path}/test_metrics.txt', 'a') as f:
+
+        #     f.write([metrics])
+
+    def configure_optimizers(self):
+        # optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=self.optimizer.lr*0.1)
+        optimizer0 = torch.optim.Adam(self.trainable_parameters, lr=1e-4, weight_decay=1e-2)
+        optimizer1 = torch.optim.Adam(self.attCls.parameters(), lr=1e-4, weight_decay=1e-2)
+
+        scheduler0 = torch.optim.lr_scheduler.MultiStepLR(optimizer0, [100], gamma=0.2)
+        scheduler1 = torch.optim.lr_scheduler.MultiStepLR(optimizer1, [100], gamma=0.2)
+        return [optimizer0, optimizer1], [scheduler0, scheduler1]     
+
+    def reshape_transform(self, tensor):
+        # print(tensor.shape)
+        H = tensor.shape[1]
+        _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
+        add_length = _H * _W - H
+        tensor = torch.cat([tensor, tensor[:,:add_length,:]],dim = 1)
+        result = tensor[:, :, :].reshape(tensor.size(0), _H, _W, tensor.size(2))
+        result = result.transpose(2,3).transpose(1,2)
+        # print(result.shape)
+        return result
+
+    def load_model(self):
+        name = self.hparams.model.name
+        # Change the `trans_unet.py` file name to `TransUnet` class name.
+        # Please always name your model file name as `trans_unet.py` and
+        # class name or funciton name corresponding `TransUnet`.
+        if '_' in name:
+            camel_name = ''.join([i.capitalize() for i in name.split('_')])
+        else:
+            camel_name = name
+        try:
+                
+            Model = getattr(importlib.import_module(
+                f'models.{name}'), camel_name)
+        except:
+            raise ValueError('Invalid Module File Name or Invalid Class Name!')
+        self.model = self.instancialize(Model)
+        pass
+
+    def instancialize(self, Model, **other_args):
+        """ Instancialize a model using the corresponding parameters
+            from self.hparams dictionary. You can also input any args
+            to overwrite the corresponding value in self.hparams.
+        """
+        class_args = inspect.getargspec(Model.__init__).args[1:]
+        inkeys = self.hparams.model.keys()
+        args1 = {}
+        for arg in class_args:
+            if arg in inkeys:
+                args1[arg] = getattr(self.hparams.model, arg)
+        args1.update(other_args)
+        return Model(**args1)
+
+    def log_image(self, tensor, stage, name):
+        
+        tile = tile.cpu().numpy().transpose(1,2,0)
+        tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
+        tile = tile.astype(np.uint8)
+        img = Image.fromarray(tile)
+        self.loggers[0].experiment.add_figure(f'{stage}/{name}', img, self.current_epoch)
+
+
+    def log_confusion_matrix(self, max_probs, target, stage):
+        confmat = self.confusion_matrix(max_probs, target)
+        print(confmat)
+        df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
+        fig_ = sns.heatmap(df_cm, annot=True, fmt='d', cmap='Spectral').get_figure()
+        if stage == 'train':
+            self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+        else:
+            fig_.savefig(f'{self.loggers[0].log_dir}/cm_test.png', dpi=400)
+
+    def log_roc_curve(self, probs, target, stage):
+
+        fpr_list, tpr_list, thresholds = self.ROC(probs, target)
+
+        plt.figure(1)
+        if self.n_classes > 2:
+            for i in range(len(fpr_list)):
+                fpr = fpr_list[i].cpu().numpy()
+                tpr = tpr_list[i].cpu().numpy()
+                plt.plot(fpr, tpr, label=f'class_{i}')
+        else: 
+            print(fpr_list)
+            fpr = fpr_list.cpu().numpy()
+            tpr = tpr_list.cpu().numpy()
+            plt.plot(fpr, tpr)
+        
+        plt.xlabel('False positive rate')
+        plt.ylabel('True positive rate')
+        plt.title('ROC curve')
+        plt.savefig(f'{self.loggers[0].log_dir}/roc.jpg')
+
+        if stage == 'train':
+            self.loggers[0].experiment.add_figure(f'{stage}/ROC', plt, self.current_epoch)
+        else:
+            plt.savefig(f'{self.loggers[0].log_dir}/roc.jpg', dpi=400)
+
+def get_cam_1d(classifier, features):
+    tweight = list(classifier.parameters())[-2]
+    cam_maps = torch.einsum('bgf,cf->bcg', [features, tweight])
+    return cam_maps
+
+class View(nn.Module):
+    def __init__(self, shape):
+        super().__init__()
+        self.shape = shape
+
+    def forward(self, input):
+        '''
+        Reshapes the input according to the shape saved in the view data structure.
+        '''
+        out = input.view(*self.shape)
+        return out
+
diff --git a/models/resnet50.py b/code/models/resnet50.py
similarity index 100%
rename from models/resnet50.py
rename to code/models/resnet50.py
diff --git a/models/vision_transformer.py b/code/models/vision_transformer.py
similarity index 100%
rename from models/vision_transformer.py
rename to code/models/vision_transformer.py
diff --git a/code/test_visualize.py b/code/test_visualize.py
new file mode 100644
index 0000000000000000000000000000000000000000..82f6cb4c6dc9d14404aece06d93c729206b70ad6
--- /dev/null
+++ b/code/test_visualize.py
@@ -0,0 +1,437 @@
+import argparse
+from pathlib import Path
+import numpy as np
+import glob
+import re
+
+from sklearn.model_selection import KFold
+from scipy.interpolate import griddata
+
+# from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule
+from datasets import JPGMILDataloader, MILDataModule
+from models.model_interface import ModelInterface
+import models.vision_transformer as vits
+from utils.utils import *
+
+# pytorch_lightning
+import pytorch_lightning as pl
+from pytorch_lightning import Trainer
+import torch
+
+from pytorch_grad_cam import GradCAM, EigenGradCAM
+from pytorch_grad_cam.utils.image import show_cam_on_image
+from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
+
+import cv2
+from PIL import Image
+from matplotlib import pyplot as plt
+import pandas as pd
+
+#--->Setting parameters
+def make_parse():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--stage', default='test', type=str)
+    parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
+    parser.add_argument('--version', default=0,type=int)
+    parser.add_argument('--epoch', default='0',type=str)
+    parser.add_argument('--gpus', default = 2, type=int)
+    parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
+    parser.add_argument('--fold', default = 0)
+    parser.add_argument('--bag_size', default = 1024, type=int)
+
+    args = parser.parse_args()
+    return args
+
+class custom_test_module(ModelInterface):
+
+    def test_step(self, batch, batch_idx):
+
+        torch.set_grad_enabled(True)
+        input_data, label, (wsi_name, batch_names) = batch
+        wsi_name = wsi_name[0]
+        label = label.float()
+        # logits, Y_prob, Y_hat = self.step(data) 
+        # print(data.shape)
+        input_data = input_data.squeeze(0).float()
+        logits, attn = self(input_data)
+        attn = attn.detach()
+        logits = logits.detach()
+
+        Y = torch.argmax(label)
+        Y_hat = torch.argmax(logits, dim=1)
+        Y_prob = F.softmax(logits, dim=1)
+        
+        #----> Get GradCam maps, map each instance to attention value, assemble, overlay on original WSI 
+        if self.model_name == 'TransMIL':
+            target_layers = [self.model.layer2.norm] # 32x32
+            # target_layers = [self.model_ft[0].features[-1]] # 32x32
+            self.cam = GradCAM(model=self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform) #, reshape_transform=self.reshape_transform
+            # self.cam_ft = GradCAM(model=self.model, target_layers = target_layers_ft, use_cuda=True) #, reshape_transform=self.reshape_transform
+        else:
+            target_layers = [self.model.attention_weights]
+            self.cam = GradCAM(model = self.model, target_layers = target_layers, use_cuda=True)
+
+        data_ft = self.model_ft(input_data).unsqueeze(0).float()
+        instance_count = input_data.size(0)
+        target = [ClassifierOutputTarget(Y)]
+        grayscale_cam = self.cam(input_tensor=data_ft, targets=target)
+        grayscale_cam = torch.Tensor(grayscale_cam)[:instance_count, :] #.to(self.device)
+
+        #----------------------------------------------------
+        # Get Topk Tiles and Topk Patients
+        #----------------------------------------------------
+        summed = torch.mean(grayscale_cam, dim=2)
+        topk_tiles, topk_indices = torch.topk(summed.squeeze(0), 5, dim=0)
+        topk_data = input_data[topk_indices].detach()
+        
+        #----------------------------------------------------
+        # Log Correct/Count
+        #----------------------------------------------------
+        Y = torch.argmax(label)
+        self.data[Y]["count"] += 1
+        self.data[Y]["correct"] += (Y_hat.item() == Y)
+
+        #----------------------------------------------------
+        # Tile Level Attention Maps
+        #----------------------------------------------------
+
+        self.save_attention_map(wsi_name, input_data, batch_names, grayscale_cam, target=Y)
+
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y, 'name': wsi_name, 'topk_data': topk_data} #
+        # return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name} #, 'topk_data': topk_data
+
+    def test_epoch_end(self, output_results):
+
+        logits = torch.cat([x['logits'] for x in output_results], dim = 0)
+        probs = torch.cat([x['Y_prob'] for x in output_results])
+        max_probs = torch.stack([x['Y_hat'] for x in output_results])
+        # target = torch.stack([x['label'] for x in output_results], dim = 0)
+        target = torch.stack([x['label'] for x in output_results])
+        # target = torch.argmax(target, dim=1)
+        patients = [x['name'] for x in output_results]
+        topk_tiles = [x['topk_data'] for x in output_results]
+        #---->
+
+        auc = self.AUROC(probs, target)
+        metrics = self.test_metrics(logits , target)
+
+
+        # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1))
+        metrics['test_auc'] = auc
+
+        # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
+
+        #---->get highest scoring patients for each class
+        # test_path = Path(self.save_path) / 'most_predictive' 
+        
+        # Path.mkdir(output_path, exist_ok=True)
+        topk, topk_indices = torch.topk(probs.squeeze(0), 5, dim=0)
+        for n in range(self.n_classes):
+            print('class: ', n)
+            
+            topk_patients = [patients[i[n]] for i in topk_indices]
+            topk_patient_tiles = [topk_tiles[i[n]] for i in topk_indices]
+            for x, p, t in zip(topk, topk_patients, topk_patient_tiles):
+                print(p, x[n])
+                patient = p
+                # outpath = test_path / str(n) / patient 
+                outpath = Path(self.save_path) / str(n) / patient
+                outpath.mkdir(parents=True, exist_ok=True)
+                for i in range(len(t)):
+                    tile = t[i]
+                    tile = tile.cpu().numpy().transpose(1,2,0)
+                    tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
+                    tile = tile.astype(np.uint8)
+                    img = Image.fromarray(tile)
+                    
+                    img.save(f'{outpath}/{i}.jpg')
+
+        for keys, values in metrics.items():
+            print(f'{keys} = {values}')
+            metrics[keys] = values.cpu().numpy()
+        #---->acc log
+        for c in range(self.n_classes):
+            count = self.data[c]["count"]
+            correct = self.data[c]["correct"]
+            if count == 0: 
+                acc = None
+            else:
+                acc = float(correct) / count
+            print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+
+
+        # self.log_roc_curve(probs, target, 'test')
+        self.log_confusion_matrix(max_probs, target, stage='test')
+        #---->
+        result = pd.DataFrame([metrics])
+        result.to_csv(Path(self.save_path) / f'test_result.csv', mode='a', header=not Path(self.save_path).exists())
+
+    def save_attention_map(self, wsi_name, data, batch_names, grayscale_cam, target):
+
+        def get_coords(batch_names): #ToDO: Change function for precise coords
+            coords = []
+            
+            for tile_name in batch_names: 
+                pos = re.findall(r'\((.*?)\)', tile_name[0])
+                x, y = pos[0].split('_')
+                coords.append((int(x),int(y)))
+            return coords
+        
+        coords = get_coords(batch_names)
+        # temp_data = data.cpu()
+        # print(data.shape)
+        wsi = self.assemble(data, coords).cpu().numpy()
+        # wsi = (wsi-wsi.min())/(wsi.max()-wsi.min())
+        # wsi = wsi
+
+        #--> Get interpolated mask from GradCam
+        W, H = wsi.shape[0], wsi.shape[1]
+        
+        
+        attention_map = grayscale_cam[:, :, 1].squeeze()
+        attention_map = F.relu(attention_map)
+        # print(attention_map)
+        input_h = 256
+        
+        mask = torch.ones(( int(W/input_h), int(H/input_h))).to(self.device)
+
+        for i, (x,y) in enumerate(coords):
+            mask[y][x] = attention_map[i]
+        mask = mask.unsqueeze(0).unsqueeze(0)
+        # mask = torch.stack([mask, mask, mask]).unsqueeze(0)
+
+        mask = F.interpolate(mask, (W,H), mode='bilinear')
+        mask = mask.squeeze(0).permute(1,2,0)
+
+        mask = (mask - mask.min())/(mask.max()-mask.min())
+        mask = mask.cpu().numpy()
+        
+        def show_cam_on_image(img, mask):
+            heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
+            heatmap = np.float32(heatmap) / 255
+            cam = heatmap*0.4 + np.float32(img)
+            cam = cam / np.max(cam)
+            return cam
+
+        wsi_cam = show_cam_on_image(wsi, mask)
+        wsi_cam = ((wsi_cam-wsi_cam.min())/(wsi_cam.max()-wsi_cam.min()) * 255.0).astype(np.uint8)
+        
+        img = Image.fromarray(wsi_cam)
+        img = img.convert('RGB')
+        output_path = self.save_path / str(target.item())
+        output_path.mkdir(parents=True, exist_ok=True)
+        img.save(f'{output_path}/{wsi_name}_gradcam.jpg')
+
+        wsi = ((wsi-wsi.min())/(wsi.max()-wsi.min()) * 255.0).astype(np.uint8)
+        img = Image.fromarray(wsi)
+        img = img.convert('RGB')
+        output_path = self.save_path / str(target.item())
+        output_path.mkdir(parents=True, exist_ok=True)
+        img.save(f'{output_path}/{wsi_name}.jpg')
+
+
+    def assemble(self, tiles, coords): # with coordinates (x-y)
+        
+        def getPosition(img_name):
+            pos = re.findall(r'\((.*?)\)', img_name) #get strings in brackets (0-0)
+            a = int(pos[0].split('-')[0])
+            b = int(pos[0].split('-')[1])
+            return a, b
+
+        position_dict = {}
+        assembled = []
+        # for tile in self.predictions:
+        count = 0
+        # max_x = max(coords, key = lambda t: t[0])[0]
+        d = tiles[0,:,:,:].permute(1,2,0).shape
+        print(d)
+        white_value = 0
+        x_max = max([x[0] for x in coords])
+        y_max = max([x[1] for x in coords])
+
+        for i, (x,y) in enumerate(coords):
+
+            # name = n[0]
+            # image = tiles[i,:,:,:].permute(1,2,0)
+            
+            # d = image.shape
+            # print(image.min())
+            # print(image.max())
+            # if image.max() > white_value:
+            #     white_value = image.max()
+            # # print(image.shape)
+            
+            # tile_position = '-'.join(name.split('_')[-2:])
+            # x,y = getPosition(tile_position)
+            
+            # y_max = y if y > y_max else y_max
+            if x not in position_dict.keys():
+                position_dict[x] = [(y, i)]
+            else: position_dict[x].append((y, i))
+            # count += 1
+        print(position_dict.keys())
+        x_positions = sorted(position_dict.keys())
+        # print(x_max)
+        # complete_image = torch.zeros([x_max, y_max, 3])
+
+
+        for i in range(x_max+1):
+
+            # if i in position_dict.keys():
+            #     print(i)
+            column = [None]*(int(y_max+1))
+            # if len(d) == 3:
+            # empty_tile = torch.zeros(d).to(self.device)
+            # else:
+            # empty_tile = torch.ones(d)
+            empty_tile = torch.ones(d).to(self.device)
+            # print(i)
+            if i in position_dict.keys():
+                # print(i)
+                for j in position_dict[i]:
+                    print(j)
+                    sample_idx = j[1]
+                    print(sample_idx)
+                    # img = tiles[sample_idx, :, :, :].permute(1,2,0)
+                    column[int(j[0])] = tiles[sample_idx, :, :, :]
+            column = [empty_tile if i is None else i for i in column]
+            print(column)
+            # for c in column:
+            #     print(c.shape)
+            # column = torch.vstack(column)
+            # print(column)
+            column = torch.stack(column)
+            assembled.append((i, column))
+        
+        assembled = sorted(assembled, key=lambda x: x[0])
+
+        stack = [i[1] for i in assembled]
+        # print(stack)
+        img_compl = torch.hstack(stack)
+        print(img_compl)
+        return img_compl
+
+
+#---->main
+def main(cfg):
+
+    torch.set_num_threads(16)
+
+    #---->Initialize seed
+    pl.seed_everything(cfg.General.seed)
+
+    #---->load loggers
+    # cfg.load_loggers = load_loggers(cfg)
+
+    # print(cfg.load_loggers)
+    # save_path = Path(cfg.load_loggers[0].log_dir) 
+
+    #---->load callbacks
+    # cfg.callbacks = load_callbacks(cfg, save_path)
+
+    home = Path.cwd().parts[1]
+    cfg.Data.label_file = '/home/ylan/DeepGraft/training_tables/split_PAS_tcmr_viral.json'
+    cfg.Data.data_dir = '/home/ylan/data/DeepGraft/224_128um/'
+    DataInterface_dict = {
+                'data_root': cfg.Data.data_dir,
+                'label_path': cfg.Data.label_file,
+                'batch_size': cfg.Data.train_dataloader.batch_size,
+                'num_workers': cfg.Data.train_dataloader.num_workers,
+                'n_classes': cfg.Model.n_classes,
+                'backbone': cfg.Model.backbone,
+                'bag_size': cfg.Data.bag_size,
+                }
+
+    dm = MILDataModule(**DataInterface_dict)
+    
+
+    #---->Define Model
+    ModelInterface_dict = {'model': cfg.Model,
+                            'loss': cfg.Loss,
+                            'optimizer': cfg.Optimizer,
+                            'data': cfg.Data,
+                            'log': cfg.log_path,
+                            'backbone': cfg.Model.backbone,
+                            }
+    # model = ModelInterface(**ModelInterface_dict)
+    model = custom_test_module(**ModelInterface_dict)
+    # model.save_path = cfg.log_path
+    #---->Instantiate Trainer
+    
+    trainer = Trainer(
+        num_sanity_val_steps=0, 
+        # logger=cfg.load_loggers,
+        # callbacks=cfg.callbacks,
+        max_epochs= cfg.General.epochs,
+        min_epochs = 200,
+        gpus=cfg.General.gpus,
+        # gpus = [0,2],
+        # strategy='ddp',
+        amp_backend='native',
+        # amp_level=cfg.General.amp_level,  
+        precision=cfg.General.precision,  
+        accumulate_grad_batches=cfg.General.grad_acc,
+        # fast_dev_run = True,
+        
+        # deterministic=True,
+        check_val_every_n_epoch=10,
+    )
+
+    #---->train or test
+    log_path = Path(cfg.log_path) / 'checkpoints'
+    # print(log_path)
+    # log_path = Path('lightning_logs/2/checkpoints')
+    model_paths = list(log_path.glob('*.ckpt'))
+
+    if cfg.epoch == 'last':
+        model_paths = [str(model_path) for model_path in model_paths if f'last' in str(model_path)]
+    else:
+        model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)]
+
+    # model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)]
+    # model_paths = [f'lightning_logs/0/.ckpt']
+    # model_paths = [f'{log_path}/last.ckpt']
+    if not model_paths: 
+        print('No Checkpoints vailable!')
+    for path in model_paths:
+        # with open(f'{log_path}/test_metrics.txt', 'w') as f:
+        #     f.write(str(path) + '\n')
+        print(path)
+        new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
+        new_model.save_path = Path(cfg.log_path) / 'visualization'
+        trainer.test(model=new_model, datamodule=dm)
+    
+    # Top 5 scoring patches for patient
+    # GradCam
+
+
+if __name__ == '__main__':
+
+    args = make_parse()
+    cfg = read_yaml(args.config)
+
+    #---->update
+    cfg.config = args.config
+    cfg.General.gpus = [args.gpus]
+    cfg.General.server = args.stage
+    cfg.Data.fold = args.fold
+    cfg.Loss.base_loss = args.loss
+    cfg.Data.bag_size = args.bag_size
+    cfg.version = args.version
+    cfg.epoch = args.epoch
+
+    config_path = '/'.join(Path(cfg.config).parts[1:])
+    log_path = Path(cfg.General.log_path) / str(Path(config_path).parent)
+
+    Path(cfg.General.log_path).mkdir(exist_ok=True, parents=True)
+    log_name =  f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
+    task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:])
+    # task = Path(cfg.config).name[:-5].split('_')[2:][0]
+    cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name / 'lightning_logs' / f'version_{cfg.version}' 
+    
+    
+
+    #---->main
+    main(cfg)
+ 
\ No newline at end of file
diff --git a/train.py b/code/train.py
similarity index 85%
rename from train.py
rename to code/train.py
index 5e30394352f771017d43e31cc8585ed4ef9aeb07..ddd37e443ffee175cdb37b0e2d2e6728728dc5ba 100644
--- a/train.py
+++ b/code/train.py
@@ -7,6 +7,7 @@ from sklearn.model_selection import KFold
 
 from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule
 from models.model_interface import ModelInterface
+from models.model_interface_dtfd import ModelInterface_DTFD
 import models.vision_transformer as vits
 from utils.utils import *
 
@@ -80,11 +81,14 @@ def main(cfg):
                             'log': cfg.log_path,
                             'backbone': cfg.Model.backbone,
                             }
-    model = ModelInterface(**ModelInterface_dict)
+    if cfg.Model.name == 'DTFDMIL':
+        model = ModelInterface_DTFD(**ModelInterface_dict)
+    else:
+        model = ModelInterface(**ModelInterface_dict)
     
     #---->Instantiate Trainer
     trainer = Trainer(
-        num_sanity_val_steps=0, 
+        # num_sanity_val_steps=0, 
         logger=cfg.load_loggers,
         callbacks=cfg.callbacks,
         max_epochs= cfg.General.epochs,
@@ -96,11 +100,25 @@ def main(cfg):
         # amp_level=cfg.General.amp_level,  
         precision=cfg.General.precision,  
         accumulate_grad_batches=cfg.General.grad_acc,
+        gradient_clip_val=0.0,
         # fast_dev_run = True,
+        # limit_train_batches=1,
         
         # deterministic=True,
-        check_val_every_n_epoch=10,
+        check_val_every_n_epoch=5,
     )
+    # print(cfg.log_path)
+    # print(trainer.loggers[0].log_dir)
+    # print(trainer.loggers[1].log_dir)
+    #----> Copy Code
+    copy_path = Path(trainer.loggers[0].log_dir) / 'code'
+    copy_path.mkdir(parents=True, exist_ok=True)
+    copy_origin = '/' / Path('/'.join(cfg.log_path.parts[1:5])) / 'code'
+    # print(copy_origin)
+    shutil.copytree(copy_origin, copy_path, dirs_exist_ok=True)
+
+    
+    # print(trainer.loggers[0].log_dir)
 
     #---->train or test
     if cfg.resume_training:
@@ -148,7 +166,11 @@ if __name__ == '__main__':
     cfg.Data.bag_size = args.bag_size
     cfg.version = args.version
 
-    log_path = Path(cfg.General.log_path) / str(Path(cfg.config).parent)
+    config_path = '/'.join(Path(cfg.config).parts[1:])
+    log_path = Path(cfg.General.log_path) / str(Path(config_path).parent)
+    # print(log_path)
+
+
     Path(cfg.General.log_path).mkdir(exist_ok=True, parents=True)
     log_name =  f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
     task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:])
diff --git a/code/train_loop.py b/code/train_loop.py
new file mode 100644
index 0000000000000000000000000000000000000000..d198abd443a49f53d3da58ba87c293dab4c4acd5
--- /dev/null
+++ b/code/train_loop.py
@@ -0,0 +1,496 @@
+from pytorch_lightning import LightningModule
+import torch
+import torch.nn.functional as F
+from torchmetrics.classification.accuracy import Accuracy
+import os.path as osp
+from abc import ABC, abstractmethod
+from copy import deepcopy
+from pytorch_lightning import LightningModule
+from pytorch_lightning.loops.base import Loop
+from pytorch_lightning.loops.fit_loop import FitLoop
+from pytorch_lightning.trainer.states import TrainerFn
+from datasets.data_interface import BaseKFoldDataModule
+from typing import Any, Dict, List, Optional, Type
+import torchmetrics
+import numpy as np
+from PIL import Image
+import cv2
+import re
+
+import pandas as pd
+import matplotlib.pyplot as plt
+import seaborn as sns
+from test_visualize import custom_test_module
+from pytorch_grad_cam import GradCAM, EigenGradCAM
+from pytorch_grad_cam.utils.image import show_cam_on_image
+from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
+from pathlib import Path
+
+
+
+class EnsembleVotingModel(LightningModule):
+    def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str], n_classes, log_path) -> None:
+        super().__init__()
+        # Create `num_folds` models with their associated fold weights
+        self.n_classes = n_classes
+        self.log_path = log_path
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+        self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths])
+        self.test_acc = Accuracy()
+        if self.n_classes > 2: 
+            self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted')
+            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes,
+                                                                           average='micro'),
+                                                     torchmetrics.CohenKappa(num_classes = self.n_classes),
+                                                     torchmetrics.F1Score(num_classes = self.n_classes,
+                                                                     average = 'macro'),
+                                                     torchmetrics.Recall(average = 'macro',
+                                                                         num_classes = self.n_classes),
+                                                     torchmetrics.Precision(average = 'macro',
+                                                                            num_classes = self.n_classes),
+                                                     torchmetrics.Specificity(average = 'macro',
+                                                                            num_classes = self.n_classes)])
+                                                                            
+        else : 
+            self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted')
+            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2,
+                                                                           average = 'micro'),
+                                                     torchmetrics.CohenKappa(num_classes = 2),
+                                                     torchmetrics.F1Score(num_classes = 2,
+                                                                     average = 'macro'),
+                                                     torchmetrics.Recall(average = 'macro',
+                                                                         num_classes = 2),
+                                                     torchmetrics.Precision(average = 'macro',
+                                                                            num_classes = 2)])
+        self.test_metrics = metrics.clone(prefix = 'test_')
+        self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes)
+    def test_step(self, batch, batch_idx):
+
+        torch.set_grad_enabled(True)
+        data, label, (wsi_name, batch_names) = batch
+        wsi_name = wsi_name[0]
+        label = label.float()
+        # logits, Y_prob, Y_hat = self.step(data) 
+        # print(data.shape)
+        data = data.squeeze(0).float()
+        logits, attn = self(data)
+        attn = attn.detach()
+        logits = logits.detach()
+
+        Y = torch.argmax(label)
+        Y_hat = torch.argmax(logits, dim=1)
+        Y_prob = F.softmax(logits, dim = 1)
+        
+        #----> Get GradCam maps, map each instance to attention value, assemble, overlay on original WSI 
+        if self.model_name == 'TransMIL':
+           
+            target_layers = [self.model.layer2.norm] # 32x32
+            # target_layers = [self.model_ft[0].features[-1]] # 32x32
+            self.cam = GradCAM(model=self.model, target_layers = target_layers, use_cuda=True, reshape_transform=self.reshape_transform) #, reshape_transform=self.reshape_transform
+            # self.cam_ft = GradCAM(model=self.model, target_layers = target_layers_ft, use_cuda=True) #, reshape_transform=self.reshape_transform
+        else:
+            target_layers = [self.model.attention_weights]
+            self.cam = GradCAM(model = self.model, target_layers = target_layers, use_cuda=True)
+
+
+        data_ft = self.model_ft(data).unsqueeze(0).float()
+        instance_count = data.size(0)
+        target = [ClassifierOutputTarget(Y)]
+        grayscale_cam = self.cam(input_tensor=data_ft, targets=target)
+        grayscale_cam = torch.Tensor(grayscale_cam)[:instance_count, :] #.to(self.device)
+
+        #----------------------------------------------------
+        # Get Topk Tiles and Topk Patients
+        #----------------------------------------------------
+        summed = torch.mean(grayscale_cam, dim=2)
+        topk_tiles, topk_indices = torch.topk(summed.squeeze(0), 5, dim=0)
+        topk_data = data[topk_indices].detach()
+        
+        #----------------------------------------------------
+        # Log Correct/Count
+        #----------------------------------------------------
+        Y = torch.argmax(label)
+        self.data[Y]["count"] += 1
+        self.data[Y]["correct"] += (Y_hat.item() == Y)
+
+        #----------------------------------------------------
+        # Tile Level Attention Maps
+        #----------------------------------------------------
+
+        # self.save_attention_map(wsi_name, data, batch_names, grayscale_cam, Y)
+
+        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : Y, 'name': wsi_name, 'topk_data': topk_data} #
+        # return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label, 'name': name} #, 'topk_data': topk_data
+
+    def test_epoch_end(self, output_results):
+
+        logits = torch.cat([x['logits'] for x in output_results], dim = 0)
+        probs = torch.cat([x['Y_prob'] for x in output_results])
+        max_probs = torch.stack([x['Y_hat'] for x in output_results])
+        # target = torch.stack([x['label'] for x in output_results], dim = 0)
+        target = torch.stack([x['label'] for x in output_results])
+        # target = torch.argmax(target, dim=1)
+        patients = [x['name'] for x in output_results]
+        topk_tiles = [x['topk_data'] for x in output_results]
+        #---->
+        
+
+
+        auc = self.AUROC(probs, target)
+        metrics = self.test_metrics(logits , target)
+
+
+        # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1))
+        metrics['test_auc'] = auc
+
+        # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
+
+        #---->get highest scoring patients for each class
+        # test_path = Path(self.save_path) / 'most_predictive' 
+        
+        # Path.mkdir(output_path, exist_ok=True)
+        topk, topk_indices = torch.topk(probs.squeeze(0), 5, dim=0)
+        for n in range(self.n_classes):
+            print('class: ', n)
+            
+            topk_patients = [patients[i[n]] for i in topk_indices]
+            topk_patient_tiles = [topk_tiles[i[n]] for i in topk_indices]
+            for x, p, t in zip(topk, topk_patients, topk_patient_tiles):
+                print(p, x[n])
+                patient = p
+                # outpath = test_path / str(n) / patient 
+                outpath = Path(self.save_path) / str(n) / patient
+                outpath.mkdir(parents=True, exist_ok=True)
+                for i in range(len(t)):
+                    tile = t[i]
+                    tile = tile.cpu().numpy().transpose(1,2,0)
+                    tile = (tile - tile.min())/ (tile.max() - tile.min()) * 255
+                    tile = tile.astype(np.uint8)
+                    img = Image.fromarray(tile)
+                    
+                    img.save(f'{outpath}/{i}.jpg')
+
+        for keys, values in metrics.items():
+            print(f'{keys} = {values}')
+            metrics[keys] = values.cpu().numpy()
+        #---->acc log
+        for c in range(self.n_classes):
+            count = self.data[c]["count"]
+            correct = self.data[c]["correct"]
+            if count == 0: 
+                acc = None
+            else:
+                acc = float(correct) / count
+            print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+
+
+        # self.log_roc_curve(probs, target, 'test')
+        self.log_confusion_matrix(max_probs, target, stage='test')
+        #---->
+        result = pd.DataFrame([metrics])
+        result.to_csv(Path(self.save_path) / f'test_result.csv', mode='a', header=not Path(self.save_path).exists())
+
+    def save_attention_map(self, wsi_name, data, batch_names, grayscale_cam, target):
+
+        def get_coords(batch_names): #ToDO: Change function for precise coords
+            coords = []
+            
+            for tile_name in batch_names: 
+                pos = re.findall(r'\((.*?)\)', tile_name[0])
+                x, y = pos[0].split('_')
+
+                coords.append((int(x),int(y)))
+            return coords
+
+        
+        coords = get_coords(batch_names)
+        # temp_data = data.cpu()
+        print(data.shape)
+        wsi = self.assemble(data, coords).cpu().numpy()
+        wsi = (wsi-wsi.min())/(wsi.max()-wsi.min())
+        # wsi = wsi
+
+        #--> Get interpolated mask from GradCam
+        W, H = wsi.shape[0], wsi.shape[1]
+        
+        
+        attention_map = grayscale_cam[:, :, 1].squeeze()
+        attention_map = F.relu(attention_map)
+        # print(attention_map)
+        input_h = 256
+        
+        mask = torch.ones(( int(W/input_h), int(H/input_h))).to(self.device)
+
+        for i, (x,y) in enumerate(coords):
+            mask[y][x] = attention_map[i]
+        mask = mask.unsqueeze(0).unsqueeze(0)
+        # mask = torch.stack([mask, mask, mask]).unsqueeze(0)
+
+        mask = F.interpolate(mask, (W,H), mode='bilinear')
+        mask = mask.squeeze(0).permute(1,2,0)
+
+        mask = (mask - mask.min())/(mask.max()-mask.min())
+        mask = mask.cpu().numpy()
+        
+        def show_cam_on_image(img, mask):
+            heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
+            heatmap = np.float32(heatmap) / 255
+            cam = heatmap*0.4 + np.float32(img)
+            cam = cam / np.max(cam)
+            return cam
+
+        wsi_cam = show_cam_on_image(wsi, mask)
+        wsi_cam = ((wsi_cam-wsi_cam.min())/(wsi_cam.max()-wsi_cam.min()) * 255.0).astype(np.uint8)
+        
+        img = Image.fromarray(wsi_cam)
+        img = img.convert('RGB')
+        output_path = self.save_path / str(target)
+        output_path.mkdir(parents=True, exist_ok=True)
+        img.save(f'{output_path}/{wsi_name}_gradcam.jpg')
+
+        wsi = ((wsi-wsi.min())/(wsi.max()-wsi.min()) * 255.0).astype(np.uint8)
+        img = Image.fromarray(wsi)
+        img = img.convert('RGB')
+        output_path = self.save_path / str(target)
+        output_path.mkdir(parents=True, exist_ok=True)
+        img.save(f'{output_path}/{wsi_name}.jpg')
+
+
+    def assemble(self, tiles, coords): # with coordinates (x-y)
+        
+        def getPosition(img_name):
+            pos = re.findall(r'\((.*?)\)', img_name) #get strings in brackets (0-0)
+            a = int(pos[0].split('-')[0])
+            b = int(pos[0].split('-')[1])
+            return a, b
+
+        position_dict = {}
+        assembled = []
+        y_max = 0
+        # for tile in self.predictions:
+        count = 0
+        max_x = max(coords, key = lambda t: t[0])[0]
+        d = 0
+        white_value = 0
+
+        for i, (x,y) in enumerate(coords):
+
+            # name = n[0]
+            image = tiles[i,:,:,:].permute(1,2,0)
+            
+            d = image.shape
+            # print(image.min())
+            # print(image.max())
+            # if image.max() > white_value:
+            #     white_value = image.max()
+            # # print(image.shape)
+            
+            # tile_position = '-'.join(name.split('_')[-2:])
+            # x,y = getPosition(tile_position)
+            
+            y_max = y if y > y_max else y_max
+            if x not in position_dict.keys():
+                position_dict[x] = [(y, image)]
+            else: position_dict[x].append((y, image))
+            count += 1
+        
+
+        for i in range(max_x+1):
+            column = [None]*(int(y_max+1))
+            # if len(d) == 3:
+            #     empty_tile = torch.zeros(d).to(self.device)
+            # else:
+            empty_tile = torch.ones(d)
+            empty_tile = torch.ones(d).to(self.device)
+            if i in position_dict.keys():
+                for j in position_dict[i]:
+                    sample = j[1]
+                    column[int(j[0])] = sample
+            column = [empty_tile if i is None else i for i in column]
+            # for c in column:
+            #     print(c.shape)
+            # column = torch.vstack(column)
+            column = torch.stack(column)
+            assembled.append((i, column))
+
+
+
+        # for key in position_dict.keys():
+        #     column = [None]*(int(y_max+1))
+        #     # print(key)
+        #     for i in position_dict[key]:
+        #         sample = i[1]
+        #         d = sample.shape
+        #         # print(d) # [3,256,256]
+        #         if len(d) == 3:
+        #             empty_tile = torch.ones(d).to(self.device)
+        #         else:
+        #             empty_tile = torch.zeros(d).to(self.device)
+        #         column[int(i[0])] = sample
+        #     column = [empty_tile if i is None else i for i in column]
+        #     column = torch.vstack(column)
+        #     assembled.append((key, column))
+        # print(len(assembled))
+        
+        assembled = sorted(assembled, key=lambda x: x[0])
+
+        stack = [i[1] for i in assembled]
+        # print(stack)
+        img_compl = torch.hstack(stack)
+        return img_compl
+    # def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
+    #     # Compute the averaged predictions over the `num_folds` models.
+    #     # print(batch[0].shape)
+    #     input, label, _ = batch
+    #     label = label.float()
+    #     input = input.squeeze(0).float()
+
+            
+    #     logits = torch.stack([m(input) for m in self.models]).mean(0)
+    #     Y_hat = torch.argmax(logits, dim=1)
+    #     Y_prob = F.softmax(logits, dim = 1)
+    #     # #---->acc log
+    #     Y = torch.argmax(label)
+    #     self.data[Y]["count"] += 1
+    #     self.data[Y]["correct"] += (Y_hat.item() == Y)
+
+    #     return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
+
+    # def test_epoch_end(self, output_results):
+    #     probs = torch.cat([x['Y_prob'] for x in output_results])
+    #     max_probs = torch.stack([x['Y_hat'] for x in output_results])
+    #     # target = torch.stack([x['label'] for x in output_results], dim = 0)
+    #     target = torch.cat([x['label'] for x in output_results])
+    #     target = torch.argmax(target, dim=1)
+        
+    #     #---->
+    #     auc = self.AUROC(probs, target.squeeze())
+    #     metrics = self.test_metrics(max_probs.squeeze() , target)
+
+
+    #     # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1))
+    #     metrics['test_auc'] = auc
+
+    #     # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
+
+    #     # print(max_probs.squeeze(0).shape)
+    #     # print(target.shape)
+    #     # self.log_dict(metrics, logger = True)
+    #     for keys, values in metrics.items():
+    #         print(f'{keys} = {values}')
+    #         metrics[keys] = values.cpu().numpy()
+    #     #---->acc log
+    #     for c in range(self.n_classes):
+    #         count = self.data[c]["count"]
+    #         correct = self.data[c]["correct"]
+    #         if count == 0: 
+    #             acc = None
+    #         else:
+    #             acc = float(correct) / count
+    #         print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
+    #     self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
+
+    #     self.log_confusion_matrix(probs, target, stage='test')
+    #     #---->
+    #     result = pd.DataFrame([metrics])
+    #     result.to_csv(self.log_path / 'result.csv')
+
+
+    def log_confusion_matrix(self, max_probs, target, stage):
+            confmat = self.confusion_matrix(max_probs.squeeze(), target)
+            df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
+            plt.figure()
+            fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
+            # plt.close(fig_)
+            # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}')
+            self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
+
+            if stage == 'test':
+                plt.savefig(f'{self.log_path}/cm_test')
+            plt.close(fig_)
+
+class KFoldLoop(Loop):
+    def __init__(self, num_folds: int, export_path: str, **kargs) -> None:
+        super().__init__()
+        self.num_folds = num_folds
+        self.current_fold: int = 0
+        self.export_path = export_path
+        self.n_classes = kargs["model"].n_classes
+        self.log_path = kargs["log"]
+
+    @property
+    def done(self) -> bool:
+        return self.current_fold >= self.num_folds
+
+    def connect(self, fit_loop: FitLoop) -> None:
+        self.fit_loop = fit_loop
+
+    def reset(self) -> None:
+        """Nothing to reset in this loop."""
+
+    def on_run_start(self, *args: Any, **kwargs: Any) -> None:
+        """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
+        model."""
+        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
+        self.trainer.datamodule.setup_folds(self.num_folds)
+        self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict())
+
+    def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
+        """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
+        print(f"STARTING FOLD {self.current_fold}")
+        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
+        self.trainer.datamodule.setup_fold_index(self.current_fold)
+
+    def advance(self, *args: Any, **kwargs: Any) -> None:
+        """Used to the run a fitting and testing on the current hold."""
+        self._reset_fitting()  # requires to reset the tracking stage.
+        self.fit_loop.run()
+
+        self._reset_testing()  # requires to reset the tracking stage.
+        self.trainer.test_loop.run()
+        self.current_fold += 1  # increment fold tracking number.
+
+    def on_advance_end(self) -> None:
+        """Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
+        self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt"))
+        # restore the original weights + optimizers and schedulers.
+        self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
+        self.trainer.strategy.setup_optimizers(self.trainer)
+        self.replace(fit_loop=FitLoop)
+
+    def on_run_end(self) -> None:
+        """Used to compute the performance of the ensemble model on the test set."""
+        checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)]
+        voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths, n_classes=self.n_classes, log_path=self.log_path)
+        voting_model.trainer = self.trainer
+        # This requires to connect the new model and move it the right device.
+        self.trainer.strategy.connect(voting_model)
+        self.trainer.strategy.model_to_device()
+        self.trainer.test_loop.run()
+
+    def on_save_checkpoint(self) -> Dict[str, int]:
+        return {"current_fold": self.current_fold}
+
+    def on_load_checkpoint(self, state_dict: Dict) -> None:
+        self.current_fold = state_dict["current_fold"]
+
+    def _reset_fitting(self) -> None:
+        self.trainer.reset_train_dataloader()
+        self.trainer.reset_val_dataloader()
+        self.trainer.state.fn = TrainerFn.FITTING
+        self.trainer.training = True
+
+    def _reset_testing(self) -> None:
+        self.trainer.reset_test_dataloader()
+        self.trainer.state.fn = TrainerFn.TESTING
+        self.trainer.testing = True
+
+    def __getattr__(self, key) -> Any:
+        # requires to be overridden as attributes of the wrapped loop are being accessed.
+        if key not in self.__dict__:
+            return getattr(self.fit_loop, key)
+        return self.__dict__[key]
+
+    def __setstate__(self, state: Dict[str, Any]) -> None:
+        self.__dict__.update(state)
\ No newline at end of file
diff --git a/utils/__init__.py b/code/utils/__init__.py
similarity index 100%
rename from utils/__init__.py
rename to code/utils/__init__.py
diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/code/utils/__pycache__/__init__.cpython-39.pyc
similarity index 100%
rename from utils/__pycache__/__init__.cpython-39.pyc
rename to code/utils/__pycache__/__init__.cpython-39.pyc
diff --git a/code/utils/__pycache__/utils.cpython-39.pyc b/code/utils/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4e86955d7bdf2df00f04cb9e61d452bc3f04da04
Binary files /dev/null and b/code/utils/__pycache__/utils.cpython-39.pyc differ
diff --git a/code/utils/extract_features.py b/code/utils/extract_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ec600e3cf5ee46cd48d57d55173c01ce8bbc7fd
--- /dev/null
+++ b/code/utils/extract_features.py
@@ -0,0 +1,37 @@
+## Choose Model and extract features from (augmented) image patches and save as .pt file
+
+# from datasets.custom_dataloader import HDF5MILDataloader
+from datasets import JPGMILDataloader
+from torchvision import models
+
+device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+def extract_features(input_dir, output_dir, model, batch_size):
+
+    
+
+    dataset = JPGMILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes)
+    model =  models.resnet50(pretrained=True)    
+    for param in self.model_ft.parameters():
+        param.requires_grad = False
+    self.model_ft.fc = nn.Linear(2048, self.out_features)
+
+    model = model.to(device)
+    model.eval()
+
+    
+
+
+
+
+if __name__ == '__main__':
+
+    # input_dir, output_dir
+    # initiate data loader
+    # use data loader to load and augment images 
+    # prediction from model
+    # choose save as bag or not (needed?)
+    
+    # features = torch.from_numpy(features)
+    # torch.save(features, output_path + '.pt') 
+
diff --git a/utils/utils.py b/code/utils/utils.py
similarity index 95%
rename from utils/utils.py
rename to code/utils/utils.py
index 814cf1df0d1258fa19663208e91cd23ece06f6eb..7e992cffd810cd7cf54babc70477e8dc174d82e3 100755
--- a/utils/utils.py
+++ b/code/utils/utils.py
@@ -15,6 +15,7 @@ from pytorch_lightning.loops.base import Loop
 from pytorch_lightning.loops.fit_loop import FitLoop
 from pytorch_lightning.trainer.states import TrainerFn
 from typing import Any, Dict, List, Optional, Type
+import shutil
 
 #---->read yaml
 import yaml
@@ -26,6 +27,7 @@ def read_yaml(fpath=None):
 
 #---->load Loggers
 from pytorch_lightning import loggers as pl_loggers
+
 def load_loggers(cfg):
 
     # log_path = cfg.General.log_path
@@ -40,9 +42,12 @@ def load_loggers(cfg):
         tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path,
                                                   # version = f'fold{cfg.Data.fold}'
                                                 log_graph = True, default_hp_metric = False)
+        # print(tb_logger.version)
+        version = tb_logger.version
         #---->CSV
-        csv_logger = pl_loggers.CSVLogger(cfg.log_path,
+        csv_logger = pl_loggers.CSVLogger(cfg.log_path, version = version
                                         ) # version = f'fold{cfg.Data.fold}', 
+        # print(csv_logger.version)
     else:  
         cfg.log_path = Path(cfg.log_path) / f'test'
         tb_logger = pl_loggers.TensorBoardLogger(cfg.log_path,
@@ -98,7 +103,7 @@ def load_callbacks(cfg, save_path):
         # save_path = Path(cfg.log_path) / 'lightning_logs' / f'version_{cfg.resume_version}' / last.ckpt
         Mycallbacks.append(ModelCheckpoint(monitor = 'val_loss',
                                          dirpath = str(output_path),
-                                         filename = '{epoch:02d}-{val_loss:.4f}',
+                                         filename = '{epoch:02d}-{val_loss:.4f}-{val_auc: .4f}',
                                          verbose = True,
                                          save_last = True,
                                          save_top_k = 1,
@@ -106,7 +111,7 @@ def load_callbacks(cfg, save_path):
                                          save_weights_only = True))
         Mycallbacks.append(ModelCheckpoint(monitor = 'val_auc',
                                          dirpath = str(output_path),
-                                         filename = '{epoch:02d}-{val_auc:.4f}',
+                                         filename = '{epoch:02d}-{val_loss:.4f}-{val_auc:.4f}',
                                          verbose = True,
                                          save_last = True,
                                          save_top_k = 1,
diff --git a/datasets/__init__.py b/datasets/__init__.py
deleted file mode 100644
index e752eeb66fddc402cda5441a62aa6fc082f47be4..0000000000000000000000000000000000000000
--- a/datasets/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-
-from .data_interface import DataInterface
\ No newline at end of file
diff --git a/datasets/__pycache__/__init__.cpython-39.pyc b/datasets/__pycache__/__init__.cpython-39.pyc
deleted file mode 100644
index 13006df55083dc070f46953e4505bfef1ac8b198..0000000000000000000000000000000000000000
Binary files a/datasets/__pycache__/__init__.cpython-39.pyc and /dev/null differ
diff --git a/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc b/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc
deleted file mode 100644
index 20aefffd1a14afe3c499f646e9eb674810896281..0000000000000000000000000000000000000000
Binary files a/datasets/__pycache__/custom_jpg_dataloader.cpython-39.pyc and /dev/null differ
diff --git a/models/__pycache__/TransMIL.cpython-39.pyc b/models/__pycache__/TransMIL.cpython-39.pyc
deleted file mode 100644
index a329e23346d376b150c7ba792b8cd3593a128d95..0000000000000000000000000000000000000000
Binary files a/models/__pycache__/TransMIL.cpython-39.pyc and /dev/null differ
diff --git a/models/__pycache__/model_interface.cpython-39.pyc b/models/__pycache__/model_interface.cpython-39.pyc
deleted file mode 100644
index e9d22d7ddccbb2a3d2ab59b98e066f9ac5d42285..0000000000000000000000000000000000000000
Binary files a/models/__pycache__/model_interface.cpython-39.pyc and /dev/null differ
diff --git a/test_visualize.py b/test_visualize.py
deleted file mode 100644
index 7ef56d3fec1a0c3d926428e95b051c13bdb1da96..0000000000000000000000000000000000000000
--- a/test_visualize.py
+++ /dev/null
@@ -1,148 +0,0 @@
-import argparse
-from pathlib import Path
-import numpy as np
-import glob
-
-from sklearn.model_selection import KFold
-
-from datasets.data_interface import DataInterface, MILDataModule, CrossVal_MILDataModule
-from models.model_interface import ModelInterface
-import models.vision_transformer as vits
-from utils.utils import *
-
-# pytorch_lightning
-import pytorch_lightning as pl
-from pytorch_lightning import Trainer
-import torch
-from train_loop import KFoldLoop
-
-#--->Setting parameters
-def make_parse():
-    parser = argparse.ArgumentParser()
-    parser.add_argument('--stage', default='train', type=str)
-    parser.add_argument('--config', default='DeepGraft/TransMIL.yaml',type=str)
-    parser.add_argument('--version', default=0,type=int)
-    parser.add_argument('--epoch', default='0',type=str)
-    parser.add_argument('--gpus', default = 2, type=int)
-    parser.add_argument('--loss', default = 'CrossEntropyLoss', type=str)
-    parser.add_argument('--fold', default = 0)
-    parser.add_argument('--bag_size', default = 1024, type=int)
-
-    args = parser.parse_args()
-    return args
-
-#---->main
-def main(cfg):
-
-    torch.set_num_threads(16)
-
-    #---->Initialize seed
-    pl.seed_everything(cfg.General.seed)
-
-    #---->load loggers
-    # cfg.load_loggers = load_loggers(cfg)
-
-    # print(cfg.load_loggers)
-    # save_path = Path(cfg.load_loggers[0].log_dir) 
-
-    #---->load callbacks
-    # cfg.callbacks = load_callbacks(cfg, save_path)
-
-    home = Path.cwd().parts[1]
-    DataInterface_dict = {
-                'data_root': cfg.Data.data_dir,
-                'label_path': cfg.Data.label_file,
-                'batch_size': cfg.Data.train_dataloader.batch_size,
-                'num_workers': cfg.Data.train_dataloader.num_workers,
-                'n_classes': cfg.Model.n_classes,
-                'backbone': cfg.Model.backbone,
-                'bag_size': cfg.Data.bag_size,
-                }
-
-    dm = MILDataModule(**DataInterface_dict)
-    
-
-    #---->Define Model
-    ModelInterface_dict = {'model': cfg.Model,
-                            'loss': cfg.Loss,
-                            'optimizer': cfg.Optimizer,
-                            'data': cfg.Data,
-                            'log': cfg.log_path,
-                            'backbone': cfg.Model.backbone,
-                            }
-    model = ModelInterface(**ModelInterface_dict)
-    
-    #---->Instantiate Trainer
-    trainer = Trainer(
-        num_sanity_val_steps=0, 
-        # logger=cfg.load_loggers,
-        # callbacks=cfg.callbacks,
-        max_epochs= cfg.General.epochs,
-        min_epochs = 200,
-        gpus=cfg.General.gpus,
-        # gpus = [0,2],
-        # strategy='ddp',
-        amp_backend='native',
-        # amp_level=cfg.General.amp_level,  
-        precision=cfg.General.precision,  
-        accumulate_grad_batches=cfg.General.grad_acc,
-        # fast_dev_run = True,
-        
-        # deterministic=True,
-        check_val_every_n_epoch=10,
-    )
-
-    #---->train or test
-    log_path = cfg.log_path
-    # print(log_path)
-    # log_path = Path('lightning_logs/2/checkpoints')
-    model_paths = list(log_path.glob('*.ckpt'))
-
-    if cfg.epoch == 'last':
-        model_paths = [str(model_path) for model_path in model_paths if f'last' in str(model_path)]
-    else:
-        model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)]
-
-    # model_paths = [str(model_path) for model_path in model_paths if f'epoch={cfg.epoch}' in str(model_path)]
-    # model_paths = [f'lightning_logs/0/.ckpt']
-    # model_paths = [f'{log_path}/last.ckpt']
-    if not model_paths: 
-        print('No Checkpoints vailable!')
-    for path in model_paths:
-        # with open(f'{log_path}/test_metrics.txt', 'w') as f:
-        #     f.write(str(path) + '\n')
-        print(path)
-        new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
-        trainer.test(model=new_model, datamodule=dm)
-    
-    # Top 5 scoring patches for patient
-    # GradCam
-
-
-if __name__ == '__main__':
-
-    args = make_parse()
-    cfg = read_yaml(args.config)
-
-    #---->update
-    cfg.config = args.config
-    cfg.General.gpus = [args.gpus]
-    cfg.General.server = args.stage
-    cfg.Data.fold = args.fold
-    cfg.Loss.base_loss = args.loss
-    cfg.Data.bag_size = args.bag_size
-    cfg.version = args.version
-    cfg.epoch = args.epoch
-
-    log_path = Path(cfg.General.log_path) / str(Path(cfg.config).parent)
-    Path(cfg.General.log_path).mkdir(exist_ok=True, parents=True)
-    log_name =  f'_{cfg.Model.backbone}' + f'_{cfg.Loss.base_loss}'
-    task = '_'.join(Path(cfg.config).name[:-5].split('_')[2:])
-    # task = Path(cfg.config).name[:-5].split('_')[2:][0]
-    cfg.log_path = log_path / f'{cfg.Model.name}' / task / log_name / 'lightning_logs' / f'version_{cfg.version}' / 'checkpoints'
-    
-    
-
-    #---->main
-    main(cfg)
- 
\ No newline at end of file
diff --git a/train_loop.py b/train_loop.py
deleted file mode 100644
index f9236814589b30d30aab061aaf52599db5a130df..0000000000000000000000000000000000000000
--- a/train_loop.py
+++ /dev/null
@@ -1,212 +0,0 @@
-from pytorch_lightning import LightningModule
-import torch
-import torch.nn.functional as F
-from torchmetrics.classification.accuracy import Accuracy
-import os.path as osp
-from abc import ABC, abstractmethod
-from copy import deepcopy
-from pytorch_lightning import LightningModule
-from pytorch_lightning.loops.base import Loop
-from pytorch_lightning.loops.fit_loop import FitLoop
-from pytorch_lightning.trainer.states import TrainerFn
-from datasets.data_interface import BaseKFoldDataModule
-from typing import Any, Dict, List, Optional, Type
-import torchmetrics
-import pandas as pd
-import matplotlib.pyplot as plt
-import seaborn as sns
-
-
-
-class EnsembleVotingModel(LightningModule):
-    def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str], n_classes, log_path) -> None:
-        super().__init__()
-        # Create `num_folds` models with their associated fold weights
-        self.n_classes = n_classes
-        self.log_path = log_path
-        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
-        self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths])
-        self.test_acc = Accuracy()
-        if self.n_classes > 2: 
-            self.AUROC = torchmetrics.AUROC(num_classes = self.n_classes, average = 'weighted')
-            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = self.n_classes,
-                                                                           average='micro'),
-                                                     torchmetrics.CohenKappa(num_classes = self.n_classes),
-                                                     torchmetrics.F1Score(num_classes = self.n_classes,
-                                                                     average = 'macro'),
-                                                     torchmetrics.Recall(average = 'macro',
-                                                                         num_classes = self.n_classes),
-                                                     torchmetrics.Precision(average = 'macro',
-                                                                            num_classes = self.n_classes),
-                                                     torchmetrics.Specificity(average = 'macro',
-                                                                            num_classes = self.n_classes)])
-                                                                            
-        else : 
-            self.AUROC = torchmetrics.AUROC(num_classes=2, average = 'weighted')
-            metrics = torchmetrics.MetricCollection([torchmetrics.Accuracy(num_classes = 2,
-                                                                           average = 'micro'),
-                                                     torchmetrics.CohenKappa(num_classes = 2),
-                                                     torchmetrics.F1Score(num_classes = 2,
-                                                                     average = 'macro'),
-                                                     torchmetrics.Recall(average = 'macro',
-                                                                         num_classes = 2),
-                                                     torchmetrics.Precision(average = 'macro',
-                                                                            num_classes = 2)])
-        self.test_metrics = metrics.clone(prefix = 'test_')
-        self.confusion_matrix = torchmetrics.ConfusionMatrix(num_classes = self.n_classes)
-
-    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
-        # Compute the averaged predictions over the `num_folds` models.
-        # print(batch[0].shape)
-        input, label, _ = batch
-        label = label.float()
-        input = input.squeeze(0).float()
-
-            
-        logits = torch.stack([m(input) for m in self.models]).mean(0)
-        Y_hat = torch.argmax(logits, dim=1)
-        Y_prob = F.softmax(logits, dim = 1)
-        # #---->acc log
-        Y = torch.argmax(label)
-        self.data[Y]["count"] += 1
-        self.data[Y]["correct"] += (Y_hat.item() == Y)
-
-        return {'logits' : logits, 'Y_prob' : Y_prob, 'Y_hat' : Y_hat, 'label' : label}
-
-    def test_epoch_end(self, output_results):
-        probs = torch.cat([x['Y_prob'] for x in output_results])
-        max_probs = torch.stack([x['Y_hat'] for x in output_results])
-        # target = torch.stack([x['label'] for x in output_results], dim = 0)
-        target = torch.cat([x['label'] for x in output_results])
-        target = torch.argmax(target, dim=1)
-        
-        #---->
-        auc = self.AUROC(probs, target.squeeze())
-        metrics = self.test_metrics(max_probs.squeeze() , target)
-
-
-        # metrics = self.test_metrics(max_probs.squeeze() , torch.argmax(target.squeeze(), dim=1))
-        metrics['test_auc'] = auc
-
-        # self.log('auc', auc, prog_bar=True, on_epoch=True, logger=True)
-
-        # print(max_probs.squeeze(0).shape)
-        # print(target.shape)
-        # self.log_dict(metrics, logger = True)
-        for keys, values in metrics.items():
-            print(f'{keys} = {values}')
-            metrics[keys] = values.cpu().numpy()
-        #---->acc log
-        for c in range(self.n_classes):
-            count = self.data[c]["count"]
-            correct = self.data[c]["correct"]
-            if count == 0: 
-                acc = None
-            else:
-                acc = float(correct) / count
-            print('class {}: acc {}, correct {}/{}'.format(c, acc, correct, count))
-        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
-
-        self.log_confusion_matrix(probs, target, stage='test')
-        #---->
-        result = pd.DataFrame([metrics])
-        result.to_csv(self.log_path / 'result.csv')
-
-
-    def log_confusion_matrix(self, max_probs, target, stage):
-            confmat = self.confusion_matrix(max_probs.squeeze(), target)
-            df_cm = pd.DataFrame(confmat.cpu().numpy(), index=range(self.n_classes), columns=range(self.n_classes))
-            plt.figure()
-            fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
-            # plt.close(fig_)
-            # plt.savefig(f'{self.log_path}/cm_e{self.current_epoch}')
-            self.loggers[0].experiment.add_figure(f'{stage}/Confusion matrix', fig_, self.current_epoch)
-
-            if stage == 'test':
-                plt.savefig(f'{self.log_path}/cm_test')
-            plt.close(fig_)
-
-class KFoldLoop(Loop):
-    def __init__(self, num_folds: int, export_path: str, **kargs) -> None:
-        super().__init__()
-        self.num_folds = num_folds
-        self.current_fold: int = 0
-        self.export_path = export_path
-        self.n_classes = kargs["model"].n_classes
-        self.log_path = kargs["log"]
-
-    @property
-    def done(self) -> bool:
-        return self.current_fold >= self.num_folds
-
-    def connect(self, fit_loop: FitLoop) -> None:
-        self.fit_loop = fit_loop
-
-    def reset(self) -> None:
-        """Nothing to reset in this loop."""
-
-    def on_run_start(self, *args: Any, **kwargs: Any) -> None:
-        """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
-        model."""
-        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
-        self.trainer.datamodule.setup_folds(self.num_folds)
-        self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict())
-
-    def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
-        """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
-        print(f"STARTING FOLD {self.current_fold}")
-        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
-        self.trainer.datamodule.setup_fold_index(self.current_fold)
-
-    def advance(self, *args: Any, **kwargs: Any) -> None:
-        """Used to the run a fitting and testing on the current hold."""
-        self._reset_fitting()  # requires to reset the tracking stage.
-        self.fit_loop.run()
-
-        self._reset_testing()  # requires to reset the tracking stage.
-        self.trainer.test_loop.run()
-        self.current_fold += 1  # increment fold tracking number.
-
-    def on_advance_end(self) -> None:
-        """Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
-        self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt"))
-        # restore the original weights + optimizers and schedulers.
-        self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
-        self.trainer.strategy.setup_optimizers(self.trainer)
-        self.replace(fit_loop=FitLoop)
-
-    def on_run_end(self) -> None:
-        """Used to compute the performance of the ensemble model on the test set."""
-        checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)]
-        voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths, n_classes=self.n_classes, log_path=self.log_path)
-        voting_model.trainer = self.trainer
-        # This requires to connect the new model and move it the right device.
-        self.trainer.strategy.connect(voting_model)
-        self.trainer.strategy.model_to_device()
-        self.trainer.test_loop.run()
-
-    def on_save_checkpoint(self) -> Dict[str, int]:
-        return {"current_fold": self.current_fold}
-
-    def on_load_checkpoint(self, state_dict: Dict) -> None:
-        self.current_fold = state_dict["current_fold"]
-
-    def _reset_fitting(self) -> None:
-        self.trainer.reset_train_dataloader()
-        self.trainer.reset_val_dataloader()
-        self.trainer.state.fn = TrainerFn.FITTING
-        self.trainer.training = True
-
-    def _reset_testing(self) -> None:
-        self.trainer.reset_test_dataloader()
-        self.trainer.state.fn = TrainerFn.TESTING
-        self.trainer.testing = True
-
-    def __getattr__(self, key) -> Any:
-        # requires to be overridden as attributes of the wrapped loop are being accessed.
-        if key not in self.__dict__:
-            return getattr(self.fit_loop, key)
-        return self.__dict__[key]
-
-    def __setstate__(self, state: Dict[str, Any]) -> None:
-        self.__dict__.update(state)
\ No newline at end of file
diff --git a/utils/__pycache__/utils.cpython-39.pyc b/utils/__pycache__/utils.cpython-39.pyc
deleted file mode 100644
index 33d3d005a578948dd3d5b58938a815f86c4e92f5..0000000000000000000000000000000000000000
Binary files a/utils/__pycache__/utils.cpython-39.pyc and /dev/null differ
diff --git a/utils/extract_features.py b/utils/extract_features.py
deleted file mode 100644
index fb040a175099d6e6612a7634a10eee07c4345cde..0000000000000000000000000000000000000000
--- a/utils/extract_features.py
+++ /dev/null
@@ -1,27 +0,0 @@
-## Choose Model and extract features from (augmented) image patches and save as .pt file
-
-from datasets.custom_dataloader import HDF5MILDataloader
-
-
-def extract_features(input_dir, output_dir, model, batch_size):
-
-
-    dataset = HDF5MILDataloader(data_root, label_path=label_path, mode='train', load_data=False, n_classes=n_classes)
-    if model == 'resnet50':
-        model = Resnet50_baseline(pretrained = True)       
-        model = model.to(device)
-        model.eval()
-
-
-
-if __name__ == '__main__':
-
-    # input_dir, output_dir
-    # initiate data loader
-    # use data loader to load and augment images 
-    # prediction from model
-    # choose save as bag or not (needed?)
-    
-    # features = torch.from_numpy(features)
-    # torch.save(features, output_path + '.pt') 
-